refactor(memory): restructure memory agent and config management

- Reorganize imports and remove unused dependencies across memory agent controllers
- Extract config validation logic into dedicated validators module
- Create new memory_config_model and memory_config_schema for configuration management
- Implement memory_config_service for centralized config handling
- Add embedder_utils module for embedding model utilities
- Refactor memory agent service to use new config validation framework
- Clean up configuration files (remove config.json, testdata.json, dbrun.json)
- Remove deprecated hybrid_chatbot.py and config overrides
- Update logging configuration and error handling across memory modules
- Consolidate LLM and embedding model validation into validators
- Improve code organization and reduce duplication in memory storage services
- Enhance type classification and verification tools with better error handling
This commit is contained in:
Ke Sun
2025-12-21 20:32:41 +08:00
parent 7386ea32f1
commit 1e3ba39150
53 changed files with 3122 additions and 3407 deletions

View File

@@ -1,36 +1,28 @@
import json
import time
from typing import Optional, List
from fastapi import APIRouter, Depends, Query, UploadFile
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.db import get_db
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.rag.llm.cv_model import QWenCV
from app.models import ModelApiKey, Knowledge
from app.services.memory_agent_service import MemoryAgentService
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
from typing import List, Optional
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.services import task_service, workspace_service
from app.core.logging_config import get_api_logger
from app.core.rag.llm.cv_model import QWenCV
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import cur_workspace_access_guard, get_current_user
from app.models import ModelApiKey
from app.models.user_model import User
from app.repositories import knowledge_repository
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse
from app.dependencies import get_current_user
from app.models.user_model import User
from fastapi import APIRouter, Depends, File, UploadFile, Form
from app.repositories import knowledge_repository
from app.services import task_service, workspace_service
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
import os
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
# 加载.env文件
load_dotenv()
# Get API logger
api_logger = get_api_logger()
# Initialize service
memory_agent_service = MemoryAgentService()
router = APIRouter(
@@ -39,95 +31,6 @@ router = APIRouter(
)
def validate_config_id(config_id: int, db: Session) -> int:
"""
Validate and ensure config_id is available, valid, and exists in database.
Args:
config_id: Configuration ID to validate
db: Database session for checking existence
Returns:
int: Validated config_id
Raises:
ValueError: If config_id is None, invalid, or doesn't exist in database
"""
if config_id is None:
api_logger.info("config_id is required but was not provided")
config_id = os.getenv('config_id')
if config_id is None:
raise ValueError("config_id is required but was not provided")
# Check if config exists in database
try:
from app.models.data_config_model import DataConfig
from app.models.models_model import ModelConfig
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if config is None:
error_msg = f"Configuration with config_id={config_id} does not exist in database"
api_logger.error(error_msg)
raise ValueError(error_msg)
# Validate llm_id exists and is usable
if config.llm_id:
try:
llm_config = db.query(ModelConfig).filter(ModelConfig.id == config.llm_id).first()
if llm_config is None:
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) does not exist"
api_logger.error(error_msg)
raise ValueError(error_msg)
if not llm_config.is_active:
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) is not active"
api_logger.error(error_msg)
raise ValueError(error_msg)
api_logger.debug(f"LLM validation successful: llm_id={config.llm_id}, name={llm_config.name}")
except ValueError:
raise
except Exception as e:
error_msg = f"Error validating LLM model: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
else:
api_logger.error(f"Config {config_id} has no llm_id set")
raise ValueError(f"Config {config_id} has no llm_id set")
# Validate embedding_id exists and is usable
if config.embedding_id:
try:
embedding_config = db.query(ModelConfig).filter(ModelConfig.id == config.embedding_id).first()
if embedding_config is None:
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) does not exist"
api_logger.error(error_msg)
raise ValueError(error_msg)
if not embedding_config.is_active:
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) is not active"
api_logger.error(error_msg)
raise ValueError(error_msg)
api_logger.debug(f"Embedding validation successful: embedding_id={config.embedding_id}, name={embedding_config.name}")
except ValueError:
raise
except Exception as e:
error_msg = f"Error validating embedding model: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
else:
api_logger.error(f"Config {config_id} has no embedding_id set")
raise ValueError(f"Config {config_id} has no embedding_id set")
api_logger.info(f"Config validation successful: config_id={config_id}, config_name={config.config_name}, llm_id={config.llm_id}, embedding_id={config.embedding_id}")
return config_id
except ValueError:
# Re-raise ValueError from above
raise
except Exception as e:
error_msg = f"Database error while validating config_id={config_id}: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
@router.get("/health/status", response_model=ApiResponse)
async def get_health_status(
current_user: User = Depends(get_current_user)
@@ -225,12 +128,7 @@ async def write_server(
Returns:
Response with write operation status
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
@@ -270,8 +168,14 @@ async def write_server(
user_rag_memory_id
)
return success(data=result, msg="写入成功")
except Exception as e:
api_logger.error(f"Write operation error: {str(e)}")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
@@ -292,12 +196,7 @@ async def write_server_async(
Task ID for tracking async operation
Use GET /memory/write_result/{task_id} to check task status and get result
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
@@ -352,12 +251,7 @@ async def read_server(
Returns:
Response with query answer
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
@@ -390,8 +284,14 @@ async def read_server(
user_rag_memory_id
)
return success(data=result, msg="回复对话消息成功")
except Exception as e:
api_logger.error(f"Read operation error: {str(e)}")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
api_logger.error(f"Read operation error (TaskGroup): {detailed_error}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", detailed_error)
api_logger.error(f"Read operation error: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
@@ -456,12 +356,7 @@ async def read_server_async(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async read service: workspace_id={workspace_id}, config_id={config_id}")

View File

@@ -1,45 +1,45 @@
from typing import Optional, Union
import os
import uuid
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends, UploadFile
from fastapi.responses import StreamingResponse
from typing import Optional
from app.db import get_db
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.memory.utils.self_reflexion_utils import self_reflexion
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate,
ConfigParamsDelete,
ConfigPilotRun,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
)
from app.schemas.response_schema import ApiResponse
from app.services.memory_storage_service import (
MemoryStorageService,
DataConfigService,
kb_type_distribution,
search_dialogue,
search_chunk,
search_statement,
search_entity,
search_all,
search_detials,
search_edges,
search_entity_graph,
MemoryStorageService,
analytics_hot_memory_tags,
analytics_memory_insight_report,
analytics_recent_activity_stats,
analytics_user_summary,
kb_type_distribution,
search_all,
search_chunk,
search_detials,
search_dialogue,
search_edges,
search_entity,
search_entity_graph,
search_statement,
)
from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import (
ConfigParamsCreate,
ConfigParamsDelete,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
ConfigKey,
ConfigPilotRun,
)
from app.core.memory.utils.config.definitions import reload_configuration_from_database
from app.dependencies import get_current_user
from app.models.user_model import User
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
# Get API logger
api_logger = get_api_logger()
@@ -329,8 +329,10 @@ async def pilot_run(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StreamingResponse:
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
api_logger.info(
f"Pilot run requested: config_id={payload.config_id}, "
f"dialogue_text_length={len(payload.dialogue_text)}"
)
svc = DataConfigService(db)
return StreamingResponse(
svc.pilot_run_stream(payload),
@@ -338,8 +340,8 @@ async def pilot_run(
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
"X-Accel-Buffering": "no",
},
)
"""
@@ -528,8 +530,8 @@ async def get_user_summary_api(
except Exception as e:
api_logger.error(f"User summary failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户摘要生成失败", str(e))
from app.core.memory.utils.self_reflexion_utils import self_reflexion
@router.get("/self_reflexion")
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
"""

View File

@@ -326,7 +326,7 @@ def log_prompt_rendering(prompt_type: str, content: str) -> None:
logger.info(log_message)
def log_template_rendering(template_name: str, context: dict | None = None) -> None:
def log_template_rendering(template_name: str, context: Optional[dict] = None) -> None:
"""Log template rendering information.
Logs the template name and context keys for debugging template rendering.
@@ -575,6 +575,43 @@ def get_named_logger(name: str) -> logging.Logger:
return get_agent_logger(name)
def get_config_logger() -> logging.Logger:
"""Get a specialized logger for memory configuration operations.
Returns a logger configured specifically for configuration loading, validation,
and model resolution operations with:
- Logger name: memory.config
- Output: Inherits from root logger (console + file)
- Level: Inherits from root logger
- Format: Standard format with timing information
This logger is optimized for configuration operations and includes
structured logging for timing, validation steps, and error context.
Returns:
Logger configured for memory configuration operations
Example:
>>> logger = get_config_logger()
>>> logger.info("Loading configuration", extra={
... "config_id": 123,
... "workspace_id": "uuid-here",
... "operation": "load_config"
... })
"""
# Ensure memory logging is initialized
if not LoggingConfig._memory_loggers_initialized:
LoggingConfig.setup_memory_logging()
# Get configuration logger with memory namespace
logger = logging.getLogger("memory.config")
# The logger automatically inherits handlers, formatters, and level from root logger
# through Python's logging hierarchy, so no additional configuration is needed
return logger
def get_memory_logger(name: Optional[str] = None) -> logging.Logger:
"""Get a standard logger for memory module components.

View File

@@ -9,11 +9,11 @@ import logging
import re
import uuid
from datetime import datetime
from typing import Dict, Any
from langchain_core.messages import AIMessage
from typing import Any, Dict
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
logger = logging.getLogger(__name__)
@@ -25,7 +25,8 @@ async def create_input_message(
search_switch: str,
apply_id: str,
group_id: str,
multimodal_processor: MultimodalProcessor
multimodal_processor: MultimodalProcessor,
memory_config: MemoryConfig,
) -> Dict[str, Any]:
"""
Create initial tool call message from user input.
@@ -46,6 +47,7 @@ async def create_input_message(
apply_id: Application identifier
group_id: Group identifier
multimodal_processor: Processor for handling image/audio inputs
memory_config: MemoryConfig object containing all configuration
Returns:
State update with AIMessage containing tool_call
@@ -53,7 +55,7 @@ async def create_input_message(
Examples:
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
>>> result = await create_input_message(
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor, config
... )
>>> result["messages"][0].tool_calls[0]["name"]
'Split_The_Problem'
@@ -123,20 +125,24 @@ async def create_input_message(
f"with ID: {tool_call_id}"
)
# Build tool arguments
tool_args = {
"sentence": last_message,
"sessionid": session_id,
"messages_id": str(uuid_str),
"search_switch": search_switch,
"apply_id": apply_id,
"group_id": group_id,
"memory_config": memory_config,
}
return {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": tool_name,
"args": {
"sentence": last_message,
"sessionid": session_id,
"messages_id": str(uuid_str),
"search_switch": search_switch,
"apply_id": apply_id,
"group_id": group_id
},
"args": tool_args,
"id": tool_call_id
}]
)

View File

@@ -9,14 +9,14 @@ import logging
import time
from typing import Any, Callable, Dict
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode
from app.core.memory.agent.langgraph_graph.state.extractors import (
extract_content_payload,
extract_tool_call_id,
extract_content_payload
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode
logger = logging.getLogger(__name__)
@@ -38,8 +38,9 @@ class ToolExecutionNode:
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
memory_config: MemoryConfig object containing all configuration
"""
def __init__(
self,
tool: Callable,
@@ -49,8 +50,9 @@ class ToolExecutionNode:
apply_id: str,
group_id: str,
parameter_builder: ParameterBuilder,
storage_type:str,
user_rag_memory_id:str
storage_type: str,
user_rag_memory_id: str,
memory_config: MemoryConfig,
):
"""
Initialize the tool execution node.
@@ -63,6 +65,9 @@ class ToolExecutionNode:
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
storage_type: Storage type for the workspace
user_rag_memory_id: User RAG memory identifier
memory_config: MemoryConfig object containing all configuration
"""
self.tool_node = ToolNode([tool])
self.id = node_id
@@ -72,9 +77,10 @@ class ToolExecutionNode:
self.apply_id = apply_id
self.group_id = group_id
self.parameter_builder = parameter_builder
self.storage_type=storage_type
self.user_rag_memory_id=user_rag_memory_id
self.storage_type = storage_type
self.user_rag_memory_id = user_rag_memory_id
self.memory_config = memory_config
logger.info(
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
)
@@ -124,8 +130,12 @@ class ToolExecutionNode:
# Extract content payload using state extractors
content = extract_content_payload(last_message)
logger.debug(
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}"
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}, content_keys: {list(content.keys()) if isinstance(content, dict) else 'N/A'}"
)
# Log raw message content for debugging
if hasattr(last_message, 'content'):
raw = last_message.content
logger.debug(f"[ToolExecutionNode] {self.id} - Raw message content (first 500 chars): {str(raw)[:500]}")
except Exception as e:
logger.error(
@@ -143,8 +153,9 @@ class ToolExecutionNode:
search_switch=self.search_switch,
apply_id=self.apply_id,
group_id=self.group_id,
memory_config=self.memory_config,
storage_type=self.storage_type,
user_rag_memory_id=self.user_rag_memory_id
user_rag_memory_id=self.user_rag_memory_id,
)
logger.debug(
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
@@ -179,7 +190,29 @@ class ToolExecutionNode:
f"[ToolExecutionNode] {self.id} - Tool execution completed"
)
# Return the result directly - it already contains the messages list
# Check for error in tool response
error_entry = None
if result and "messages" in result:
for msg in result["messages"]:
if hasattr(msg, 'content'):
try:
import json
content = msg.content
if isinstance(content, str):
parsed = json.loads(content)
if isinstance(parsed, dict) and "error" in parsed:
error_msg = parsed["error"]
logger.warning(
f"[ToolExecutionNode] {self.id} - Tool returned error: {error_msg}"
)
error_entry = {"tool": self.tool_name, "error": error_msg, "node_id": self.id}
except (json.JSONDecodeError, TypeError):
pass
# Return result with error tracking if error was found
if error_entry:
result["errors"] = [error_entry]
return result
except Exception as e:
@@ -187,13 +220,15 @@ class ToolExecutionNode:
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
exc_info=True
)
# Return error as ToolMessage to maintain message chain consistency
# Track error in state and return error message
from langchain_core.messages import ToolMessage
error_entry = {"tool": self.tool_name, "error": str(e), "node_id": self.id}
return {
"messages": [
ToolMessage(
content=f"Error executing tool: {str(e)}",
tool_call_id=f"{self.id}_{tool_call_id}"
)
]
],
"errors": [error_entry]
}

View File

@@ -1,38 +1,26 @@
import asyncio
import io
import json
import logging
import os
import re
import time
import uuid
import warnings
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Literal
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from functools import partial
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
from langgraph.checkpoint.memory import InMemorySaver
from app.core.memory.agent.utils.redis_tool import store
from app.core.logging_config import get_agent_logger
# Import new modular components
from app.core.memory.agent.langgraph_graph.nodes import ToolExecutionNode, create_input_message
from app.core.memory.agent.langgraph_graph.routing.routers import (
Verify_continue,
Retrieve_continue,
Split_continue
from app.core.memory.agent.langgraph_graph.nodes import (
ToolExecutionNode,
create_input_message,
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
from app.core.memory.agent.utils.llm_tools import COUNTState, ReadState
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
logger = get_agent_logger(__name__)
@@ -44,9 +32,9 @@ redisdb=os.getenv('REDISDB')
redispassword=os.getenv('REDISPASSWORD')
counter = COUNTState(limit=3)
# 在工作流中添加循环计数更新
# Update loop count in workflow
async def update_loop_count(state):
"""更新循环计数器"""
"""Update loop counter"""
current_count = state.get("loop_count", 0)
return {"loop_count": current_count + 1}
@@ -54,13 +42,13 @@ async def update_loop_count(state):
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
messages = state["messages"]
# 添加边界检查
# Add boundary check
if not messages:
return END
counter.add(1) # 累加 1
counter.add(1) # Increment by 1
loop_count = counter.get_total()
logger.debug(f"[should_continue] 当前循环次数: {loop_count}")
logger.debug(f"[should_continue] Current loop count: {loop_count}")
last_message = messages[-1]
last_message_str = str(last_message).replace('\\', '')
@@ -71,15 +59,15 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
counter.reset()
return "Summary"
elif "failed" in status_tools:
if loop_count < 2: # 最大循环次数 3
if loop_count < 2: # Maximum loop count is 3
return "content_input"
else:
counter.reset()
return "Summary_fails"
else:
# 添加默认返回值,避免返回 None
# Add default return value to avoid returning None
counter.reset()
return "Summary" # 或根据业务需求选择合适的默认值
return "Summary" # Default based on business requirements
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
@@ -115,8 +103,8 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
elif search_switch == '1':
return 'Retrieve_Summary'
# 添加默认返回值,避免返回 None
return 'Retrieve_Summary' # 或根据业务逻辑选择合适的默认值
# Add default return value to avoid returning None
return 'Retrieve_Summary' # Default based on business logic
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
@@ -151,46 +139,7 @@ def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
search_switch = str(search_switch)
if search_switch == '2':
return 'Input_Summary'
return 'Split_The_Problem' # 默认情况
# 在 input_sentence 函数中修改参数名称
async def input_sentence(state, name, id, search_switch,apply_id,group_id):
messages = state["messages"]
last_message = messages[-1].content if messages else ""
if last_message.endswith('.jpg') or last_message.endswith('.png'):
last_message=await picture_model_requests(last_message)
if any(last_message.endswith(ext) for ext in audio_extensions):
last_message=await Vico_recognition([last_message]).run()
logger.debug(f"Audio recognition result: {last_message}")
uuid_str = uuid.uuid4()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
namespace = str(id).split('_id_')[1]
if 'verified_data' in str(last_message):
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
last_message = re.findall(r'"query": "(.*?)",', str(messages_last))[0]
return {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": name,
"args": {
"sentence": last_message,
'sessionid': id,
'messages_id': str(uuid_str),
"search_switch": search_switch, # 正确地将 search_switch 放入 args 中
"apply_id":apply_id,
"group_id":group_id
},
"id": id + f'_{uuid_str}'
}]
)
]
}
return 'Split_The_Problem' # Default case
class ProblemExtensionNode:
@@ -208,30 +157,28 @@ class ProblemExtensionNode:
async def __call__(self, state):
messages = state["messages"]
last_message = messages[-1] if messages else ""
logger.debug(f"ProblemExtensionNode {self.id} - 当前时间: {time.time()} - Message: {last_message}")
if self.tool_name=='Input_Summary':
tool_call =re.findall("'id': '(.*?)'",str(last_message))[0]
else:tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
# try:
# content = json.loads(last_message.content) if hasattr(last_message, 'content') else last_message
# except:
# content = last_message.content if hasattr(last_message, 'content') else str(last_message)
# 尝试从上一工具的结果中提取实际的内容载荷(而不是整个对象的字符串表示)
logger.debug(f"ProblemExtensionNode {self.id} - Current time: {time.time()} - Message: {last_message}")
if self.tool_name == 'Input_Summary':
tool_call = re.findall("'id': '(.*?)'", str(last_message))[0]
else:
tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
# Try to extract actual content payload from previous tool result
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
extracted_payload = None
# 捕获 ToolMessage content 字段(支持单/双引号),并避免贪婪匹配
# Capture ToolMessage content field (supports single/double quotes), avoid greedy matching
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
if m:
extracted_payload = m.group(1)
else:
# 回退:直接尝试使用原始字符串
# Fallback: use raw string directly
extracted_payload = raw_msg
# 优先尝试将内容解析为 JSON
# Try to parse content as JSON first
try:
content = json.loads(extracted_payload)
except Exception:
# 尝试从文本中提取 JSON 片段再解析
# Try to extract JSON fragment from text and parse
parsed = None
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
for cand in candidates:
@@ -240,14 +187,14 @@ class ProblemExtensionNode:
break
except Exception:
continue
# 如果仍然失败,则以原始字符串作为内容
# If still fails, use raw string as content
content = parsed if parsed is not None else extracted_payload
# 根据工具名称构建正确的参数
# Build correct parameters based on tool name
tool_args = {}
if self.tool_name == "Verify":
# Verify工具需要contextusermessages参数
# Verify tool requires context and usermessages parameters
if isinstance(content, dict):
tool_args["context"] = content
else:
@@ -256,7 +203,7 @@ class ProblemExtensionNode:
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Retrieve":
# Retrieve工具需要contextusermessages参数
# Retrieve tool requires context and usermessages parameters
if isinstance(content, dict):
tool_args["context"] = content
else:
@@ -266,9 +213,9 @@ class ProblemExtensionNode:
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary":
# Summary工具需要字符串类型的context参数
# Summary tool requires string type context parameter
if isinstance(content, dict):
# 将字典转换为JSON字符串
# Convert dict to JSON string
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
@@ -276,24 +223,24 @@ class ProblemExtensionNode:
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary_fails":
# Summary工具需要字符串类型的context参数
# Summary_fails tool requires string type context parameter
if isinstance(content, dict):
# 将字典转换为JSON字符串
# Convert dict to JSON string
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name=='Input_Summary':
tool_args["context"] =str(last_message)
elif self.tool_name == 'Input_Summary':
tool_args["context"] = str(last_message)
tool_args["usermessages"] = str(tool_call)
tool_args["search_switch"] = str(self.search_switch)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
tool_args["storage_type"] = getattr(self, 'storage_type', "")
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
elif self.tool_name=='Retrieve_Summary' :
elif self.tool_name == 'Retrieve_Summary':
# Retrieve_Summary expects dict directly, not JSON string
# content might be a JSON string, try to parse it
if isinstance(content, str):
@@ -320,7 +267,7 @@ class ProblemExtensionNode:
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
else:
# 其他工具使用context参数
# Other tools use context parameter
if isinstance(content, dict):
tool_args["context"] = content
else:
@@ -349,12 +296,24 @@ class ProblemExtensionNode:
@asynccontextmanager
async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config_id=None,storage_type=None,user_rag_memory_id=None):
async def make_read_graph(namespace, tools, search_switch, apply_id, group_id, memory_config: MemoryConfig, storage_type=None, user_rag_memory_id=None):
"""
Create a read graph workflow for memory operations.
Args:
namespace: Namespace identifier
tools: MCP tools loaded from session
search_switch: Search mode switch ("0", "1", or "2")
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type (optional)
user_rag_memory_id: User RAG memory ID (optional)
"""
memory = InMemorySaver()
tool=[i.name for i in tools ]
tool = [i.name for i in tools]
logger.info(f"Initializing read graph with tools: {tool}")
if config_id:
logger.info(f"使用配置 ID: {config_id}")
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
# Extract tool functions
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
@@ -382,9 +341,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Retrieve_node = ToolExecutionNode(
tool=Retrieve_,
node_id="Retrieve_id",
@@ -394,9 +354,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Verify_node = ToolExecutionNode(
tool=Verify_,
node_id="Verify_id",
@@ -406,7 +367,8 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Summary_node = ToolExecutionNode(
@@ -418,9 +380,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Summary_fails_node = ToolExecutionNode(
tool=Summary_fails_,
node_id="Summary_fails_id",
@@ -430,9 +393,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Retrieve_Summary_node = ToolExecutionNode(
tool=Retrieve_Summary_,
node_id="Retrieve_Summary_id",
@@ -442,9 +406,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Input_Summary_node = ToolExecutionNode(
tool=Input_Summary_,
node_id="Input_Summary_id",
@@ -454,16 +419,16 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
async def content_input_node(state):
state_search_switch = state.get("search_switch", search_switch)
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
return await create_input_message(
state=state,
tool_name=tool_name,
@@ -471,7 +436,8 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
multimodal_processor=multimodal_processor
multimodal_processor=multimodal_processor,
memory_config=memory_config,
)
@@ -501,8 +467,3 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
graph = workflow.compile(checkpointer=memory)
yield graph
# 添加到文件末尾或创建新的执行脚本
# 在 memory_agent_service.py 文件中添加以下函数

View File

@@ -128,6 +128,15 @@ def extract_content_payload(message: Any) -> Any:
# For ToolMessages (responses from tools), extract from content
if hasattr(message, "content"):
raw_content = message.content
logger.info(f"extract_content_payload: raw_content type={type(raw_content)}, value={str(raw_content)[:500]}")
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(raw_content, list):
for block in raw_content:
if isinstance(block, dict) and block.get('type') == 'text':
raw_content = block.get('text', '')
logger.info(f"extract_content_payload: extracted text from MCP format: {str(raw_content)[:300]}")
break
# If content is empty and this is an AIMessage with tool_calls,
# extract from args (this handles the initial tool call from content_input)
@@ -140,13 +149,16 @@ def extract_content_payload(message: Any) -> Any:
# If content is already a dict or list, return it directly
if isinstance(raw_content, (dict, list)):
logger.info(f"extract_content_payload: returning raw dict/list with keys={list(raw_content.keys()) if isinstance(raw_content, dict) else 'list'}")
return raw_content
# Try to parse as JSON
if isinstance(raw_content, str):
# First, try direct JSON parsing
try:
return json.loads(raw_content)
parsed = json.loads(raw_content)
logger.info(f"extract_content_payload: parsed JSON, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
return parsed
except (json.JSONDecodeError, ValueError):
pass
@@ -156,9 +168,12 @@ def extract_content_payload(message: Any) -> Any:
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
for candidate in json_candidates:
try:
return json.loads(candidate)
parsed = json.loads(candidate)
logger.info(f"extract_content_payload: parsed JSON from candidate, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
return parsed
except (json.JSONDecodeError, ValueError):
continue
# If all parsing attempts fail, return the raw content
logger.info(f"extract_content_payload: returning raw content (parsing failed)")
return raw_content

View File

@@ -1,69 +1,71 @@
import asyncio
import json
from contextlib import asynccontextmanager
from langgraph.constants import START, END
from langgraph.graph import add_messages, StateGraph
from langgraph.prebuilt import ToolNode
from app.core.memory.agent.utils.llm_tools import WriteState
import warnings
import sys
from langchain_core.messages import AIMessage
import warnings
from contextlib import asynccontextmanager
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@asynccontextmanager
async def make_write_graph(user_id, tools, apply_id, group_id, config_id=None):
logger.info("加载 MCP 工具: %s", [t.name for t in tools])
if config_id:
logger.info(f"使用配置 ID: {config_id}")
data_type_tool = next((t for t in tools if t.name == "Data_type_differentiation"), None)
@asynccontextmanager
async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: MemoryConfig):
"""
Create a write graph workflow for memory operations.
Args:
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
logger.info("Loading MCP tools: %s", [t.name for t in tools])
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
if not data_type_tool or not data_write_tool:
logger.error('不存在数据存储工具', exc_info=True)
raise ValueError('不存在数据存储工具')
# ToolNode
write_node = ToolNode([data_write_tool])
if not data_write_tool:
logger.error("Data_write tool not found", exc_info=True)
raise ValueError("Data_write tool not found")
write_node = ToolNode([data_write_tool])
async def call_model(state):
messages = state["messages"]
last_message = messages[-1]
content = last_message[1] if isinstance(last_message, tuple) else last_message.content
result = await data_type_tool.ainvoke({
"context": last_message[1] if isinstance(last_message, tuple) else last_message.content
})
result=json.loads( result)
# 调用 Data_write传递 config_id
# Call Data_write directly with memory_config
write_params = {
"content": result["context"],
"content": content,
"apply_id": apply_id,
"group_id": group_id,
"user_id": user_id
"user_id": user_id,
"memory_config": memory_config,
}
# 如果提供了 config_id添加到参数中
if config_id:
write_params["config_id"] = config_id
logger.debug(f"传递 config_id 到 Data_write: {config_id}")
logger.debug(f"Passing memory_config to Data_write: {memory_config.config_id}")
write_result = await data_write_tool.ainvoke(write_params)
if isinstance(write_result, dict):
content = write_result.get("data", str(write_result))
result_content = write_result.get("data", str(write_result))
else:
content = str(write_result)
logger.info("写入内容: %s", content)
return {"messages": [AIMessage(content=content)]}
result_content = str(write_result)
logger.info("Write content: %s", result_content)
return {"messages": [AIMessage(content=result_content)]}
workflow = StateGraph(WriteState)
workflow.add_node("content_input", call_model)

View File

@@ -10,19 +10,19 @@ Package structure:
- models: Pydantic response models
- services: Business logic services
"""
from app.core.memory.agent.mcp_server.server import (
mcp,
initialize_context,
main,
get_context_resource
)
# from app.core.memory.agent.mcp_server.server import (
# mcp,
# initialize_context,
# main,
# get_context_resource
# )
# Import tools to register them (but don't export them)
from app.core.memory.agent.mcp_server import tools
# # Import tools to register them (but don't export them)
# from app.core.memory.agent.mcp_server import tools
__all__ = [
'mcp',
'initialize_context',
'main',
'get_context_resource',
]
# __all__ = [
# 'mcp',
# 'initialize_context',
# 'main',
# 'get_context_resource',
# ]

View File

@@ -6,19 +6,15 @@ in the context for dependency injection into tool functions.
"""
import os
import sys
from mcp.server.fastmcp import FastMCP
from app.core.config import settings
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.redis_tool import RedisSessionStore, store
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID,reload_configuration_from_database
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.services.search_service import SearchService
from app.core.memory.agent.mcp_server.services.session_service import SessionService
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import store
logger = get_agent_logger(__name__)
@@ -78,17 +74,11 @@ def initialize_context():
logger.info("Registering session_store in context")
mcp.session_store = store
# Register LLM client
try:
logger.info(f"Registering llm_client in context with model ID: {SELECTED_LLM_ID}")
llm_client = get_llm_client(SELECTED_LLM_ID)
mcp.llm_client = llm_client
logger.info("llm_client registered successfully")
except Exception as e:
logger.error(f"Failed to register llm_client: {e}", exc_info=True)
# 注册一个 None 值,避免工具调用时找不到资源
mcp.llm_client = None
logger.warning("llm_client set to None due to initialization failure")
# Note: LLM client is NOT loaded at server startup
# It should be loaded dynamically when needed, with config_id passed explicitly
# to make_write_graph or make_read_graph functions
logger.info("LLM client will be loaded dynamically with config_id when needed")
mcp.llm_client = None # Placeholder - actual client loaded per-request with config_id
# Register application settings (renamed to avoid conflict with FastMCP's settings)
logger.info("Registering app_settings in context")
@@ -124,26 +114,20 @@ def main():
Initializes context and starts the server with SSE transport.
"""
try:
# logger.info("Starting MCP server initialization")
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
logger.info("Starting MCP server initialization")
# Initialize context resources
initialize_context()
# Import and register tools
# logger.info("Importing MCP tools")
from app.core.memory.agent.mcp_server.tools import (
# Import and register tools (imports trigger tool registration)
from app.core.memory.agent.mcp_server.tools import ( # noqa: F401
data_tools,
problem_tools,
retrieval_tools,
verification_tools,
summary_tools,
data_tools
verification_tools,
)
# logger.info("All MCP tools imported and registered")
# Log registered tools for debugging
import asyncio
tools_list = asyncio.run(mcp.list_tools())
# logger.info(f"Registered {len(tools_list)} MCP tools: {[t.name for t in tools_list]}")
# Tools are registered via imports above
# Get MCP port from environment (default: 8081)
mcp_port = int(os.getenv("MCP_PORT", "8081"))

View File

@@ -4,22 +4,22 @@ Parameter Builder for constructing tool call arguments.
This service provides tool-specific parameter transformation logic
to build correct arguments for each tool type.
"""
import json
from typing import Any, Dict, Optional
from app.core.logging_config import get_agent_logger
from app.schemas.memory_config_schema import MemoryConfig
logger = get_agent_logger(__name__)
class ParameterBuilder:
"""Service for building tool call arguments based on tool type."""
def __init__(self):
"""Initialize the parameter builder."""
logger.info("ParameterBuilder initialized")
def build_tool_args(
self,
tool_name: str,
@@ -28,8 +28,9 @@ class ParameterBuilder:
search_switch: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Build tool arguments based on tool type.
@@ -48,6 +49,7 @@ class ParameterBuilder:
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
@@ -58,18 +60,19 @@ class ParameterBuilder:
base_args = {
"usermessages": tool_call_id,
"apply_id": apply_id,
"group_id": group_id
"group_id": group_id,
"memory_config": memory_config,
}
# Always add storage_type and user_rag_memory_id (with defaults if None)
base_args["storage_type"] = storage_type if storage_type is not None else ""
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
# Tool-specific argument construction
if tool_name in ["Verify","Summary", "Summary_fails",'Retrieve_Summary']:
# Verify expects dict context
if tool_name in ["Verify", "Summary", "Summary_fails", "Retrieve_Summary", "Problem_Extension"]:
# These tools expect dict context
return {
"context": content if isinstance(content, dict) else {},
"context": content if isinstance(content, dict) else {"content": content},
**base_args
}

View File

@@ -4,21 +4,31 @@ Search Service for executing hybrid search and processing results.
This service provides clean search result processing with content extraction
and deduplication.
"""
from typing import List, Tuple, Optional
from typing import TYPE_CHECKING, List, Optional, Tuple
from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
logger = get_agent_logger(__name__)
class SearchService:
"""Service for executing hybrid search and processing results."""
def __init__(self):
"""Initialize the search service."""
def __init__(self, memory_config: "MemoryConfig" = None):
"""
Initialize the search service.
Args:
memory_config: Optional MemoryConfig for embedding model configuration.
If not provided, must be passed to execute_hybrid_search.
"""
self.memory_config = memory_config
logger.info("SearchService initialized")
def extract_content_from_result(self, result: dict) -> str:
@@ -93,12 +103,13 @@ class SearchService:
self,
group_id: str,
question: str,
limit: int = 5,
limit: int = 15,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False
return_raw_results: bool = False,
memory_config: "MemoryConfig" = None,
) -> Tuple[str, str, Optional[dict]]:
"""
Execute hybrid search and return clean content.
@@ -112,6 +123,7 @@ class SearchService:
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
output_path: Path to save search results (default: "search_results.json")
return_raw_results: If True, also return the raw search results as third element (default: False)
memory_config: MemoryConfig object for embedding model. Falls back to self.memory_config if not provided.
Returns:
Tuple of (clean_content, cleaned_query, raw_results)
@@ -119,12 +131,17 @@ class SearchService:
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
# Use provided memory_config or fall back to instance config
config = memory_config or self.memory_config
if not config:
raise ValueError("memory_config is required for search - either pass it to __init__ or execute_hybrid_search")
# Clean query
cleaned_query = self.clean_query(question)
try:
# Execute search
# Execute search using embedding_model_id from memory_config
answer = await run_hybrid_search(
query_text=cleaned_query,
search_type=search_type,
@@ -132,7 +149,8 @@ class SearchService:
limit=limit,
include=include,
output_path=output_path,
rerank_alpha=rerank_alpha
embedding_id=str(config.embedding_model_id),
rerank_alpha=rerank_alpha,
)
# Extract results based on search type and include parameter

View File

@@ -3,16 +3,19 @@ Data Tools for data type differentiation and writing.
This module contains MCP tools for distinguishing data types and writing data.
"""
import os
from mcp.server.fastmcp import Context
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.models.retrieval_models import (
DistinguishTypeResponse,
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.retrieval_models import DistinguishTypeResponse
from app.core.memory.agent.utils.write_tools import write
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
from app.schemas.memory_config_schema import MemoryConfig
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@@ -20,7 +23,8 @@ logger = get_agent_logger(__name__)
@mcp.tool()
async def Data_type_differentiation(
ctx: Context,
context: str
context: str,
memory_config: MemoryConfig,
) -> dict:
"""
Distinguish the type of data (read or write).
@@ -28,6 +32,7 @@ async def Data_type_differentiation(
Args:
ctx: FastMCP context for dependency injection
context: Text to analyze for type differentiation
memory_config: MemoryConfig object containing LLM configuration
Returns:
dict: Contains 'context' with the original text and 'type' field
@@ -35,7 +40,9 @@ async def Data_type_differentiation(
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config)
# Render template
try:
@@ -53,7 +60,7 @@ async def Data_type_differentiation(
"type": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
@@ -98,7 +105,7 @@ async def Data_write(
user_id: str,
apply_id: str,
group_id: str,
config_id: str
memory_config: MemoryConfig,
) -> dict:
"""
Write data to the database/file system.
@@ -109,7 +116,7 @@ async def Data_write(
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
config_id: Configuration ID for processing (optional, integer)
memory_config: MemoryConfig object containing all configuration
Returns:
dict: Contains 'status', 'saved_to', and 'data' fields
@@ -118,32 +125,28 @@ async def Data_write(
# Ensure output directory exists
os.makedirs("data_output", exist_ok=True)
file_path = os.path.join("data_output", "user_data.csv")
# Write data using utility function
try:
await write(content, user_id, apply_id, group_id, config_id=config_id)
logger.info(f"写入成功Config ID: {config_id if config_id else 'None'}")
return {
"status": "success",
"saved_to": file_path,
"data": content,
"config_id": config_id
}
except Exception as e:
logger.error(f"写入失败: {e}", exc_info=True)
return {
"status": "error",
"message": str(e)
}
except Exception as e:
logger.error(
f"Data_write failed: {e}",
exc_info=True
# Write data - clients are constructed inside write() from memory_config
await write(
content=content,
user_id=user_id,
apply_id=apply_id,
group_id=group_id,
memory_config=memory_config,
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
return {
"status": "success",
"saved_to": file_path,
"data": content,
"config_id": memory_config.config_id,
"config_name": memory_config.config_name,
}
except Exception as e:
logger.error(f"Data_write failed: {e}", exc_info=True)
return {
"status": "error",
"message": str(e)
"message": str(e),
}

View File

@@ -2,25 +2,23 @@
Problem Tools for question segmentation and extension.
This module contains MCP tools for breaking down and extending user questions.
LLM clients are constructed from MemoryConfig when needed.
"""
import json
import time
from typing import List
from pydantic import BaseModel, Field, RootModel
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.problem_models import (
ProblemBreakdownItem,
ProblemBreakdownResponse,
ExtendedQuestionItem,
ProblemExtensionResponse
ProblemExtensionResponse,
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
from app.schemas.memory_config_schema import MemoryConfig
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@@ -32,7 +30,8 @@ async def Split_The_Problem(
sessionid: str,
messages_id: str,
apply_id: str,
group_id: str
group_id: str,
memory_config: MemoryConfig,
) -> dict:
"""
Segment the dialogue or sentence into sub-problems.
@@ -44,17 +43,20 @@ async def Split_The_Problem(
messages_id: Message identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
Returns:
dict: Contains 'context' (JSON string of split results) and 'original' sentence
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config)
# Extract user ID from session
user_id = session_service.resolve_user_id(sessionid)
@@ -116,8 +118,8 @@ async def Split_The_Problem(
)
split_result = json.dumps([], ensure_ascii=False)
logger.info("问题拆分")
logger.info(f"问题拆分结果==>>:{split_result}")
logger.info("Problem splitting")
logger.info(f"Problem split result: {split_result}")
# Emit intermediate output for frontend
result = {
@@ -150,7 +152,7 @@ async def Split_The_Problem(
duration = end - start
except Exception:
duration = 0.0
log_time('问题拆分', duration)
log_time('Problem splitting', duration)
@mcp.tool()
@@ -160,8 +162,9 @@ async def Problem_Extension(
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
user_rag_memory_id: str = "",
) -> dict:
"""
Extend the problem with additional sub-questions.
@@ -172,6 +175,7 @@ async def Problem_Extension(
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
@@ -179,12 +183,14 @@ async def Problem_Extension(
dict: Contains 'context' (aggregated questions) and 'original' question
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config)
# Resolve session ID from usermessages
from app.core.memory.agent.utils.messages_tool import Resolve_username
@@ -250,8 +256,8 @@ async def Problem_Extension(
)
aggregated_dict = {}
logger.info("问题扩展")
logger.info(f"问题扩展==>>:{aggregated_dict}")
logger.info("Problem extension")
logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend
result = {
@@ -290,4 +296,4 @@ async def Problem_Extension(
duration = end - start
except Exception:
duration = 0.0
log_time('问题扩展', duration)
log_time('Problem extension', duration)

View File

@@ -3,25 +3,24 @@ Retrieval Tools for database and context retrieval.
This module contains MCP tools for retrieving data using hybrid search.
"""
from dotenv import load_dotenv
import os
from app.core.rag.nlp.search import knowledge_retrieval
# 加载.env文件
load_dotenv()
import time
from typing import List
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.llm_tools import deduplicate_entries, merge_to_key_value_pairs
from app.core.memory.agent.utils.llm_tools import (
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
from app.core.rag.nlp.search import knowledge_retrieval
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
from mcp.server.fastmcp import Context
load_dotenv()
logger = get_agent_logger(__name__)
@@ -32,8 +31,9 @@ async def Retrieve(
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
user_rag_memory_id: str = "",
) -> dict:
"""
Retrieve data from the database using hybrid search.
@@ -44,6 +44,7 @@ async def Retrieve(
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
@@ -66,6 +67,7 @@ async def Retrieve(
}
start = time.time()
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
logger.info(f"Retrieve: context type={type(context)}, context={str(context)[:500]}")
try:
# Extract services from context
@@ -77,7 +79,13 @@ async def Retrieve(
if isinstance(context, dict):
# Process dict context with extended questions
all_items = []
logger.info(f"Retrieve: context keys={list(context.keys())}")
content, original = await Retriev_messages_deal(context)
logger.info(f"Retrieve: after Retriev_messages_deal - content_type={type(content)}, content={str(content)[:300]}")
logger.info(f"Retrieve: original='{original[:100] if original else 'EMPTY'}'")
if not original:
logger.warning(f"Retrieve: original query is empty! context={context}")
# Extract all query items from content
# content is like {original_question: [extended_questions...], ...}
@@ -113,9 +121,11 @@ async def Retrieve(
clean_content = ''
raw_results=''
cleaned_query = question
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
databases_anser.append({
"Query_small": cleaned_query,
@@ -206,9 +216,11 @@ async def Retrieve(
clean_content = ''
raw_results = ''
cleaned_query = query
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
# Keep structure for Verify/Retrieve_Summary compatibility
dup_databases = {
"Query": cleaned_query,
@@ -236,7 +248,7 @@ async def Retrieve(
}
logger.info(
f"检索==>>:{storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
f"Retrieval: {storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
)
@@ -279,4 +291,4 @@ async def Retrieve(
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration)
log_time('Retrieval', duration)

View File

@@ -2,33 +2,31 @@
Summary Tools for data summarization.
This module contains MCP tools for summarizing retrieved data and generating responses.
LLM clients are constructed from MemoryConfig when needed.
"""
import json
import os
import re
import time
from typing import List
from pydantic import BaseModel, Field
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.summary_models import (
SummaryData,
RetrieveSummaryResponse,
SummaryResponse,
RetrieveSummaryData,
RetrieveSummaryResponse
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.messages_tool import (
Resolve_username,
Summary_messages_deal,
Resolve_username
)
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
from app.core.rag.nlp.search import knowledge_retrieval
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
import os
from mcp.server.fastmcp import Context
# 加载.env文件
load_dotenv()
logger = get_agent_logger(__name__)
@@ -40,8 +38,9 @@ async def Summary(
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
user_rag_memory_id: str = "",
) -> dict:
"""
Summarize the verified data.
@@ -52,6 +51,7 @@ async def Summary(
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
@@ -59,12 +59,14 @@ async def Summary(
dict: Contains 'status' and 'summary_result'
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages)
@@ -155,7 +157,7 @@ async def Summary(
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"验证之后的总结==>>:{aimessages}")
logger.info(f"Summary after verification: {aimessages}")
# Log execution time
end = time.time()
@@ -163,7 +165,7 @@ async def Summary(
duration = end - start
except Exception:
duration = 0.0
log_time('总结', duration)
log_time('Summary', duration)
return {
"status": "success",
@@ -180,8 +182,9 @@ async def Retrieve_Summary(
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
user_rag_memory_id: str = "",
) -> dict:
"""
Summarize data directly from retrieval results.
@@ -192,6 +195,7 @@ async def Retrieve_Summary(
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
@@ -202,9 +206,11 @@ async def Retrieve_Summary(
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages)
@@ -212,6 +218,8 @@ async def Retrieve_Summary(
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
logger.debug(f"Retrieve_Summary: raw context type={type(context)}, keys={list(context.keys()) if isinstance(context, dict) else 'N/A'}")
if isinstance(context, dict):
if "content" in context:
inner = context["content"]
@@ -252,17 +260,19 @@ async def Retrieve_Summary(
query = context_dict.get("Query", "")
expansion_issue = context_dict.get("Expansion_issue", [])
logger.debug(f"Retrieve_Summary: query='{query}', expansion_issue count={len(expansion_issue)}")
logger.debug(f"Retrieve_Summary: expansion_issue={expansion_issue[:2] if expansion_issue else 'empty'}")
# Extract retrieve_info from expansion_issue
retrieve_info = []
for item in expansion_issue:
# Check for both Answer_Small and Answer_Samll (typo) for backward compatibility
# Check for both Answer_Small and Answer_Small (typo) for backward compatibility
answer = None
if isinstance(item, dict):
if "Answer_Small" in item:
answer = item["Answer_Small"]
elif "Answer_Samll" in item:
answer = item["Answer_Samll"]
if answer is not None:
# Handle both string and list formats
@@ -350,7 +360,7 @@ async def Retrieve_Summary(
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"检索之后的总结==>>:{aimessages}")
logger.info(f"Summary after retrieval: {aimessages}")
# Log execution time
end = time.time()
@@ -358,7 +368,7 @@ async def Retrieve_Summary(
duration = end - start
except Exception:
duration = 0.0
log_time('检索总结', duration)
log_time('Retrieval summary', duration)
# Emit intermediate output for frontend
return {
@@ -384,8 +394,9 @@ async def Input_Summary(
search_switch: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
user_rag_memory_id: str = "",
) -> dict:
"""
Generate a quick summary for direct input without verification.
@@ -397,6 +408,7 @@ async def Input_Summary(
search_switch: Search switch value for routing ('2' for summaries only)
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
@@ -406,21 +418,14 @@ async def Input_Summary(
start = time.time()
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# Initialize variables to avoid UnboundLocalError
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
search_service = get_context_resource(ctx, 'search_service')
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
search_service = get_context_resource(ctx, "search_service")
# Check if llm_client is None
if llm_client is None:
error_msg = "LLM client is not available. Please check server configuration and SELECTED_LLM_ID environment variable."
logger.error(error_msg)
return error_msg
# Get LLM client from memory_config
llm_client = get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages) or ""
@@ -479,7 +484,7 @@ async def Input_Summary(
# Add storage-specific parameters
'''检索'''
# Retrieval
if search_switch == '2':
search_params["include"] = ["summaries"]
if storage_type == "rag" and user_rag_memory_id:
@@ -509,12 +514,16 @@ async def Input_Summary(
except:
retrieve_info=''
raw_results=['']
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
logger.info("Input_Summary: 使用 summary 进行检索")
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
logger.info("Input_Summary: Using summary for retrieval")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
except Exception as e:
logger.error(
@@ -547,7 +556,7 @@ async def Input_Summary(
)
aimessages = "信息不足,无法回答"
logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
# Emit intermediate output for frontend
return {
@@ -587,7 +596,7 @@ async def Input_Summary(
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration)
log_time('Retrieval', duration)
@mcp.tool()

View File

@@ -5,20 +5,19 @@ This module contains MCP tools for verifying retrieved data.
"""
import time
from jinja2 import Template
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.verify_tool import VerifyTool
from app.core.memory.agent.utils.messages_tool import (
Verify_messages_deal,
Retrieve_verify_tool_messages_deal,
Resolve_username
)
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import (
Resolve_username,
Retrieve_verify_tool_messages_deal,
Verify_messages_deal,
)
from app.core.memory.agent.utils.verify_tool import VerifyTool
from app.schemas.memory_config_schema import MemoryConfig
from jinja2 import Template
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@@ -30,6 +29,7 @@ async def Verify(
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
@@ -42,6 +42,7 @@ async def Verify(
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
@@ -91,8 +92,12 @@ async def Verify(
# Call verification workflow
verify_tool = VerifyTool(system_prompt, messages)
# Call verification workflow with LLM model ID from memory_config
verify_tool = VerifyTool(
system_prompt=system_prompt,
verify_data=messages,
llm_model_id=str(memory_config.llm_model_id)
)
verify_result = await verify_tool.verify()
# Parse LLM verification result with error handling
@@ -118,7 +123,7 @@ async def Verify(
"history": history,
}
logger.info(f"验证==>>:{messages_deal}")
logger.info(f"Verification result: {messages_deal}")
# Emit intermediate output for frontend
return {
@@ -128,7 +133,7 @@ async def Verify(
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "verification",
"title": "数据验证",
"title": "Data Verification",
"result": messages_deal.get("split_result", "unknown"),
"reason": messages_deal.get("reason", ""),
"query": query,
@@ -166,4 +171,4 @@ async def Verify(
duration = end - start
except Exception:
duration = 0.0
log_time('验证', duration)
log_time('Verification', duration)

View File

@@ -1,22 +1,21 @@
import asyncio
import json
from collections import defaultdict
from typing import TypedDict, Annotated
import os
import logging
from jinja2 import Template
from langchain_core.messages import AnyMessage
from dotenv import load_dotenv
from langgraph.graph import add_messages
from openai import OpenAI
import os
from collections import defaultdict
from typing import Annotated, TypedDict
from app.core.memory.agent.utils.messages_tool import read_template_file
from app.core.memory.utils.config.config_utils import get_picture_config, get_voice_config
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
from app.core.models.base import RedBearModelConfig
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.config.config_utils import (
get_picture_config,
get_voice_config,
)
# Removed global variable imports - use dependency injection instead
from dotenv import load_dotenv
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from openai import OpenAI
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logger = logging.getLogger(__name__)
@@ -44,6 +43,7 @@ class WriteState(TypedDict):
user_id:str
apply_id:str
group_id:str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
class ReadState(TypedDict):
'''
@@ -53,6 +53,7 @@ class ReadState(TypedDict):
loop_count:Traverse times
search_switchtype
config_id: configuration id for filtering results
errors: list of errors that occurred during workflow execution
'''
messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息
name: str
@@ -63,6 +64,7 @@ class ReadState(TypedDict):
apply_id: str
group_id: str
config_id: str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
class COUNTState:
@@ -109,9 +111,17 @@ def deduplicate_entries(entries):
async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str:
async def Picture_recognize(image_path, PROMPT_TICKET_EXTRACTION, picture_model_name: str) -> str:
"""
Updated to eliminate global variables in favor of explicit parameters.
Args:
image_path: Path to image file
PROMPT_TICKET_EXTRACTION: Extraction prompt
picture_model_name: Picture model name (required, no longer from global variables)
"""
try:
model_config = get_picture_config(SELECTED_LLM_PICTURE_NAME)
model_config = get_picture_config(picture_model_name)
except Exception as e:
err = f"LLM配置不可用{str(e)}。请检查 config.json 和 runtime.json。"
logger.error(err)
@@ -147,9 +157,15 @@ async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str:
picture_text = json.loads(picture_text)
return (picture_text['statement'])
async def Voice_recognize():
async def Voice_recognize(voice_model_name: str):
"""
Updated to eliminate global variables in favor of explicit parameters.
Args:
voice_model_name: Voice model name (required, no longer from global variables)
"""
try:
model_config = get_voice_config(SELECTED_LLM_VOICE_NAME)
model_config = get_voice_config(voice_model_name)
except Exception as e:
err = f"LLM配置不可用{str(e)}。请检查 config.json 和 runtime.json。"
logger.error(err)

View File

@@ -1,10 +1,10 @@
import json
import logging
import re
from typing import List, Any
from typing import Any, List
from langchain_core.messages import AnyMessage
from app.core.logging_config import get_agent_logger
from langchain_core.messages import AnyMessage
logger = get_agent_logger(__name__)
@@ -119,11 +119,23 @@ async def Problem_Extension_messages_deal(context):
extent_quest = []
original = context.get('original', '')
messages = context.get('context', '')
messages = json.loads(messages)
for message in messages:
question = message.get('question', '')
type = message.get('type', '')
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
# Handle empty or non-string messages
if not messages:
return extent_quest, original
if isinstance(messages, str):
try:
messages = json.loads(messages)
except json.JSONDecodeError:
# If JSON parsing fails, return empty list
return extent_quest, original
if isinstance(messages, list):
for message in messages:
question = message.get('question', '')
type = message.get('type', '')
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
return extent_quest, original
@@ -135,10 +147,19 @@ async def Retriev_messages_deal(context):
context:
Returns:
'''
logger.info(f"Retriev_messages_deal input: type={type(context)}, value={str(context)[:500]}")
if isinstance(context, dict):
logger.info(f"Retriev_messages_deal: context is dict with keys={list(context.keys())}")
if 'context' in context or 'original' in context:
return context.get('context', {}), context.get('original', '')
return content, original_value
content = context.get('context', {})
original = context.get('original', '')
logger.info(f"Retriev_messages_deal output: content_type={type(content)}, content={str(content)[:300]}, original='{original[:50] if original else ''}'")
return content, original
# Return empty defaults if context is not a dict or doesn't have expected keys
logger.warning(f"Retriev_messages_deal: context missing expected keys, returning empty defaults")
return {}, ''
async def Verify_messages_deal(context):
'''

View File

@@ -1,22 +1,22 @@
# 角色
你是验证专家
你的目标是针对用户的输入Query_Samll字段的提问和Answer_Samll的回答分析是不是回答Query_Samll这个字段的问题
你的目标是针对用户的输入Query_Small字段的提问和Answer_Small的回答分析是不是回答Query_Small这个字段的问题
{#以下可以采用先总括,再展开详细说明的方式,描述你希望智能体在每一个步骤如何进行工作,具体的工作步骤数量可以根据实际需求增删#}
## 工作步骤
1. 获取所有的Query_Samll字段和Answer_Samll字段
2. 分析Answer_Samll的回复是不是和Query_Samll有关系
3. 判断Answer_Samll和Query_Samll之间分析出来的关系状态
1. 获取所有的Query_Small字段和Answer_Small字段
2. 分析Answer_Small的回复是不是和Query_Small有关系
3. 判断Answer_Small和Query_Small之间分析出来的关系状态
4. 如果是True保留否则不要相对应的问题和回答
5. 输出,需要严格按照模版
输入:{{history}}
历史消息:{"history":{{sentence}}}
### 第一步 获取用户的输入
获取用户的输入提取对应的Query_Samll和Answer_Samll
获取用户的输入提取对应的Query_Small和Answer_Small
### 第二步 分析验证
需要分析Query_Samll和Answer_Samll之间的关系可以参考history字段的内容如果有关系不是答非所问
需要分析Query_Small和Answer_Small之间的关系可以参考history字段的内容如果有关系不是答非所问
## 核心验证标准
在评估子问题拆分时必须严格遵循以下标准且验证过程中完全不依赖于子问题的相关信息Answer_Samll
在评估子问题拆分时必须严格遵循以下标准且验证过程中完全不依赖于子问题的相关信息Answer_Small
1. 合理性标准(必须全部满足)
- 完整性:每个不同的子问题必须完整覆盖原问题的所有关键要素(如时间、主体、动作、目标等),无遗漏。
- 最小化每个不同的子问题数量应尽可能少通常不超过原问题关键要素数量的2倍建议2-4个避免冗余和不必要拆分。

View File

@@ -19,12 +19,14 @@ class DistinguishTypeResponse(BaseModel):
type: str
async def status_typle(messages: str) -> dict:
async def status_typle(messages: str, llm_model_id: str) -> dict:
"""
Classify message type as read or write operation.
Updated to eliminate global variables in favor of explicit parameters.
Args:
messages: User message to classify
llm_model_id: LLM model ID to use (required, no longer from global variables)
Returns:
dict: Contains 'type' field with classification result
@@ -42,8 +44,7 @@ async def status_typle(messages: str) -> dict:
"message": f"Prompt rendering failed: {str(e)}"
}
from app.core.memory.utils.config import definitions as config_defs
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
llm_client = get_llm_client(llm_model_id)
try:
structured = await llm_client.response_structured(

View File

@@ -11,7 +11,7 @@ from langchain_core.messages import HumanMessage
from jinja2 import Environment, FileSystemLoader
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
# Removed global variable imports - use dependency injection instead
from app.core.logging_config import get_agent_logger
load_dotenv(find_dotenv())
@@ -31,8 +31,17 @@ class State(TypedDict):
class VerifyTool:
def __init__(self, system_prompt: str="", verify_data: Any=None):
def __init__(self, system_prompt: str="", verify_data: Any=None, llm_model_id: str=None):
"""
Updated to eliminate global variables in favor of explicit parameters.
Args:
system_prompt: System prompt for verification
verify_data: Data to verify
llm_model_id: LLM model ID (required, no longer from global variables)
"""
self.system_prompt = system_prompt
self.llm_model_id = llm_model_id
if isinstance(verify_data, str):
self.verify_data = verify_data
else:
@@ -42,7 +51,9 @@ class VerifyTool:
self.verify_data = str(verify_data)
async def model_1(self, state: State) -> State:
llm_client = get_llm_client(SELECTED_LLM_ID)
if not self.llm_model_id:
raise ValueError("llm_model_id is required but not provided")
llm_client = get_llm_client(self.llm_model_id)
response_content = await llm_client.chat(
messages=[{"role": "system", "content": self.system_prompt}, *_to_openai_messages(state["messages"])]
)

View File

@@ -1,80 +1,93 @@
import asyncio
from dotenv import load_dotenv
"""
Write Tools for Memory Knowledge Extraction Pipeline
This module provides the main write function for executing the knowledge extraction
pipeline. Only MemoryConfig is needed - clients are constructed internally.
"""
import time
from datetime import datetime
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
# 使用新的模块化架构
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
embedding_generation_all,
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
)
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# 导入配置模块(而不是直接导入变量)
from app.core.memory.utils.config import definitions as config_defs
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
Memory_summary_generation,
)
from app.core.memory.utils.embedder.embedder_utils import (
get_embedder_client_from_config,
)
from app.core.memory.utils.llm.llm_utils import get_llm_client_from_config
from app.core.memory.utils.log.logging_utils import log_time
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import Memory_summary_generation
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
load_dotenv()
logger = get_agent_logger(__name__)
async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id: str = "wyl20251027", config_id: str = None) -> None:
async def write(
content: str,
user_id: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
ref_id: str = "wyl20251027",
) -> None:
"""
执行完整的知识提取流水线(使用新的 ExtractionOrchestrator
Execute the complete knowledge extraction pipeline.
Only MemoryConfig is needed - LLM and embedding clients are constructed
internally from the config.
Args:
content: 对话内容
user_id: 用户ID
apply_id: 应用ID
group_id: 组ID
ref_id: 参考ID默认为 "wyl20251027"
config_id: 配置ID用于标记数据处理配置
content: Dialogue content to process
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
ref_id: Reference ID, defaults to "wyl20251027"
"""
# Extract config values
embedding_model_id = str(memory_config.embedding_model_id)
chunker_strategy = memory_config.chunker_strategy
config_id = str(memory_config.config_id)
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
logger.info(f"Using model: {config_defs.SELECTED_LLM_NAME}")
logger.info(f"Using LLM ID: {config_defs.SELECTED_LLM_ID}")
logger.info(f"Using chunker strategy: {config_defs.SELECTED_CHUNKER_STRATEGY}")
logger.info(f"Using group ID: {config_defs.SELECTED_GROUP_ID}")
logger.info(f"Using embedding ID: {config_defs.SELECTED_EMBEDDING_ID}")
logger.info(f"Config ID: {config_id if config_id else 'None'}")
logger.info(f"LANGFUSE_ENABLED: {config_defs.LANGFUSE_ENABLED}")
logger.info(f"AGENTA_ENABLED: {config_defs.AGENTA_ENABLED}")
logger.info(f"Config: {memory_config.config_name} (ID: {config_id})")
logger.info(f"Workspace: {memory_config.workspace_name}")
logger.info(f"LLM model: {memory_config.llm_model_name}")
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
logger.info(f"Chunker strategy: {chunker_strategy}")
logger.info(f"Group ID: {group_id}")
# Construct clients from memory_config
llm_client = get_llm_client_from_config(memory_config)
embedder_client = get_embedder_client_from_config(memory_config)
logger.info("LLM and embedding clients constructed")
# Initialize timing log
log_file = "logs/time.log"
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pipeline Run Started: {timestamp} ===\n")
f.write(f"Config: {memory_config.config_name} (ID: {config_id})\n")
pipeline_start = time.time()
# 初始化客户端
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
# 获取 embedder 配置
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
# Initialize Neo4j connector
neo4j_connector = Neo4jConnector()
# Step 1: 加载和分块数据
# Step 1: Load and chunk data
step_start = time.time()
chunked_dialogs = await get_chunked_dialogs(
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
chunker_strategy=chunker_strategy,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
@@ -83,21 +96,21 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
config_id=config_id,
)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# Step 2: 初始化并运行 ExtractionOrchestrator
# Step 2: Initialize and run ExtractionOrchestrator
step_start = time.time()
from app.core.memory.utils.config.config_utils import get_pipeline_config
config = get_pipeline_config()
pipeline_config = get_pipeline_config()
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
config=pipeline_config,
embedding_id=embedding_model_id,
)
# 运行完整的提取流水线
# orchestrator.run returns a flat tuple of 7 values after deduplication
# Run the complete extraction pipeline
(
all_dialogue_nodes,
all_chunk_nodes,
@@ -107,14 +120,12 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
all_statement_entity_edges,
all_entity_entity_edges,
all_dedup_details,
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# Step 8: Save all data to Neo4j database using graph models
# Step 3: Save all data to Neo4j database
step_start = time.time()
# 运行索引创建
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
try:
await create_fulltext_indexes()
@@ -141,18 +152,16 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
log_time("Neo4j Database Save", time.time() - step_start, log_file)
# Step 9: Generate Memory summaries and save to local vector DB and Neo4j
# Step 4: Generate Memory summaries and save to Neo4j
step_start = time.time()
try:
summaries = await Memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
chunked_dialogs, llm_client=llm_client, embedding_id=embedding_model_id
)
# Save memory summaries to Neo4j as nodes
try:
ms_connector = Neo4jConnector()
await add_memory_summary_nodes(summaries, ms_connector)
# Link summaries to statements via chunks for summary→entity queries
await add_memory_summary_statement_edges(summaries, ms_connector)
finally:
try:
@@ -162,24 +171,15 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
finally:
log_time("Memory Summary (Local Vector DB & Neo4j)", time.time() - step_start, log_file)
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
# Log total pipeline time
total_time = time.time() - pipeline_start
log_time("TOTAL PIPELINE TIME", total_time, log_file)
# Add completion marker to log
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")
logger.info(f"Timing details saved to: {log_file}")
if __name__ == "__main__":
content = "你好,我是张三,是张曼婷的新朋友。请问张曼婷喜欢什么?"
asyncio.run(write(content, ref_id="wyl20251027"))

View File

@@ -1,8 +1,9 @@
import sys
import os
import asyncio
from neo4j import GraphDatabase
import os
import sys
from typing import List, Tuple
from neo4j import GraphDatabase
from pydantic import BaseModel, Field
# ------------------- 自包含路径解析 -------------------
@@ -31,11 +32,16 @@ except NameError:
# ---------------------------------------------------------------------
# 现在路径已经配置好,我们可以使用绝对导入
from app.core.config import settings
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
from app.core.memory.utils.llm.llm_utils import get_llm_client
import json
from app.core.config import settings
from app.core.memory.utils.llm.llm_utils import get_llm_client
#TODO: Fix this
# Default values (previously from definitions.py)
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
# 定义用于LLM结构化输出的Pydantic模型
class FilteredTags(BaseModel):
"""用于接收LLM筛选后的核心标签列表的模型。"""
@@ -140,8 +146,8 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
limit: 返回的标签数量限制
by_user: 是否按user_id查询默认False按group_id查询
"""
# 默认从 runtime.json selections.group_id 读取
group_id = group_id or SELECTED_GROUP_ID
# 默认从环境变量读取
group_id = group_id or DEFAULT_GROUP_ID
# 1. 从数据库获取原始排名靠前的标签
raw_tags_with_freq = get_raw_tags_from_db(group_id, limit, by_user=by_user)
if not raw_tags_with_freq:
@@ -150,8 +156,7 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
from app.core.memory.utils.config import definitions as config_defs
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
llm_client = get_llm_client(DEFAULT_LLM_ID)
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, llm_client)
# 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序
@@ -165,8 +170,8 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
if __name__ == "__main__":
print("开始获取热门记忆标签...")
try:
# 直接使用 runtime.json 中的 group_id
group_id_to_query = SELECTED_GROUP_ID
# 直接使用环境变量中的 group_id
group_id_to_query = DEFAULT_GROUP_ID
# 使用 asyncio.run 来执行异步主函数
top_tags = asyncio.run(get_hot_memory_tags(group_id=group_id_to_query))

View File

@@ -5,9 +5,9 @@ This script can be executed directly to generate a memory insight report for a t
"""
import asyncio
import json
import os
import sys
import json
from collections import Counter
from datetime import datetime
@@ -17,12 +17,16 @@ src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if src_path not in sys.path:
sys.path.insert(0, src_path)
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from pydantic import BaseModel, Field
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
#TODO: Fix this
# Default values (previously from definitions.py)
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
# 定义用于LLM结构化输出的Pydantic模型
class TagClassification(BaseModel):
@@ -55,8 +59,7 @@ class MemoryInsight:
def __init__(self, user_id: str):
self.user_id = user_id
self.neo4j_connector = Neo4jConnector()
from app.core.memory.utils.config import definitions as config_defs
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
self.llm_client = get_llm_client(DEFAULT_LLM_ID)
async def close(self):
"""关闭数据库连接。"""
@@ -294,8 +297,8 @@ async def main():
"""
Initializes and runs the memory insight analysis for a test user.
"""
# 默认从 runtime.json selections.group_id 读取
test_user_id = SELECTED_GROUP_ID
# 默认从环境变量读取
test_user_id = DEFAULT_GROUP_ID
print(f"正在为用户 {test_user_id} 生成记忆洞察报告...\n")
insight = None

View File

@@ -6,10 +6,10 @@ Usage:
python -m analytics.user_summary --user_id <group_id>
"""
import os
import sys
import asyncio
import json
import os
import sys
from dataclasses import dataclass
from typing import List, Tuple
@@ -24,10 +24,15 @@ try:
except Exception:
pass
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
#TODO: Fix this
# Default values (previously from definitions.py)
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
@dataclass
@@ -42,8 +47,7 @@ class UserSummary:
def __init__(self, user_id: str):
self.user_id = user_id
self.connector = Neo4jConnector()
from app.core.memory.utils.config import definitions as config_defs
self.llm = get_llm_client(config_defs.SELECTED_LLM_ID)
self.llm = get_llm_client(DEFAULT_LLM_ID)
async def close(self):
await self.connector.close()
@@ -107,8 +111,8 @@ class UserSummary:
async def generate_user_summary(user_id: str | None = None) -> str:
# 默认从 runtime.json selections.group_id 读取
effective_group_id = user_id or SELECTED_GROUP_ID
# 默认从环境变量读取
effective_group_id = user_id or DEFAULT_GROUP_ID
svc = UserSummary(effective_group_id)
try:
return await svc.generate()
@@ -139,7 +143,7 @@ if __name__ == "__main__":
with open(dashboard_path, "r", encoding="utf-8") as rf:
existing = json.load(rf)
existing["user_summary"] = {
"group_id": SELECTED_GROUP_ID,
"group_id": DEFAULT_GROUP_ID,
"summary": summary
}
with open(dashboard_path, "w", encoding="utf-8") as wf:

View File

@@ -1,132 +0,0 @@
{
"llm_list": [
{
"llm_name": "qwen2.5-14b-instruct-awq",
"api_base": "http://175.27.131.196:9090/v1",
"api_key": "OPENAI_API_KEY"
},
{
"llm_name": "openai/qwen2.5-14b-instruct-awq",
"api_base": "http://175.27.131.196:9090/v1",
"api_key": "OPENAI_API_AGENT_KEY"
},
{
"llm_name": "openai/qwen2.5-14b",
"api_base": "http://43.137.4.24:9090/v1",
"api_key": "OPENAI_API_KEY"
},
{
"llm_name": "openai/qwen2.5-14b-instruct-awq",
"api_base": "http://175.27.131.196:9090/v1",
"api_key": "OPENAI_API_KEY"
},
{
"llm_name": "openai/qwen3-14b",
"api_base": "http://43.137.4.24:9090/v1",
"api_key": "OPENAI_API_KEY"
},
{
"llm_name": "openai/deepseek-r1-0528-qwen3-8b",
"api_base": "http://43.137.4.24:9090/v1",
"api_key": "OPENAI_API_KEY"
},
{
"llm_name": "openai/qwen3-235b-a22b-instruct-2507",
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"api_key": "DASHSCOPE_API_KEY"
}
,
{
"llm_name": "openai/qwen-plus",
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"api_key": "DASHSCOPE_API_KEY"
},
{
"llm_name": "bedrock/global.anthropic.claude-sonnet-4-5-20250929-v1:0"
},
{
"llm_name": "bedrock/global.anthropic.claude-sonnet-4-20250514-v1:0"
}
],
"embedding_list": [
{
"embedding_name": "openai/nomic-embed-text:v1.5",
"api_base": "http://119.45.239.97:11434/v1",
"dimension": 768
},
{
"embedding_name": "openai/bge-m3",
"api_base": "http://43.137.4.24:9090/v1",
"dimension": 1024
}
],
"neo4j": {
"uri": "bolt://1.94.111.67:7687",
"username": "neo4j"
},
"chunker_list": [
{
"chunker_strategy": "TokenChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 512,
"chunk_overlap": 56,
"tokenizer_or_token_counter": "character"
},
{
"chunker_strategy": "RecursiveChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 512,
"min_characters_per_chunk": 50
},
{
"chunker_strategy": "SemanticChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 1024,
"threshold": 0.8,
"min_sentences": 2,
"skip_window": 1,
"min_characters_per_chunk": 100
},
{
"chunker_strategy": "LateChunker",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size": 2048,
"min_characters_per_chunk": 24
},
{
"chunker_strategy": "NeuralChunker",
"embedding_model": "mirth/chonky_modernbert_base_1",
"min_characters_per_chunk": 24
},
{
"chunker_strategy": "LLMChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 1000,
"min_characters_per_chunk": 100
},
{
"chunker_strategy": "HybridChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 512,
"threshold": 0.8,
"min_characters_per_chunk": 100
},
{
"chunker_strategy": "SentenceChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 2048,
"chunk_overlap": 128,
"min_sentences_per_chunk": 1,
"min_characters_per_sentence": 12,
"delim": [".", "!", "?", "\n"],
"include_delim": "prev",
"tokenizer_or_token_counter": "character"
}
],
"langfuse": {
"enabled": true
},
"agenta": {
"enabled": false
}
}

File diff suppressed because one or more lines are too long

View File

@@ -1,5 +0,0 @@
{
"selections": {
"config_id": ""
}
}

View File

@@ -51,16 +51,31 @@ logger = get_memory_logger(__name__)
async def main(
# Required configuration parameters (no longer from global variables)
chunker_strategy: str,
group_id: str,
user_id: str,
apply_id: str,
llm_model_id: str,
embedding_model_id: str,
# Optional parameters
dialogue_text: Optional[str] = None,
is_pilot_run: bool = False,
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None
):
"""
记忆系统主流程 - 重构版本
记忆系统主流程 - 重构版本 (Updated to eliminate global variables)
该函数是重构后的主入口,使用新的模块化架构。
Global variables have been eliminated in favor of explicit parameters.
Args:
chunker_strategy: Chunking strategy to use (required)
group_id: Group ID for the operation (required)
user_id: User ID for the operation (required)
apply_id: Application ID for the operation (required)
llm_model_id: LLM model ID to use (required)
embedding_model_id: Embedding model ID to use (required)
dialogue_text: 输入的对话文本(可选,用于试运行模式)
is_pilot_run: 是否为试运行模式
- True: 试运行模式,不保存到 Neo4j
@@ -82,12 +97,10 @@ async def main(
print("MemSci 知识提取流水线 - 重构版本")
print("=" * 60)
print(f"运行模式: {'试运行不保存到Neo4j' if is_pilot_run else '正常运行保存到Neo4j'}")
print("Using chunker strategy:", config_defs.SELECTED_CHUNKER_STRATEGY)
print("Using group ID:", config_defs.SELECTED_GROUP_ID)
print("Using model ID:", config_defs.SELECTED_LLM_ID)
print("Using embedding model ID:", config_defs.SELECTED_EMBEDDING_ID)
print("LANGFUSE_ENABLED:", config_defs.LANGFUSE_ENABLED)
print("AGENTA_ENABLED:", config_defs.AGENTA_ENABLED)
print("Using chunker strategy:", chunker_strategy)
print("Using group ID:", group_id)
print("Using model ID:", llm_model_id)
print("Using embedding model ID:", embedding_model_id)
print("=" * 60)
# 初始化日志
@@ -104,11 +117,11 @@ async def main(
logger.info("Initializing clients...")
step_start = time.time()
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
llm_client = get_llm_client(llm_model_id)
# 获取 embedder 配置并转换为 RedBearModelConfig 对象
from app.core.models.base import RedBearModelConfig
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
embedder_config_dict = get_embedder_config(embedding_model_id)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
@@ -145,9 +158,9 @@ async def main(
dialog = DialogData(
context=context,
ref_id="pilot_dialog_1",
group_id=config_defs.SELECTED_GROUP_ID,
user_id=config_defs.SELECTED_USER_ID,
apply_id=config_defs.SELECTED_APPLY_ID,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
metadata={"source": "pilot_run", "input_type": "frontend_text"}
)
@@ -158,7 +171,7 @@ async def main(
# 对前端传入的对话进行分块处理
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
chunker_strategy=chunker_strategy,
llm_client=llm_client,
)
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
@@ -172,7 +185,7 @@ async def main(
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dialog.id,
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
"chunker_strategy": chunker_strategy
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
@@ -180,7 +193,7 @@ async def main(
preprocessing_summary = {
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
"chunker_strategy": chunker_strategy
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
else:
@@ -199,11 +212,11 @@ async def main(
await progress_callback("text_preprocessing", "开始预处理文本...")
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
group_id=config_defs.SELECTED_GROUP_ID,
user_id=config_defs.SELECTED_USER_ID,
apply_id=config_defs.SELECTED_APPLY_ID,
indices=config_defs.SELECTED_TEST_DATA_INDICES,
chunker_strategy=chunker_strategy,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
indices=None,
input_data_path=test_data_path,
llm_client=llm_client,
skip_cleaning=True,
@@ -219,7 +232,7 @@ async def main(
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dialog.id,
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
"chunker_strategy": chunker_strategy
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
@@ -227,7 +240,7 @@ async def main(
preprocessing_summary = {
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
"chunker_strategy": chunker_strategy
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
@@ -249,6 +262,7 @@ async def main(
connector=neo4j_connector,
config=config,
progress_callback=progress_callback, # 传递进度回调
embedding_id=embedding_model_id, # 传递嵌入模型ID
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
@@ -352,7 +366,7 @@ async def main(
)
summaries = await Memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
chunked_dialogs, llm_client=llm_client, embedding_id=embedding_model_id
)
if not is_pilot_run:
@@ -400,4 +414,17 @@ async def main(
if __name__ == "__main__":
asyncio.run(main())
print("⚠️ Warning: This script now requires explicit configuration parameters.")
print("Global variables have been removed. Please provide configuration parameters.")
print("Example usage:")
print(" asyncio.run(main(")
print(" chunker_strategy='RecursiveChunker',")
print(" group_id='your_group_id',")
print(" user_id='your_user_id',")
print(" apply_id='your_apply_id',")
print(" llm_model_id='your_llm_id',")
print(" embedding_model_id='your_embedding_id'")
print(" ))")
# This will fail because global variables are removed
raise RuntimeError("Global variables removed. Please provide explicit configuration parameters.")

View File

@@ -1,31 +1,41 @@
import argparse
import asyncio
import json
import math
import os
import time
from typing import List, Dict, Any, Optional
from dotenv import load_dotenv
from datetime import datetime
import math
from typing import Any, Dict, List, Optional
from app.core.logging_config import get_memory_logger
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import (
search_graph_by_embedding, search_graph,
search_graph_by_temporal, search_graph_by_keyword_temporal,
search_graph_by_chunk_id
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.config_models import TemporalSearchParams
from app.core.memory.utils.config.config_utils import get_embedder_config, get_pipeline_config
from app.core.memory.utils.data.time_utils import normalize_date_safe
from app.core.memory.models.variate_config import ForgettingEngineConfig
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
from app.core.memory.utils.data.text_utils import extract_plain_query
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import (
ForgettingEngine,
)
from app.core.memory.utils.config import definitions as config_defs
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import (
get_embedder_config,
get_pipeline_config,
)
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
from app.core.memory.utils.data.text_utils import extract_plain_query
from app.core.memory.utils.data.time_utils import normalize_date_safe
from app.core.memory.utils.llm.llm_utils import get_reranker_client
from app.core.models.base import RedBearModelConfig
from app.repositories.neo4j.graph_search import (
search_graph,
search_graph_by_chunk_id,
search_graph_by_embedding,
search_graph_by_keyword_temporal,
search_graph_by_temporal,
)
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from dotenv import load_dotenv
load_dotenv()
logger = get_memory_logger(__name__)
@@ -131,7 +141,7 @@ def rerank_hybrid_results(
# Add keyword results with BM25 scores
for item in keyword_items:
item_id = item.get("id") or item.get("uuid")
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
@@ -139,7 +149,7 @@ def rerank_hybrid_results(
# Add or update with embedding results
for item in embedding_items:
item_id = item.get("id") or item.get("uuid")
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id:
if item_id in combined_items:
# Update existing item with embedding score
@@ -220,7 +230,7 @@ def rerank_with_forgetting_curve(
(keyword_items, False), (embedding_items, True)
):
for item in src_items:
item_id = item.get("id") or item.get("uuid")
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if not item_id:
continue
existing = combined_items.get(item_id)
@@ -266,26 +276,25 @@ def rerank_with_forgetting_curve(
return reranked
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = "search_log.txt"):
"""Log search query information to file"""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None):
"""Log search query information using the logger.
Args:
query_text: The search query text
search_type: Type of search (keyword, embedding, hybrid)
group_id: Group identifier for filtering
limit: Maximum number of results
include: List of result types to include
log_file: Deprecated parameter, kept for backward compatibility
"""
# Ensure the query text is plain and clean before logging
cleaned_query = extract_plain_query(query_text)
log_entry = {
"timestamp": timestamp,
# "query": query_text,
"query": cleaned_query,
"search_type": search_type,
"group_id": group_id,
"limit": limit,
"include": include
}
# Append to log file
with open(log_file, "a", encoding="utf-8") as f:
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
logger.info(f"Search logged: {query_text} ({search_type})")
# Log using the standard logger
logger.info(
f"Search query: query='{cleaned_query}', type={search_type}, "
f"group_id={group_id}, limit={limit}, include={include}"
)
def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
@@ -547,6 +556,7 @@ async def run_hybrid_search(
limit: int,
include: List[str],
output_path: str | None,
embedding_id: str,
rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
@@ -558,6 +568,7 @@ async def run_hybrid_search(
# Start overall timing
search_start_time = time.time()
latency_metrics = {}
logger.info(f"using embedding_id:{embedding_id}...")
# Clean and normalize the incoming query before use/logging
query_text = extract_plain_query(query_text)
@@ -610,7 +621,7 @@ async def run_hybrid_search(
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
config_load_start = time.time()
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
embedder_config_dict = get_embedder_config(embedding_id)
rb_config = RedBearModelConfig(
model_name=embedder_config_dict["model_name"],
provider=embedder_config_dict["provider"],
@@ -759,18 +770,11 @@ async def run_hybrid_search(
else:
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
completion_log = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"query": query_text,
"search_type": search_type,
"status": "completed",
"result_counts": result_counts,
"output_file": output_path,
"latency_metrics": latency_metrics
}
with open("search_log.txt", "a", encoding="utf-8") as f:
f.write(json.dumps(completion_log, ensure_ascii=False) + "\n")
# Log completion using the standard logger
logger.info(
f"Search completed: query='{query_text}', type={search_type}, "
f"result_counts={result_counts}, latency={latency_metrics}"
)
return results
@@ -969,6 +973,7 @@ def main():
limit=args.limit,
include=args.include,
output_path=args.output,
embedding_id=config_defs.SELECTED_EMBEDDING_ID,
rerank_alpha=args.rerank_alpha,
use_forgetting_rerank=args.forgetting_rerank,
use_llm_rerank=args.llm_rerank,

View File

@@ -19,50 +19,50 @@
import asyncio
import logging
import os
from typing import List, Dict, Any, Tuple, Optional, Callable, Awaitable
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from app.core.memory.models.message_models import DialogData
from app.core.memory.llm_tools.openai_client import LLMClient
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.graph_models import (
DialogueNode,
ChunkNode,
StatementNode,
DialogueNode,
EntityEntityEdge,
ExtractedEntityNode,
StatementChunkEdge,
StatementEntityEdge,
EntityEntityEdge,
StatementNode,
)
from app.core.memory.utils.data.ontology import TemporalInfo
from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import (
ExtractionPipelineConfig,
StatementExtractionConfig,
)
from app.core.memory.llm_tools.openai_client import LLMClient
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# 导入各个提取模块
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
StatementExtractor,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.triplet_extraction import (
TripletExtractor,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.temporal_extraction import (
TemporalExtractor,
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
dedup_layers_and_merge_and_return,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
embedding_generation,
embedding_generation_all,
generate_entity_embeddings_from_triplets,
)
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
dedup_layers_and_merge_and_return,
# 导入各个提取模块
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
StatementExtractor,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.temporal_extraction import (
TemporalExtractor,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.triplet_extraction import (
TripletExtractor,
)
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
_write_extracted_result_summary,
export_test_input_doc,
)
from app.core.memory.utils.data.ontology import TemporalInfo
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# 配置日志
logger = logging.getLogger(__name__)
@@ -96,6 +96,7 @@ class ExtractionOrchestrator:
connector: Neo4jConnector,
config: Optional[ExtractionPipelineConfig] = None,
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
embedding_id: Optional[str] = None,
):
"""
初始化流水线编排器
@@ -108,6 +109,7 @@ class ExtractionOrchestrator:
progress_callback: 进度回调函数
- 接受 (stage: str, message: str, data: Optional[Dict[str, Any]]) 并返回 Awaitable[None]
- 在管线关键点调用以报告进度和结果数据
embedding_id: 嵌入模型ID如果为 None 则从全局配置获取(向后兼容)
"""
self.llm_client = llm_client
self.embedder_client = embedder_client
@@ -115,6 +117,7 @@ class ExtractionOrchestrator:
self.config = config or ExtractionPipelineConfig()
self.is_pilot_run = False # 默认非试运行模式
self.progress_callback = progress_callback # 保存进度回调函数
self.embedding_id = embedding_id # 保存嵌入模型ID
# 保存去重消歧的详细记录(内存中的数据结构)
self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录
@@ -420,7 +423,9 @@ class ExtractionOrchestrator:
return await self.triplet_extractor._extract_triplets(statement, chunk_content)
except Exception as e:
logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}")
from app.core.memory.models.triplet_models import TripletExtractionResponse
from app.core.memory.models.triplet_models import (
TripletExtractionResponse,
)
return TripletExtractionResponse(triplets=[], entities=[])
tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements]
@@ -434,7 +439,9 @@ class ExtractionOrchestrator:
d_idx, stmt_id = statement_metadata[i]
if isinstance(result, Exception):
logger.error(f"陈述句处理异常: {result}")
from app.core.memory.models.triplet_models import TripletExtractionResponse
from app.core.memory.models.triplet_models import (
TripletExtractionResponse,
)
triplet_maps[d_idx][stmt_id] = TripletExtractionResponse(triplets=[], entities=[])
else:
triplet_maps[d_idx][stmt_id] = result
@@ -521,8 +528,8 @@ class ExtractionOrchestrator:
temporal_maps[d_idx][stmt_id] = result
# 为 ATEMPORAL 陈述句添加空的时间范围
from app.core.memory.utils.data.ontology import TemporalInfo
from app.core.memory.models.message_models import TemporalValidityRange
from app.core.memory.utils.data.ontology import TemporalInfo
for d_idx, dialog in enumerate(dialog_data_list):
for chunk in dialog.chunks:
for statement in chunk.statements:
@@ -629,17 +636,14 @@ class ExtractionOrchestrator:
logger.info("开始生成基础嵌入向量(陈述句、分块、对话)")
try:
# 从 runtime.json 获取嵌入模型配置ID
from app.core.memory.utils.config import definitions as config_defs
embedding_id = config_defs.SELECTED_EMBEDDING_ID
if not embedding_id:
logger.error("未在 runtime.json 中配置 embedding 模型 ID")
raise ValueError("未配置嵌入模型ID")
# embedding_id is required - no fallback to global variable
if not self.embedding_id:
logger.error("embedding_id is required but was not provided to ExtractionOrchestrator")
raise ValueError("embedding_id is required but was not provided")
# 只生成陈述句、分块和对话的嵌入(不包括实体)
statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = await embedding_generation(
dialog_data_list, embedding_id
dialog_data_list, self.embedding_id
)
# 统计生成结果
@@ -683,17 +687,14 @@ class ExtractionOrchestrator:
logger.info("开始生成实体嵌入向量")
try:
# 从 runtime.json 获取嵌入模型配置ID
from app.core.memory.utils.config import definitions as config_defs
embedding_id = config_defs.SELECTED_EMBEDDING_ID
if not embedding_id:
logger.error("未在 runtime.json 中配置 embedding 模型 ID")
# embedding_id is required - no fallback to global variable
if not self.embedding_id:
logger.error("embedding_id is required but was not provided to ExtractionOrchestrator")
return triplet_maps
# 生成实体嵌入
updated_triplet_maps = await generate_entity_embeddings_from_triplets(
triplet_maps, embedding_id
triplet_maps, self.embedding_id
)
logger.info("实体嵌入生成完成")
@@ -1086,7 +1087,9 @@ class ExtractionOrchestrator:
if self.is_pilot_run:
logger.info("试运行模式:仅执行第一层去重,跳过第二层数据库去重")
# 只执行第一层去重
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
deduplicate_entities_and_edges,
)
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
entity_nodes,
@@ -1608,8 +1611,8 @@ async def get_chunked_dialogs(
包含分块的 DialogData 对象列表
"""
import json
import re
import os
import re
# 加载测试数据
testdata_path = os.path.join(os.path.dirname(__file__), "../../data", "testdata.json")
@@ -1671,7 +1674,9 @@ async def get_chunked_dialogs(
)
# 创建分块器并处理对话
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
)
chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data)
dialog_data.chunks = extracted_chunks
@@ -1718,7 +1723,9 @@ def preprocess_data(
经过清洗转换后的 DialogData 列表
"""
print("\n=== 数据预处理 ===")
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import DataPreprocessor
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
DataPreprocessor,
)
preprocessor = DataPreprocessor()
try:
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
@@ -1749,7 +1756,9 @@ async def get_chunked_dialogs_from_preprocessed(
raise ValueError("预处理数据为空,无法进行分块")
all_chunked_dialogs: List[DialogData] = []
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
)
for dialog_data in data:
chunker = DialogueChunker(chunker_strategy, llm_client=llm_client)
@@ -1811,7 +1820,9 @@ async def get_chunked_dialogs_with_preprocessing(
# 步骤2: 语义剪枝
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
SemanticPruner,
)
pruner = SemanticPruner(llm_client=llm_client)
# 记录单对话场景下剪枝前的消息数量
@@ -1834,7 +1845,9 @@ async def get_chunked_dialogs_with_preprocessing(
# 保存剪枝后的数据
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import DataPreprocessor
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
DataPreprocessor,
)
pruned_output_path = settings.get_memory_output_path("pruned_data.json")
dp = DataPreprocessor(output_file_path=pruned_output_path)
dp.save_data(preprocessed_data, output_path=pruned_output_path)

View File

@@ -1,447 +0,0 @@
# TODO hybrid_chatbot.py 是一个独立的GUI演示应用不是核心功能的一部分可以考虑删除
from app.core.memory.utils.llm.llm_utils import get_llm_client
import asyncio
import os
import time
import json
from datetime import datetime, timezone
import tkinter as tk
from tkinter import scrolledtext, messagebox
import threading
from typing import Any, Dict, Tuple, List
# Import our hybrid search functionality
from app.core.memory.storage_services.search import run_hybrid_search
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.models.config_models import LLMConfig
from dotenv import load_dotenv
load_dotenv()
class HybridSearchChatbot:
def __init__(self):
from app.core.memory.utils.config import definitions as config_defs
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
# Chat history
self.chat_history = []
# Search configuration
self.search_config = {
"group_id": "group_wyl_25",
"limit": 10,
"include": ["statements", "chunks", "entities","summaries"],
# "include": ["statements", "dialogues", "entities"],
"rerank_alpha": 0.6
}
# Setup GUI
self.setup_gui()
def setup_gui(self):
"""Setup the GUI interface"""
self.root = tk.Tk()
self.root.title("Hybrid Search Chatbot")
self.root.geometry("800x600")
# Chat display area
self.chat_display = scrolledtext.ScrolledText(
self.root,
wrap=tk.WORD,
width=80,
height=25,
state=tk.DISABLED
)
self.chat_display.pack(padx=10, pady=10, fill=tk.BOTH, expand=True)
# Input frame
input_frame = tk.Frame(self.root)
input_frame.pack(padx=10, pady=5, fill=tk.X)
# User input
self.user_input = tk.Entry(input_frame, font=("Arial", 12))
self.user_input.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 5))
self.user_input.bind("<Return>", self.on_send_message)
# Send button
self.send_button = tk.Button(
input_frame,
text="发送",
command=self.on_send_message,
font=("Arial", 12)
)
self.send_button.pack(side=tk.RIGHT)
# Status frame
status_frame = tk.Frame(self.root)
status_frame.pack(padx=10, pady=5, fill=tk.X)
# Status label
self.status_label = tk.Label(
status_frame,
text="就绪",
font=("Arial", 10),
anchor="w"
)
self.status_label.pack(side=tk.LEFT, fill=tk.X, expand=True)
# Search config button
config_button = tk.Button(
status_frame,
text="搜索配置",
command=self.show_config_dialog,
font=("Arial", 10)
)
config_button.pack(side=tk.RIGHT)
# Add welcome message
self.add_message("系统", "欢迎使用混合搜索聊天机器人!我可以基于知识图谱中的信息回答您的问题。")
def add_message(self, sender: str, message: str, metadata: Dict = None):
"""Add a message to the chat display"""
self.chat_display.config(state=tk.NORMAL)
timestamp = datetime.now().strftime("%H:%M:%S")
# Add sender and timestamp
self.chat_display.insert(tk.END, f"[{timestamp}] {sender}:\n", "sender")
# Add message content
self.chat_display.insert(tk.END, f"{message}\n", "message")
# Add metadata if available
if metadata:
self.chat_display.insert(tk.END, f" {metadata}\n", "metadata")
self.chat_display.insert(tk.END, "\n")
self.chat_display.config(state=tk.DISABLED)
self.chat_display.see(tk.END)
# Configure text tags for styling
self.chat_display.tag_config("sender", foreground="blue", font=("Arial", 10, "bold"))
self.chat_display.tag_config("message", foreground="black", font=("Arial", 10))
self.chat_display.tag_config("metadata", foreground="gray", font=("Arial", 8))
def show_config_dialog(self):
"""Show search configuration dialog"""
config_window = tk.Toplevel(self.root)
config_window.title("搜索配置")
config_window.geometry("400x600")
config_window.transient(self.root)
config_window.grab_set()
# Current configuration display
current_config_frame = tk.Frame(config_window)
current_config_frame.pack(pady=10, padx=10, fill=tk.X)
tk.Label(current_config_frame, text="当前配置:", font=("Arial", 10, "bold")).pack(anchor="w")
current_text = f"Alpha: {self.search_config['rerank_alpha']}, 限制: {self.search_config['limit']}, 目标: {', '.join(self.search_config['include'])}"
tk.Label(current_config_frame, text=current_text, font=("Arial", 9), fg="blue").pack(anchor="w")
# Alpha parameter
tk.Label(config_window, text="重排权重 (Alpha):").pack(pady=(10, 5))
alpha_var = tk.DoubleVar(value=self.search_config["rerank_alpha"])
alpha_scale = tk.Scale(
config_window,
from_=0.0,
to=1.0,
resolution=0.1,
orient=tk.HORIZONTAL,
variable=alpha_var
)
alpha_scale.pack(pady=5, padx=20, fill=tk.X)
tk.Label(config_window, text="0.0=纯语义搜索, 1.0=纯关键词搜索", font=("Arial", 8)).pack()
# Limit parameter
tk.Label(config_window, text="搜索结果数量:").pack(pady=(20, 5))
limit_var = tk.IntVar(value=self.search_config["limit"])
limit_spinbox = tk.Spinbox(
config_window,
from_=1,
to=50,
textvariable=limit_var,
width=10
)
limit_spinbox.pack(pady=5)
# Include options
tk.Label(config_window, text="搜索目标:").pack(pady=(20, 5))
include_frame = tk.Frame(config_window)
include_frame.pack(pady=5)
include_vars = {}
for option in ["statements", "chunks", "entities","summaries"]:
var = tk.BooleanVar(value=option in self.search_config["include"])
include_vars[option] = var
tk.Checkbutton(
include_frame,
text=option,
variable=var
).pack(side=tk.LEFT, padx=10)
# Buttons
button_frame = tk.Frame(config_window)
button_frame.pack(pady=20)
def save_config():
try:
# Validate inputs
alpha_value = alpha_var.get()
limit_value = limit_var.get()
include_list = [
option for option, var in include_vars.items() if var.get()
]
# Check if at least one search target is selected
if not include_list:
messagebox.showerror("配置错误", "请至少选择一个搜索目标!")
return
# Update configuration
self.search_config["rerank_alpha"] = alpha_value
self.search_config["limit"] = limit_value
self.search_config["include"] = include_list
config_window.destroy()
self.add_message("系统",
f"配置已更新: Alpha={alpha_value:.1f}, 限制={limit_value}, 目标={', '.join(include_list)}")
except Exception as e:
messagebox.showerror("配置错误", f"保存配置时出错: {str(e)}")
print(f"Config save error: {e}") # Debug output
tk.Button(button_frame, text="保存", command=save_config).pack(side=tk.LEFT, padx=5)
tk.Button(button_frame, text="取消", command=config_window.destroy).pack(side=tk.LEFT, padx=5)
def on_send_message(self, event=None):
"""Handle sending a message"""
user_message = self.user_input.get().strip()
if not user_message:
return
# Clear input
self.user_input.delete(0, tk.END)
# Add user message to display
self.add_message("用户", user_message)
# Disable send button and show processing status
self.send_button.config(state=tk.DISABLED)
self.status_label.config(text="正在搜索和生成回复...")
# Process message in background thread
threading.Thread(
target=self.process_message_async,
args=(user_message,),
daemon=True
).start()
def process_message_async(self, user_message: str):
"""Process message asynchronously"""
try:
# Run the async processing
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
response, metadata = loop.run_until_complete(
self.process_message(user_message)
)
loop.close()
# Update GUI in main thread
self.root.after(0, self.on_response_ready, response, metadata)
except Exception as e:
error_msg = f"处理消息时出错: {str(e)}"
self.root.after(0, self.on_error, error_msg)
async def process_message(self, user_message: str) -> Tuple[str, Dict[str, Any]]:
"""Process user message with hybrid search"""
start_time = time.time()
# Perform hybrid search
search_start = time.time()
search_results = await run_hybrid_search(
query_text=user_message,
search_type="hybrid",
group_id=self.search_config["group_id"],
limit=self.search_config["limit"],
include=self.search_config["include"],
output_path=None,
rerank_alpha=self.search_config["rerank_alpha"]
)
search_time = time.time() - search_start
# Extract relevant information from search results
context_info = self.extract_context_from_search(search_results)
# Generate response using LLM
llm_start = time.time()
response = await self.generate_response(user_message, context_info)
llm_time = time.time() - llm_start
total_time = time.time() - start_time
# Prepare metadata
metadata = {
"搜索时间": f"{search_time:.2f}s",
"生成时间": f"{llm_time:.2f}s",
"总时间": f"{total_time:.2f}s",
"搜索结果": self.get_search_summary(search_results),
"重排权重": self.search_config["rerank_alpha"]
}
return response, metadata
def extract_context_from_search(self, search_results: Dict) -> str:
"""Extract context information from search results"""
if not search_results:
return "未找到相关信息。"
context_parts = []
# Get reranked results if available, otherwise use individual results
if "reranked_results" in search_results:
results = search_results["reranked_results"]
else:
results = {}
for key in ["keyword_search", "embedding_search"]:
if key in search_results:
for category, items in search_results[key].items():
if category not in results:
results[category] = []
results[category].extend(items)
# Extract statements
if "statements" in results and results["statements"]:
statements = results["statements"][:5] # Top 5
context_parts.append("相关陈述:")
for i, stmt in enumerate(statements, 1):
content = stmt.get("statement", "")
score = stmt.get("combined_score", stmt.get("score", 0))
context_parts.append(f"{i}. {content} (相关度: {score:.3f})")
# Extract chunks
if "chunks" in results and results["chunks"]:
chunks = results["chunks"][:3] # Top 3
context_parts.append("\n相关对话:")
for i, chunk in enumerate(chunks, 1):
content = chunk.get("content", "")
score = chunk.get("combined_score", chunk.get("score", 0))
context_parts.append(f"{i}. {content} (相关度: {score:.3f})")
# Extract entities
if "entities" in results and results["entities"]:
entities = results["entities"][:5] # Top 5
context_parts.append("\n相关实体:")
entity_names = [ent.get("name", "") for ent in entities]
context_parts.append(", ".join(entity_names))
return "\n".join(context_parts) if context_parts else "未找到相关信息。"
def get_search_summary(self, search_results: Dict) -> str:
"""Get a summary of search results"""
if not search_results:
return "无结果"
summary_parts = []
if "combined_summary" in search_results:
summary = search_results["combined_summary"]
if "total_reranked_results" in summary:
summary_parts.append(f"重排结果: {summary['total_reranked_results']}")
if "total_keyword_results" in summary:
summary_parts.append(f"关键词: {summary['total_keyword_results']}")
if "total_embedding_results" in summary:
summary_parts.append(f"语义: {summary['total_embedding_results']}")
return ", ".join(summary_parts) if summary_parts else "有结果"
async def generate_response(self, user_message: str, context: str) -> str:
"""Generate response using LLM"""
system_prompt = f"""你是一个智能助手,基于知识图谱中的信息回答用户问题。
以下是从知识图谱中检索到的相关信息:
{context}
请基于这些信息回答用户的问题。如果信息不足,请诚实地说明。回答要自然、友好,并且准确。"""
try:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
]
response = self.llm_client.chat(
messages=messages,
)
print(response)
# Extract content from various possible response types
# 1) LangChain AIMessage or similar object with `.content`
if hasattr(response, 'content'):
return getattr(response, 'content')
# 2) OpenAI-style response with `.choices`
if hasattr(response, 'choices') and response.choices:
first_choice = response.choices[0]
# Newer clients may have `.message.content`, some have `.content` directly
if hasattr(first_choice, 'message') and hasattr(first_choice.message, 'content'):
return first_choice.message.content
if hasattr(first_choice, 'content'):
return first_choice.content
# 3) Dict-like responses
if isinstance(response, dict):
if 'content' in response:
return response['content']
if 'choices' in response and response['choices']:
ch = response['choices'][0]
if isinstance(ch, dict):
if 'message' in ch and 'content' in ch['message']:
return ch['message']['content']
if 'content' in ch:
return ch['content']
# 4) Fallback: if it's a plain string
if isinstance(response, str):
return response
# Default fallback
return "抱歉,我无法生成回复。"
except Exception as e:
return f"生成回复时出错: {str(e)}"
def on_response_ready(self, response: str, metadata: Dict[str, Any]):
"""Handle when response is ready"""
self.add_message("助手", response, metadata)
self.send_button.config(state=tk.NORMAL)
self.status_label.config(text="就绪")
self.user_input.focus()
def on_error(self, error_message: str):
"""Handle errors"""
self.add_message("系统", f" {error_message}")
self.send_button.config(state=tk.NORMAL)
self.status_label.config(text="就绪")
self.user_input.focus()
def run(self):
"""Start the chatbot"""
self.root.mainloop()
def main():
"""Main function to run the chatbot"""
try:
chatbot = HybridSearchChatbot()
chatbot.run()
except Exception as e:
print(f"启动聊天机器人时出错: {e}")
if __name__ == "__main__":
main()

View File

@@ -1,408 +1,408 @@
# -*- coding: utf-8 -*-
"""混合搜索策略
# # -*- coding: utf-8 -*-
# """混合搜索策略
结合关键词搜索和语义搜索的混合检索方法。
支持结果重排序和遗忘曲线加权。
"""
# 结合关键词搜索和语义搜索的混合检索方法。
# 支持结果重排序和遗忘曲线加权。
# """
from typing import List, Dict, Any, Optional
import math
from datetime import datetime
from app.core.logging_config import get_memory_logger
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.variate_config import ForgettingEngineConfig
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
# from typing import List, Dict, Any, Optional
# import math
# from datetime import datetime
# from app.core.logging_config import get_memory_logger
# from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
# from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
# from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
# from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
# from app.core.memory.models.variate_config import ForgettingEngineConfig
# from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
logger = get_memory_logger(__name__)
# logger = get_memory_logger(__name__)
class HybridSearchStrategy(SearchStrategy):
"""混合搜索策略
# class HybridSearchStrategy(SearchStrategy):
# """混合搜索策略
结合关键词搜索和语义搜索的优势:
- 关键词搜索:精确匹配,适合已知术语
- 语义搜索:语义理解,适合概念查询
- 混合重排序:综合两种搜索的结果
- 遗忘曲线:根据时间衰减调整相关性
"""
# 结合关键词搜索和语义搜索的优势:
# - 关键词搜索:精确匹配,适合已知术语
# - 语义搜索:语义理解,适合概念查询
# - 混合重排序:综合两种搜索的结果
# - 遗忘曲线:根据时间衰减调整相关性
# """
def __init__(
self,
connector: Optional[Neo4jConnector] = None,
embedder_client: Optional[OpenAIEmbedderClient] = None,
alpha: float = 0.6,
use_forgetting_curve: bool = False,
forgetting_config: Optional[ForgettingEngineConfig] = None
):
"""初始化混合搜索策略
# def __init__(
# self,
# connector: Optional[Neo4jConnector] = None,
# embedder_client: Optional[OpenAIEmbedderClient] = None,
# alpha: float = 0.6,
# use_forgetting_curve: bool = False,
# forgetting_config: Optional[ForgettingEngineConfig] = None
# ):
# """初始化混合搜索策略
Args:
connector: Neo4j连接器
embedder_client: 嵌入模型客户端
alpha: BM25分数权重0.0-1.01-alpha为嵌入分数权重
use_forgetting_curve: 是否使用遗忘曲线
forgetting_config: 遗忘引擎配置
"""
self.connector = connector
self.embedder_client = embedder_client
self.alpha = alpha
self.use_forgetting_curve = use_forgetting_curve
self.forgetting_config = forgetting_config or ForgettingEngineConfig()
self._owns_connector = connector is None
# Args:
# connector: Neo4j连接器
# embedder_client: 嵌入模型客户端
# alpha: BM25分数权重0.0-1.01-alpha为嵌入分数权重
# use_forgetting_curve: 是否使用遗忘曲线
# forgetting_config: 遗忘引擎配置
# """
# self.connector = connector
# self.embedder_client = embedder_client
# self.alpha = alpha
# self.use_forgetting_curve = use_forgetting_curve
# self.forgetting_config = forgetting_config or ForgettingEngineConfig()
# self._owns_connector = connector is None
# 创建子策略
self.keyword_strategy = KeywordSearchStrategy(connector=connector)
self.semantic_strategy = SemanticSearchStrategy(
connector=connector,
embedder_client=embedder_client
)
# # 创建子策略
# self.keyword_strategy = KeywordSearchStrategy(connector=connector)
# self.semantic_strategy = SemanticSearchStrategy(
# connector=connector,
# embedder_client=embedder_client
# )
async def __aenter__(self):
"""异步上下文管理器入口"""
if self._owns_connector:
self.connector = Neo4jConnector()
self.keyword_strategy.connector = self.connector
self.semantic_strategy.connector = self.connector
return self
# async def __aenter__(self):
# """异步上下文管理器入口"""
# if self._owns_connector:
# self.connector = Neo4jConnector()
# self.keyword_strategy.connector = self.connector
# self.semantic_strategy.connector = self.connector
# return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
if self._owns_connector and self.connector:
await self.connector.close()
# async def __aexit__(self, exc_type, exc_val, exc_tb):
# """异步上下文管理器出口"""
# if self._owns_connector and self.connector:
# await self.connector.close()
async def search(
self,
query_text: str,
group_id: Optional[str] = None,
limit: int = 50,
include: Optional[List[str]] = None,
**kwargs
) -> SearchResult:
"""执行混合搜索
# async def search(
# self,
# query_text: str,
# group_id: Optional[str] = None,
# limit: int = 50,
# include: Optional[List[str]] = None,
# **kwargs
# ) -> SearchResult:
# """执行混合搜索
Args:
query_text: 查询文本
group_id: 可选的组ID过滤
limit: 每个类别的最大结果数
include: 要包含的搜索类别列表
**kwargs: 其他搜索参数如alpha, use_forgetting_curve
# Args:
# query_text: 查询文本
# group_id: 可选的组ID过滤
# limit: 每个类别的最大结果数
# include: 要包含的搜索类别列表
# **kwargs: 其他搜索参数如alpha, use_forgetting_curve
Returns:
SearchResult: 搜索结果对象
"""
logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
# Returns:
# SearchResult: 搜索结果对象
# """
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
# 从kwargs中获取参数
alpha = kwargs.get("alpha", self.alpha)
use_forgetting = kwargs.get("use_forgetting_curve", self.use_forgetting_curve)
# # 从kwargs中获取参数
# alpha = kwargs.get("alpha", self.alpha)
# use_forgetting = kwargs.get("use_forgetting_curve", self.use_forgetting_curve)
# 获取有效的搜索类别
include_list = self._get_include_list(include)
# # 获取有效的搜索类别
# include_list = self._get_include_list(include)
try:
# 并行执行关键词搜索和语义搜索
keyword_result = await self.keyword_strategy.search(
query_text=query_text,
group_id=group_id,
limit=limit,
include=include_list
)
# try:
# # 并行执行关键词搜索和语义搜索
# keyword_result = await self.keyword_strategy.search(
# query_text=query_text,
# group_id=group_id,
# limit=limit,
# include=include_list
# )
semantic_result = await self.semantic_strategy.search(
query_text=query_text,
group_id=group_id,
limit=limit,
include=include_list
)
# semantic_result = await self.semantic_strategy.search(
# query_text=query_text,
# group_id=group_id,
# limit=limit,
# include=include_list
# )
# 重排序结果
if use_forgetting:
reranked_results = self._rerank_with_forgetting_curve(
keyword_result=keyword_result,
semantic_result=semantic_result,
alpha=alpha,
limit=limit
)
else:
reranked_results = self._rerank_hybrid_results(
keyword_result=keyword_result,
semantic_result=semantic_result,
alpha=alpha,
limit=limit
)
# # 重排序结果
# if use_forgetting:
# reranked_results = self._rerank_with_forgetting_curve(
# keyword_result=keyword_result,
# semantic_result=semantic_result,
# alpha=alpha,
# limit=limit
# )
# else:
# reranked_results = self._rerank_hybrid_results(
# keyword_result=keyword_result,
# semantic_result=semantic_result,
# alpha=alpha,
# limit=limit
# )
# 创建元数据
metadata = self._create_metadata(
query_text=query_text,
search_type="hybrid",
group_id=group_id,
limit=limit,
include=include_list,
alpha=alpha,
use_forgetting_curve=use_forgetting
)
# # 创建元数据
# metadata = self._create_metadata(
# query_text=query_text,
# search_type="hybrid",
# group_id=group_id,
# limit=limit,
# include=include_list,
# alpha=alpha,
# use_forgetting_curve=use_forgetting
# )
# 添加结果统计
metadata["keyword_results"] = keyword_result.metadata.get("result_counts", {})
metadata["semantic_results"] = semantic_result.metadata.get("result_counts", {})
metadata["total_keyword_results"] = keyword_result.total_results()
metadata["total_semantic_results"] = semantic_result.total_results()
metadata["total_reranked_results"] = reranked_results.total_results()
# # 添加结果统计
# metadata["keyword_results"] = keyword_result.metadata.get("result_counts", {})
# metadata["semantic_results"] = semantic_result.metadata.get("result_counts", {})
# metadata["total_keyword_results"] = keyword_result.total_results()
# metadata["total_semantic_results"] = semantic_result.total_results()
# metadata["total_reranked_results"] = reranked_results.total_results()
reranked_results.metadata = metadata
# reranked_results.metadata = metadata
logger.info(f"混合搜索完成: 共找到 {reranked_results.total_results()} 条结果")
return reranked_results
# logger.info(f"混合搜索完成: 共找到 {reranked_results.total_results()} 条结果")
# return reranked_results
except Exception as e:
logger.error(f"混合搜索失败: {e}", exc_info=True)
# 返回空结果但包含错误信息
return SearchResult(
metadata=self._create_metadata(
query_text=query_text,
search_type="hybrid",
group_id=group_id,
limit=limit,
error=str(e)
)
)
# except Exception as e:
# logger.error(f"混合搜索失败: {e}", exc_info=True)
# # 返回空结果但包含错误信息
# return SearchResult(
# metadata=self._create_metadata(
# query_text=query_text,
# search_type="hybrid",
# group_id=group_id,
# limit=limit,
# error=str(e)
# )
# )
def _normalize_scores(
self,
results: List[Dict[str, Any]],
score_field: str = "score"
) -> List[Dict[str, Any]]:
"""使用z-score标准化和sigmoid转换归一化分数
# def _normalize_scores(
# self,
# results: List[Dict[str, Any]],
# score_field: str = "score"
# ) -> List[Dict[str, Any]]:
# """使用z-score标准化和sigmoid转换归一化分数
Args:
results: 结果列表
score_field: 分数字段名
# Args:
# results: 结果列表
# score_field: 分数字段名
Returns:
List[Dict[str, Any]]: 归一化后的结果列表
"""
if not results:
return results
# Returns:
# List[Dict[str, Any]]: 归一化后的结果列表
# """
# if not results:
# return results
# 提取分数
scores = []
for item in results:
if score_field in item:
score = item.get(score_field)
if score is not None and isinstance(score, (int, float)):
scores.append(float(score))
else:
scores.append(0.0)
# # 提取分数
# scores = []
# for item in results:
# if score_field in item:
# score = item.get(score_field)
# if score is not None and isinstance(score, (int, float)):
# scores.append(float(score))
# else:
# scores.append(0.0)
if not scores or len(scores) == 1:
# 单个分数或无分数设置为1.0
for item in results:
if score_field in item:
item[f"normalized_{score_field}"] = 1.0
return results
# if not scores or len(scores) == 1:
# # 单个分数或无分数设置为1.0
# for item in results:
# if score_field in item:
# item[f"normalized_{score_field}"] = 1.0
# return results
# 计算均值和标准差
mean_score = sum(scores) / len(scores)
variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
std_dev = math.sqrt(variance)
# # 计算均值和标准差
# mean_score = sum(scores) / len(scores)
# variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
# std_dev = math.sqrt(variance)
if std_dev == 0:
# 所有分数相同设置为1.0
for item in results:
if score_field in item:
item[f"normalized_{score_field}"] = 1.0
else:
# z-score标准化 + sigmoid转换
for item in results:
if score_field in item:
score = item[score_field]
if score is None or not isinstance(score, (int, float)):
score = 0.0
z_score = (score - mean_score) / std_dev
normalized = 1 / (1 + math.exp(-z_score))
item[f"normalized_{score_field}"] = normalized
# if std_dev == 0:
# # 所有分数相同设置为1.0
# for item in results:
# if score_field in item:
# item[f"normalized_{score_field}"] = 1.0
# else:
# # z-score标准化 + sigmoid转换
# for item in results:
# if score_field in item:
# score = item[score_field]
# if score is None or not isinstance(score, (int, float)):
# score = 0.0
# z_score = (score - mean_score) / std_dev
# normalized = 1 / (1 + math.exp(-z_score))
# item[f"normalized_{score_field}"] = normalized
return results
# return results
def _rerank_hybrid_results(
self,
keyword_result: SearchResult,
semantic_result: SearchResult,
alpha: float,
limit: int
) -> SearchResult:
"""重排序混合搜索结果
# def _rerank_hybrid_results(
# self,
# keyword_result: SearchResult,
# semantic_result: SearchResult,
# alpha: float,
# limit: int
# ) -> SearchResult:
# """重排序混合搜索结果
Args:
keyword_result: 关键词搜索结果
semantic_result: 语义搜索结果
alpha: BM25分数权重
limit: 结果限制
# Args:
# keyword_result: 关键词搜索结果
# semantic_result: 语义搜索结果
# alpha: BM25分数权重
# limit: 结果限制
Returns:
SearchResult: 重排序后的结果
"""
reranked_data = {}
# Returns:
# SearchResult: 重排序后的结果
# """
# reranked_data = {}
for category in ["statements", "chunks", "entities", "summaries"]:
keyword_items = getattr(keyword_result, category, [])
semantic_items = getattr(semantic_result, category, [])
# for category in ["statements", "chunks", "entities", "summaries"]:
# keyword_items = getattr(keyword_result, category, [])
# semantic_items = getattr(semantic_result, category, [])
# 归一化分数
keyword_items = self._normalize_scores(keyword_items, "score")
semantic_items = self._normalize_scores(semantic_items, "score")
# # 归一化分数
# keyword_items = self._normalize_scores(keyword_items, "score")
# semantic_items = self._normalize_scores(semantic_items, "score")
# 合并结果
combined_items = {}
# # 合并结果
# combined_items = {}
# 添加关键词结果
for item in keyword_items:
item_id = item.get("id") or item.get("uuid")
if item_id:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
combined_items[item_id]["embedding_score"] = 0
# # 添加关键词结果
# for item in keyword_items:
# item_id = item.get("id") or item.get("uuid")
# if item_id:
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
# combined_items[item_id]["embedding_score"] = 0
# 添加或更新语义结果
for item in semantic_items:
item_id = item.get("id") or item.get("uuid")
if item_id:
if item_id in combined_items:
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# # 添加或更新语义结果
# for item in semantic_items:
# item_id = item.get("id") or item.get("uuid")
# if item_id:
# if item_id in combined_items:
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# else:
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = 0
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# 计算组合分数
for item_id, item in combined_items.items():
bm25_score = item.get("bm25_score", 0)
embedding_score = item.get("embedding_score", 0)
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
item["combined_score"] = combined_score
# # 计算组合分数
# for item_id, item in combined_items.items():
# bm25_score = item.get("bm25_score", 0)
# embedding_score = item.get("embedding_score", 0)
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
# item["combined_score"] = combined_score
# 排序并限制结果
sorted_items = sorted(
combined_items.values(),
key=lambda x: x.get("combined_score", 0),
reverse=True
)[:limit]
# # 排序并限制结果
# sorted_items = sorted(
# combined_items.values(),
# key=lambda x: x.get("combined_score", 0),
# reverse=True
# )[:limit]
reranked_data[category] = sorted_items
# reranked_data[category] = sorted_items
return SearchResult(
statements=reranked_data.get("statements", []),
chunks=reranked_data.get("chunks", []),
entities=reranked_data.get("entities", []),
summaries=reranked_data.get("summaries", [])
)
# return SearchResult(
# statements=reranked_data.get("statements", []),
# chunks=reranked_data.get("chunks", []),
# entities=reranked_data.get("entities", []),
# summaries=reranked_data.get("summaries", [])
# )
def _parse_datetime(self, value: Any) -> Optional[datetime]:
"""解析日期时间字符串"""
if value is None:
return None
if isinstance(value, datetime):
return value
if isinstance(value, str):
s = value.strip()
if not s:
return None
try:
return datetime.fromisoformat(s)
except Exception:
return None
return None
# def _parse_datetime(self, value: Any) -> Optional[datetime]:
# """解析日期时间字符串"""
# if value is None:
# return None
# if isinstance(value, datetime):
# return value
# if isinstance(value, str):
# s = value.strip()
# if not s:
# return None
# try:
# return datetime.fromisoformat(s)
# except Exception:
# return None
# return None
def _rerank_with_forgetting_curve(
self,
keyword_result: SearchResult,
semantic_result: SearchResult,
alpha: float,
limit: int
) -> SearchResult:
"""使用遗忘曲线重排序混合搜索结果
# def _rerank_with_forgetting_curve(
# self,
# keyword_result: SearchResult,
# semantic_result: SearchResult,
# alpha: float,
# limit: int
# ) -> SearchResult:
# """使用遗忘曲线重排序混合搜索结果
Args:
keyword_result: 关键词搜索结果
semantic_result: 语义搜索结果
alpha: BM25分数权重
limit: 结果限制
# Args:
# keyword_result: 关键词搜索结果
# semantic_result: 语义搜索结果
# alpha: BM25分数权重
# limit: 结果限制
Returns:
SearchResult: 重排序后的结果
"""
engine = ForgettingEngine(self.forgetting_config)
now_dt = datetime.now()
# Returns:
# SearchResult: 重排序后的结果
# """
# engine = ForgettingEngine(self.forgetting_config)
# now_dt = datetime.now()
reranked_data = {}
# reranked_data = {}
for category in ["statements", "chunks", "entities", "summaries"]:
keyword_items = getattr(keyword_result, category, [])
semantic_items = getattr(semantic_result, category, [])
# for category in ["statements", "chunks", "entities", "summaries"]:
# keyword_items = getattr(keyword_result, category, [])
# semantic_items = getattr(semantic_result, category, [])
# 归一化分数
keyword_items = self._normalize_scores(keyword_items, "score")
semantic_items = self._normalize_scores(semantic_items, "score")
# # 归一化分数
# keyword_items = self._normalize_scores(keyword_items, "score")
# semantic_items = self._normalize_scores(semantic_items, "score")
# 合并结果
combined_items = {}
# # 合并结果
# combined_items = {}
for src_items, is_embedding in [(keyword_items, False), (semantic_items, True)]:
for item in src_items:
item_id = item.get("id") or item.get("uuid")
if not item_id:
continue
# for src_items, is_embedding in [(keyword_items, False), (semantic_items, True)]:
# for item in src_items:
# item_id = item.get("id") or item.get("uuid")
# if not item_id:
# continue
if item_id not in combined_items:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0
combined_items[item_id]["embedding_score"] = 0
# if item_id not in combined_items:
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = 0
# combined_items[item_id]["embedding_score"] = 0
if is_embedding:
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
# if is_embedding:
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# else:
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
# 计算分数并应用遗忘权重
for item_id, item in combined_items.items():
bm25_score = float(item.get("bm25_score", 0) or 0)
embedding_score = float(item.get("embedding_score", 0) or 0)
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
# # 计算分数并应用遗忘权重
# for item_id, item in combined_items.items():
# bm25_score = float(item.get("bm25_score", 0) or 0)
# embedding_score = float(item.get("embedding_score", 0) or 0)
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
# 计算时间衰减
dt = self._parse_datetime(item.get("created_at"))
if dt is None:
time_elapsed_days = 0.0
else:
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
# # 计算时间衰减
# dt = self._parse_datetime(item.get("created_at"))
# if dt is None:
# time_elapsed_days = 0.0
# else:
# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
memory_strength = 1.0 # 默认强度
forgetting_weight = engine.calculate_weight(
time_elapsed=time_elapsed_days,
memory_strength=memory_strength
)
# memory_strength = 1.0 # 默认强度
# forgetting_weight = engine.calculate_weight(
# time_elapsed=time_elapsed_days,
# memory_strength=memory_strength
# )
final_score = combined_score * forgetting_weight
item["combined_score"] = final_score
item["forgetting_weight"] = forgetting_weight
item["time_elapsed_days"] = time_elapsed_days
# final_score = combined_score * forgetting_weight
# item["combined_score"] = final_score
# item["forgetting_weight"] = forgetting_weight
# item["time_elapsed_days"] = time_elapsed_days
# 排序并限制结果
sorted_items = sorted(
combined_items.values(),
key=lambda x: x.get("combined_score", 0),
reverse=True
)[:limit]
# # 排序并限制结果
# sorted_items = sorted(
# combined_items.values(),
# key=lambda x: x.get("combined_score", 0),
# reverse=True
# )[:limit]
reranked_data[category] = sorted_items
# reranked_data[category] = sorted_items
return SearchResult(
statements=reranked_data.get("statements", []),
chunks=reranked_data.get("chunks", []),
entities=reranked_data.get("entities", []),
summaries=reranked_data.get("summaries", [])
)
# return SearchResult(
# statements=reranked_data.get("statements", []),
# chunks=reranked_data.get("chunks", []),
# entities=reranked_data.get("entities", []),
# summaries=reranked_data.get("summaries", [])
# )

View File

@@ -1,445 +0,0 @@
# Memory 模块工具函数文档
本目录包含 Memory 模块使用的所有工具函数,统一管理以提高代码可维护性和可复用性。
## 目录结构
```
app/core/memory/utils/
├── __init__.py # 包初始化文件,导出所有公共接口
├── README.md # 本文档
├── config/ # 配置管理模块
│ ├── __init__.py # 配置模块初始化
│ ├── config_utils.py # 配置管理工具
│ ├── definitions.py # 全局定义和常量
│ ├── overrides.py # 运行时配置覆写
│ ├── get_data.py # 数据获取工具
│ ├── litellm_config.py # LiteLLM 配置和监控
│ └── config_optimization.py # 配置优化工具
├── log/ # 日志管理模块
│ ├── __init__.py # 日志模块初始化
│ ├── logging_utils.py # 日志工具
│ └── audit_logger.py # 审计日志
├── prompt/ # 提示词管理模块
│ ├── __init__.py # 提示词模块初始化
│ ├── prompt_utils.py # 提示词渲染工具
│ ├── template_render.py # 模板渲染工具
│ └── prompts/ # Jinja2 提示词模板目录
│ ├── entity_dedup.jinja2 # 实体去重提示词
│ ├── extract_statement.jinja2 # 陈述句提取提示词
│ ├── extract_temporal.jinja2 # 时间信息提取提示词
│ ├── extract_triplet.jinja2 # 三元组提取提示词
│ ├── memory_summary.jinja2 # 记忆摘要提示词
│ ├── evaluate.jinja2 # 评估提示词
│ ├── reflexion.jinja2 # 反思提示词
│ ├── system.jinja2 # 系统提示词
│ └── user.jinja2 # 用户提示词
├── llm/ # LLM 工具模块
│ ├── __init__.py # LLM 模块初始化
│ └── llm_utils.py # LLM 客户端工具
├── data/ # 数据处理模块
│ ├── __init__.py # 数据模块初始化
│ ├── text_utils.py # 文本处理工具
│ ├── time_utils.py # 时间处理工具
│ └── ontology.py # 本体定义(谓语、标签等)
├── paths/ # 路径管理模块
│ ├── __init__.py # 路径模块初始化
│ └── output_paths.py # 输出路径管理
├── visualization/ # 可视化模块
│ ├── __init__.py # 可视化模块初始化
│ └── forgetting_visualizer.py # 遗忘曲线可视化
└── self_reflexion_utils/ # 自我反思工具模块
├── __init__.py # 反思模块初始化
├── evaluate.py # 冲突评估
├── reflexion.py # 反思处理
└── self_reflexion.py # 自我反思主逻辑
```
## 模块分类
### 1. 配置管理config/
配置管理模块包含所有与配置相关的工具函数和定义。
#### config_utils.py
提供配置加载和管理功能:
- `get_model_config(model_id)` - 获取 LLM 模型配置
- `get_embedder_config(embedding_id)` - 获取嵌入模型配置
- `get_neo4j_config()` - 获取 Neo4j 数据库配置
- `get_chunker_config(chunker_strategy)` - 获取分块策略配置
- `get_pipeline_config()` - 获取流水线配置
- `get_pruning_config()` - 获取语义剪枝配置
- `get_picture_config()` - 获取图片模型配置
- `get_voice_config()` - 获取语音模型配置
#### definitions.py
全局定义和常量:
- `CONFIG` - 基础配置(从 config.json 加载)
- `RUNTIME_CONFIG` - 运行时配置(从 runtime.json 或数据库加载)
- `PROJECT_ROOT` - 项目根目录路径
- 各种选择配置常量LLM、嵌入模型、分块策略等
- `reload_configuration_from_database(config_id)` - 动态重新加载配置
#### overrides.py
运行时配置覆写:
- `load_unified_config(project_root)` - 加载统一配置
#### get_data.py
数据获取工具:
- `get_data(host_id)` - 从 SQL 数据库获取数据
#### litellm_config.py
LiteLLM 配置和监控:
- `LiteLLMConfig` - LiteLLM 配置类
- `setup_litellm_enhanced(max_retries)` - 设置增强的 LiteLLM 配置
- `get_usage_summary()` - 获取使用统计摘要
- `print_usage_summary()` - 打印使用统计
- `get_instant_qps(module)` - 获取即时 QPS 数据
- `print_instant_qps(module)` - 打印即时 QPS 信息
#### config_optimization.py
配置优化工具:
- 配置参数优化相关功能
### 3. LLM 工具llm/
LLM 工具模块包含所有与 LLM 客户端相关的工具函数。
#### llm_utils.py
LLM 客户端工具:
- `get_llm_client(llm_id)` - 获取 LLM 客户端实例
- `get_reranker_client(rerank_id)` - 获取重排序客户端实例
- `handle_response(response)` - 处理 LLM 响应
#### litellm_config.py
LiteLLM 配置和监控:
- `LiteLLMConfig` - LiteLLM 配置类
- `setup_litellm_enhanced(max_retries)` - 设置增强的 LiteLLM 配置
- `get_usage_summary()` - 获取使用统计摘要
- `print_usage_summary()` - 打印使用统计
- `get_instant_qps(module)` - 获取即时 QPS 数据
- `print_instant_qps(module)` - 打印即时 QPS 信息
### 4. 提示词管理prompt/
提示词管理模块包含所有提示词渲染和模板管理相关的工具函数。
#### prompt_utils.py
提示词渲染工具(使用 Jinja2 模板):
- `get_prompts(message)` - 获取系统和用户提示词
- `render_statement_extraction_prompt(...)` - 渲染陈述句提取提示词
- `render_temporal_extraction_prompt(...)` - 渲染时间信息提取提示词
- `render_entity_dedup_prompt(...)` - 渲染实体去重提示词
- `render_triplet_extraction_prompt(...)` - 渲染三元组提取提示词
- `render_memory_summary_prompt(...)` - 渲染记忆摘要提示词
- `prompt_env` - Jinja2 环境对象
#### template_render.py
模板渲染工具(用于评估和反思):
- `render_evaluate_prompt(evaluate_data, schema)` - 渲染评估提示词
- `render_reflexion_prompt(data, schema)` - 渲染反思提示词
#### prompts/
Jinja2 模板文件目录,包含所有提示词模板
### 5. 数据处理data/
数据处理模块包含所有数据处理相关的工具函数。
#### text_utils.py
文本处理工具:
- `escape_lucene_query(query)` - 转义 Lucene 查询特殊字符
- `extract_plain_query(query_input)` - 从各种输入格式提取纯文本查询
#### time_utils.py
时间处理工具:
- `validate_date_format(date_str)` - 验证日期格式YYYY-MM-DD
- `normalize_date(date_str)` - 标准化日期格式
- `normalize_date_safe(date_str, default)` - 安全的日期标准化(带默认值)
- `preprocess_date_string(date_str)` - 预处理日期字符串
#### ontology.py
本体定义:
- `PREDICATE_DEFINITIONS` - 谓语定义字典
- `LABEL_DEFINITIONS` - 标签定义字典
- `Predicate` - 谓语枚举
- `StatementType` - 陈述句类型枚举
- `TemporalInfo` - 时间信息枚举
- `RelevenceInfo` - 相关性信息枚举
### 2. 日志管理log/
日志管理模块包含所有与日志记录相关的工具函数。
#### logging_utils.py
日志工具:
- `log_prompt_rendering(role, content)` - 记录提示词渲染
- `log_template_rendering(template_name, context)` - 记录模板渲染
- `log_time(operation, duration)` - 记录操作耗时
- `prompt_logger` - 提示词日志记录器
#### audit_logger.py
审计日志:
- `audit_logger` - 审计日志记录器
- 记录系统关键操作和安全事件
### 6. 自我反思工具self_reflexion_utils/
自我反思工具模块包含记忆冲突检测和反思处理功能。
#### evaluate.py
冲突评估:
- `conflict(evaluate_data, schema)` - 评估记忆冲突
#### reflexion.py
反思处理:
- `reflexion(data, schema)` - 执行反思处理
#### self_reflexion.py
自我反思主逻辑:
- `self_reflexion(...)` - 自我反思主函数
### 7. 数据模型
#### json_schema.py
JSON Schema 数据模型:
- `BaseDataSchema` - 基础数据模型
- `ConflictResultSchema` - 冲突结果模型
- `ConflictSchema` - 冲突模型
- `ReflexionSchema` - 反思模型
- `ResolvedSchema` - 解决方案模型
- `ReflexionResultSchema` - 反思结果模型
#### messages.py
API 消息模型:
- `ConfigKey` - 配置键模型
- `ChunkerStrategy` - 分块策略枚举
- `ConfigParams` - 配置参数模型
- `ConfigParamsCreate` - 创建配置参数模型
- `ConfigUpdate` - 更新配置模型
- `ConfigUpdateExtracted` - 更新萃取引擎配置模型
- `ConfigUpdateForget` - 更新遗忘引擎配置模型
- `ConfigPilotRun` - 试运行配置模型
- `ConfigFilter` - 配置过滤模型
- `ApiResponse` - API 响应模型
- `ok(msg, data)` - 成功响应构造函数
- `fail(msg, error_code, data)` - 失败响应构造函数
### 8. 可视化visualization/
可视化模块包含所有可视化相关的工具函数。
#### forgetting_visualizer.py
遗忘曲线可视化:
- `export_memory_curve_numpy(...)` - 导出记忆曲线为 NumPy 数组
- `export_memory_curves_multiple_strengths(...)` - 导出多个强度的记忆曲线
- `export_parameter_sweep_numpy(...)` - 导出参数扫描结果
- `visualize_forgetting_curve(...)` - 可视化遗忘曲线
- `plot_3d_forgetting_surface(...)` - 绘制 3D 遗忘曲线表面
- `create_comparison_visualization(...)` - 创建对比可视化
- `save_memory_curves_to_file(...)` - 保存记忆曲线到文件
### 9. 路径管理paths/
路径管理模块包含所有路径管理相关的工具函数。
#### output_paths.py
输出路径管理:
- `get_output_dir()` - 获取输出目录
- `get_output_path(filename)` - 获取输出文件路径
## 使用示例
### 配置管理
```python
from app.core.memory.utils.config import get_model_config, get_pipeline_config
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
# 获取模型配置
model_config = get_model_config("model_id_123")
# 获取流水线配置
pipeline_config = get_pipeline_config()
# 使用全局常量
llm_id = SELECTED_LLM_ID
```
### 日志管理
```python
from app.core.memory.utils.log import log_prompt_rendering, log_time, audit_logger
# 记录提示词渲染
log_prompt_rendering('user', 'Hello, world!')
# 记录操作耗时
log_time('extraction', 1.23)
# 使用审计日志
audit_logger.info('User action performed')
```
### LLM 工具
```python
from app.core.memory.utils.llm import get_llm_client
# 获取 LLM 客户端
llm_client = get_llm_client("llm_id_456")
# 调用 LLM
response = await llm_client.chat([
{"role": "user", "content": "Hello"}
])
```
### 提示词渲染
```python
from app.core.memory.utils.prompt import render_statement_extraction_prompt
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS
# 渲染陈述句提取提示词
prompt = await render_statement_extraction_prompt(
chunk_content="对话内容...",
definitions=LABEL_DEFINITIONS,
json_schema=schema,
granularity=2
)
```
### 数据处理
```python
from app.core.memory.utils.data.time_utils import normalize_date
from app.core.memory.utils.data.text_utils import escape_lucene_query
# 标准化日期
normalized = normalize_date("2025/10/28") # 返回 "2025-10-28"
# 转义 Lucene 查询
escaped = escape_lucene_query("user:admin AND status:active")
```
### 运行时配置覆写
```python
from app.core.memory.utils import apply_runtime_overrides_with_config_id
# 使用指定 config_id 覆写配置
runtime_cfg = {"selections": {}}
updated_cfg = apply_runtime_overrides_with_config_id(
project_root="/path/to/project",
runtime_cfg=runtime_cfg,
config_id="config_123"
)
```
## 迁移说明
### 从旧路径迁移
如果你的代码使用了旧的导入路径,请按以下方式更新:
**旧路径2024年11月之前**
```python
from app.core.memory.src.utils.config_utils import get_model_config
from app.core.memory.src.utils.prompt_utils import render_statement_extraction_prompt
from app.core.memory.src.data_config_api.utils.messages import ok, fail
```
**中间路径2024年11月**
```python
from app.core.memory.utils.config_utils import get_model_config
from app.core.memory.utils.logging_utils import log_prompt_rendering
from app.schemas.memory_storage_schema import ok, fail
```
**新路径2024年11月27日之后**
```python
# 配置相关
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.memory.utils.config import get_model_config # 简化导入
# 日志相关
from app.core.memory.utils.log.logging_utils import log_prompt_rendering
from app.core.memory.utils.log import log_prompt_rendering # 简化导入
# 其他工具
from app.core.memory.utils import prompt_utils
from app.schemas.memory_storage_schema import ok, fail
```
### 目录结构重组2024年11月27日
utils 目录已按功能进行了完整的重组:
**重组前的结构:**
- 所有文件都在 `app/core/memory/utils/` 根目录下
**重组后的结构:**
- `config/` - 配置管理相关文件
- `log/` - 日志管理相关文件
- `prompt/` - 提示词管理相关文件
- `llm/` - LLM 工具相关文件
- `data/` - 数据处理相关文件
- `paths/` - 路径管理相关文件
- `visualization/` - 可视化相关文件
- `self_reflexion_utils/` - 自我反思工具(已存在)
**导入路径变化:**
```python
# 旧导入方式
from app.core.memory.utils.config_utils import get_model_config
from app.core.memory.utils.logging_utils import log_prompt_rendering
from app.core.memory.utils.prompt_utils import render_statement_extraction_prompt
# 新导入方式
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.memory.utils.log.logging_utils import log_prompt_rendering
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
# 或使用简化导入
from app.core.memory.utils.config import get_model_config
from app.core.memory.utils.log import log_prompt_rendering
from app.core.memory.utils.prompt import render_statement_extraction_prompt
```
## 维护指南
### 添加新工具函数
1. 在相应的模块文件中添加函数
2.`__init__.py` 中导出函数
3. 在本 README 中添加文档
4. 编写单元测试
### 删除旧工具函数
1. 确认没有代码使用该函数
2. 从模块文件中删除函数
3.`__init__.py` 中删除导出
4. 更新本 README
### 重构工具函数
1. 保持向后兼容性(使用别名或包装器)
2. 更新所有使用该函数的代码
3. 更新文档和测试
4. 在适当时机删除旧版本
## 注意事项
1. **向后兼容性**:所有工具函数应保持向后兼容,避免破坏现有代码
2. **文档完整性**:每个函数都应有清晰的文档字符串
3. **类型注解**:使用类型注解提高代码可读性
4. **错误处理**:工具函数应有适当的错误处理
5. **测试覆盖**:所有工具函数都应有单元测试
## 相关文档
- [Memory 模块架构设计](../.kiro/specs/memory-refactoring/design.md)
- [Memory 模块需求文档](../.kiro/specs/memory-refactoring/requirements.md)
- [Memory 模块任务列表](../.kiro/specs/memory-refactoring/tasks.md)

View File

@@ -6,33 +6,27 @@
# 从子模块导出常用函数和常量,保持向后兼容
from .config_utils import (
get_model_config,
get_embedder_config,
get_neo4j_config,
get_chunker_config,
get_embedder_config,
get_model_config,
get_neo4j_config,
get_picture_config,
get_pipeline_config,
get_pruning_config,
get_picture_config,
get_voice_config,
)
from .definitions import (
CONFIG,
RUNTIME_CONFIG,
PROJECT_ROOT,
SELECTED_LLM_ID,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_RERANK_ID,
SELECTED_LLM_PICTURE_NAME,
SELECTED_LLM_VOICE_NAME,
REFLEXION_ENABLED,
REFLEXION_ITERATION_PERIOD,
REFLEXION_RANGE,
REFLEXION_BASELINE,
reload_configuration_from_database,
)
from .overrides import load_unified_config
# DEPRECATED: Global configuration variables removed
# Use MemoryConfig objects with dependency injection instead
# from .definitions import (
# CONFIG, # DEPRECATED - empty dict for backward compatibility
# RUNTIME_CONFIG, # DEPRECATED - minimal for backward compatibility
# PROJECT_ROOT, # Still needed for file paths
# reload_configuration_from_database, # DEPRECATED - returns False
# )
# DEPRECATED: overrides module removed - use MemoryConfig with dependency injection
from .get_data import get_data
# litellm_config 需要时动态导入,避免循环依赖
# from .litellm_config import (
# LiteLLMConfig,
@@ -53,23 +47,11 @@ __all__ = [
"get_pruning_config",
"get_picture_config",
"get_voice_config",
# definitions
"CONFIG",
"RUNTIME_CONFIG",
"PROJECT_ROOT",
"SELECTED_LLM_ID",
"SELECTED_EMBEDDING_ID",
"SELECTED_GROUP_ID",
"SELECTED_RERANK_ID",
"SELECTED_LLM_PICTURE_NAME",
"SELECTED_LLM_VOICE_NAME",
"REFLEXION_ENABLED",
"REFLEXION_ITERATION_PERIOD",
"REFLEXION_RANGE",
"REFLEXION_BASELINE",
"reload_configuration_from_database",
# overrides
"load_unified_config",
# definitions (DEPRECATED - use MemoryConfig objects instead)
# "CONFIG", # DEPRECATED
# "RUNTIME_CONFIG", # DEPRECATED
# "PROJECT_ROOT",
# "reload_configuration_from_database", # DEPRECATED
# get_data
"get_data",
# litellm_config - 需要时从 .litellm_config 直接导入

View File

@@ -1,22 +1,18 @@
import uuid
import json
from typing import Optional
from sqlalchemy.orm import Session
from fastapi.exceptions import HTTPException
from fastapi import status
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
from app.core.memory.models.variate_config import (
ExtractionPipelineConfig,
DedupConfig,
StatementExtractionConfig,
ExtractionPipelineConfig,
ForgettingEngineConfig,
StatementExtractionConfig,
)
from app.core.memory.models.config_models import PruningConfig
from app.core.memory.utils.config.definitions import CONFIG
from app.db import get_db
from app.models.models_model import ModelConfig, ModelApiKey
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService
from fastapi import status
from fastapi.exceptions import HTTPException
from sqlalchemy.orm import Session
def get_model_config(model_id: str, db: Session | None = None) -> dict:
if db is None:
db_gen = get_db() # get_db 通常是一个生成器
@@ -110,6 +106,13 @@ def get_chunker_config(chunker_strategy: str) -> dict:
# 2) Provide sane defaults for newer strategies
default_configs = {
"RecursiveChunker": {
"chunker_strategy": "RecursiveChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 512,
"min_characters_per_chunk": 50
},
"LLMChunker": {
"chunker_strategy": "LLMChunker",
"embedding_model": "BAAI/bge-m3",
@@ -147,94 +150,74 @@ def get_chunker_config(chunker_strategy: str) -> dict:
f"Chunker '{chunker_strategy}' not found in config.json and no default or fallback available"
)
#TODO: Fix this
def get_pipeline_config() -> ExtractionPipelineConfig:
"""Build ExtractionPipelineConfig using only runtime.json values.
def get_pipeline_config(
config_id: int,
db: Session | None = None,
) -> ExtractionPipelineConfig:
"""Build ExtractionPipelineConfig from database.
Behavior:
- Read `deduplication` section from runtime.json if present.
- Read `statement_extraction` section from runtime.json if present.
- Read `forgetting_engine` section from runtime.json if present.
- If absent, check legacy top-level `enable_llm_dedup` key.
- Do NOT fall back to environment variables.
- Unspecified fields use model defaults defined in DedupConfig.
Args:
config_id: Database configuration ID (required). Loads pipeline
settings from the DataConfig table.
db: Optional database session. If not provided, a new session
will be created.
Returns:
ExtractionPipelineConfig with deduplication, statement extraction,
and forgetting engine settings loaded from database.
Raises:
ValueError: If config_id not found in database.
"""
dedup_rc = RUNTIME_CONFIG.get("deduplication", {}) or {}
stmt_rc = RUNTIME_CONFIG.get("statement_extraction", {}) or {}
forget_rc = RUNTIME_CONFIG.get("forgetting_engine", {}) or {}
from app.repositories.data_config_repository import DataConfigRepository
# Assemble kwargs from runtime.json only
kwargs = {}
# LLM switch: prefer new key, then legacy top-level, default False
if "enable_llm_dedup_blockwise" in dedup_rc:
kwargs["enable_llm_dedup_blockwise"] = bool(dedup_rc.get("enable_llm_dedup_blockwise"))
else:
# Legacy top-level fallback inside runtime.json only
legacy = RUNTIME_CONFIG.get("enable_llm_dedup")
if legacy is not None:
kwargs["enable_llm_dedup_blockwise"] = bool(legacy)
else:
kwargs["enable_llm_dedup_blockwise"] = False # default reserve
# Disambiguation switch: only from runtime.json deduplication section
if "enable_llm_disambiguation" in dedup_rc:
kwargs["enable_llm_disambiguation"] = bool(dedup_rc.get("enable_llm_disambiguation"))
# Load from database
if db is None:
db_gen = get_db()
db = next(db_gen)
db_config = DataConfigRepository.get_by_id(db, config_id)
if db_config is None:
raise ValueError(f"Configuration {config_id} not found in database")
# Optional LLM fallback gating
if "enable_llm_fallback_only_on_borderline" in dedup_rc:
kwargs["enable_llm_fallback_only_on_borderline"] = bool(dedup_rc.get("enable_llm_fallback_only_on_borderline"))
# Build DedupConfig from database
dedup_kwargs = {
"enable_llm_dedup_blockwise": bool(db_config.enable_llm_dedup_blockwise) if db_config.enable_llm_dedup_blockwise is not None else False,
"enable_llm_disambiguation": bool(db_config.enable_llm_disambiguation) if db_config.enable_llm_disambiguation is not None else False,
}
# Fuzzy thresholds
if db_config.t_name_strict is not None:
dedup_kwargs["fuzzy_name_threshold_strict"] = db_config.t_name_strict
if db_config.t_type_strict is not None:
dedup_kwargs["fuzzy_type_threshold_strict"] = db_config.t_type_strict
if db_config.t_overall is not None:
dedup_kwargs["fuzzy_overall_threshold"] = db_config.t_overall
# Optional fuzzy thresholds: use values if provided; otherwise rely on DedupConfig defaults
for key in (
"fuzzy_name_threshold_strict",
"fuzzy_type_threshold_strict",
"fuzzy_overall_threshold",
"fuzzy_unknown_type_name_threshold",
"fuzzy_unknown_type_type_threshold",
):
if key in dedup_rc:
kwargs[key] = dedup_rc[key]
dedup_config = DedupConfig(**dedup_kwargs)
# Optional weights and bonuses for overall scoring
for key in (
"name_weight",
"desc_weight",
"type_weight",
"context_bonus",
"llm_fallback_floor",
"llm_fallback_ceiling",
):
if key in dedup_rc:
kwargs[key] = dedup_rc[key]
# Optional LLM iterative dedup parameters
for key in (
"llm_block_size",
"llm_block_concurrency",
"llm_pair_concurrency",
"llm_max_rounds",
):
if key in dedup_rc:
kwargs[key] = dedup_rc[key]
dedup_config = DedupConfig(**kwargs)
# Build StatementExtractionConfig from runtime.json
# Build StatementExtractionConfig from database
stmt_kwargs = {}
for key in (
"statement_granularity",
"temperature",
"include_dialogue_context",
"max_dialogue_context_chars",
):
if key in stmt_rc:
stmt_kwargs[key] = stmt_rc[key]
if db_config.statement_granularity is not None:
stmt_kwargs["statement_granularity"] = db_config.statement_granularity
if db_config.include_dialogue_context is not None:
stmt_kwargs["include_dialogue_context"] = bool(db_config.include_dialogue_context)
if db_config.max_context is not None:
stmt_kwargs["max_dialogue_context_chars"] = db_config.max_context
stmt_config = StatementExtractionConfig(**stmt_kwargs)
# Build ForgettingEngineConfig from runtime.json
# Build ForgettingEngineConfig from database
forget_kwargs = {}
for key in ("offset", "lambda_time", "lambda_mem"):
if key in forget_rc:
forget_kwargs[key] = forget_rc[key]
if db_config.offset is not None:
forget_kwargs["offset"] = db_config.offset
if db_config.lambda_time is not None:
forget_kwargs["lambda_time"] = db_config.lambda_time
if db_config.lambda_mem is not None:
forget_kwargs["lambda_mem"] = db_config.lambda_mem
forget_config = ForgettingEngineConfig(**forget_kwargs)
return ExtractionPipelineConfig(
@@ -244,24 +227,37 @@ def get_pipeline_config() -> ExtractionPipelineConfig:
)
def get_pruning_config() -> dict:
"""Retrieve semantic pruning config from runtime.json.
def get_pruning_config(
config_id: int,
db: Session | None = None,
) -> dict:
"""Retrieve semantic pruning config from database.
Returns a dict suitable for PruningConfig.model_validate.
Args:
config_id: Database configuration ID (required).
db: Optional database session.
Structure in runtime.json:
{
"pruning": {
"enabled": true,
"scene": "education" | "online_service" | "outbound",
"threshold": 0.5
}
}
Returns:
Dict suitable for PruningConfig.model_validate with keys:
- pruning_switch: bool
- pruning_scene: str ("education" | "online_service" | "outbound")
- pruning_threshold: float (0-0.9)
Raises:
ValueError: If config_id not found in database.
"""
pruning_rc = RUNTIME_CONFIG.get("pruning", {}) or {}
from app.repositories.data_config_repository import DataConfigRepository
if db is None:
db_gen = get_db()
db = next(db_gen)
db_config = DataConfigRepository.get_by_id(db, config_id)
if db_config is None:
raise ValueError(f"Configuration {config_id} not found in database")
return {
"pruning_switch": bool(pruning_rc.get("enabled", False)),
"pruning_scene": pruning_rc.get("scene", "education"),
"pruning_threshold": float(pruning_rc.get("threshold", 0.5)),
"pruning_switch": bool(db_config.pruning_enabled) if db_config.pruning_enabled is not None else False,
"pruning_scene": db_config.pruning_scene or "education",
"pruning_threshold": float(db_config.pruning_threshold) if db_config.pruning_threshold is not None else 0.5,
}

View File

@@ -1,18 +1,26 @@
"""
配置加载模块 - 三阶段架构(已迁移到统一配置管理)
配置加载模块 - DEPRECATED
本模块现在使用全局配置管理系统 (app/core/config.py)
来加载和管理配置,同时保持向后兼容性。
⚠️ DEPRECATION NOTICE ⚠️
This module is deprecated and will be removed in a future version.
Global configuration variables have been eliminated in favor of dependency injection.
阶段 1: 从 runtime.json 加载配置(路径 A
阶段 2: 从数据库加载配置(路径 B基于 dbrun.json 中的 config_id
阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)
Use the new MemoryConfig system instead:
- app.core.memory_config.config.MemoryConfig for configuration objects
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
阶段 1: 从 runtime.json 加载配置(路径 A- DEPRECATED
阶段 2: 从数据库加载配置(路径 B基于 dbrun.json 中的 config_id- DEPRECATED
阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED
"""
import os
import json
import os
import threading
from typing import Any, Dict, Optional
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
#TODO: Fix this
try:
from dotenv import load_dotenv
@@ -35,21 +43,12 @@ except ImportError:
# os.path.dirname(...) = app/core/memory
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 全局配置锁 - 用于线程安全
_config_lock = threading.RLock()
# DEPRECATED: Global configuration lock removed
# Use MemoryConfig objects with dependency injection instead
# 加载基础配置config.json- 使用全局配置系统
if USE_UNIFIED_CONFIG:
CONFIG = settings.load_memory_config()
else:
# Fallback to legacy loading
config_path = os.path.join(PROJECT_ROOT, "config.json")
try:
with open(config_path, "r") as f:
CONFIG = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
print("Warning: config.json not found or is malformed. Using default settings.")
CONFIG = {}
# DEPRECATED: Legacy config.json loading removed
# Use MemoryConfig objects with dependency injection instead
CONFIG = {}
DEFAULT_VALUES = {
"llm_name": "openai/qwen-plus",
@@ -68,35 +67,31 @@ DEFAULT_VALUES = {
"reflexion_baseline": "TIME",
}
# DEPRECATED: Legacy global variables for backward compatibility only
# These will be removed in a future version
# Use MemoryConfig objects with dependency injection instead
LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"])
# 阶段 1: 从 runtime.json 加载配置(路径 A
def _load_from_runtime_json() -> Dict[str, Any]:
"""
从 runtime.json 文件加载配置(通过统一配置加载器)
DEPRECATED: Legacy runtime.json loading
使用 overrides.py 的统一配置加载器,按优先级加载:
1. 数据库配置(如果 dbrun.json 中有 config_id/group_id
2. 环境变量配置
3. runtime.json 默认配置
⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead.
Returns:
Dict[str, Any]: 运行时配置字典
Dict[str, Any]: Empty configuration (legacy support only)
"""
try:
# 使用 overrides.py 的统一配置加载器
from app.core.memory.utils.config.overrides import load_unified_config
runtime_cfg = load_unified_config(PROJECT_ROOT)
return runtime_cfg
except Exception as e:
# Fallback: 直接读取 runtime.json
runtime_config_path = os.path.join(PROJECT_ROOT, "runtime.json")
try:
with open(runtime_config_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e2:
pass # print(f"[definitions] ❌ 无法加载 runtime.json: {e2},使用空配置")
return {"selections": {}}
import warnings
warnings.warn(
"Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning,
stacklevel=2
)
return {"selections": {}}
# 阶段 2: 从数据库加载配置(路径 B- 已整合到统一加载器
@@ -104,207 +99,116 @@ def _load_from_runtime_json() -> Dict[str, Any]:
# 保留此函数仅为向后兼容
def _load_from_database() -> Optional[Dict[str, Any]]:
"""
从数据库加载配置(基于 dbrun.json 中的 config_id
DEPRECATED: Legacy database configuration loading
注意:此函数已被统一配置加载器替代,现在直接调用 _load_from_runtime_json
即可获得包含数据库配置的完整配置。
⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead.
Returns:
Optional[Dict[str, Any]]: 配置字典
Optional[Dict[str, Any]]: None (deprecated functionality)
"""
try:
# 直接使用统一配置加载器
return _load_from_runtime_json()
except Exception:
return None
import warnings
warnings.warn(
"Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning,
stacklevel=2
)
return None
# 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)
# 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED
def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
"""
将运行时配置暴露为全局常量供项目使用
这是路径 Aruntime.json和路径 B数据库的汇合点
无论配置来自哪里,都通过这个函数统一暴露为常量。
DEPRECATED: 将运行时配置暴露为全局常量供项目使用
⚠️ This function is deprecated and will be removed in a future version.
Global configuration variables have been eliminated in favor of dependency injection.
Use the new MemoryConfig system instead:
- app.core.memory_config.config.MemoryConfig for configuration objects
- Pass configuration objects as parameters instead of using global variables
Args:
runtime_cfg: 运行时配置字典
"""
global RUNTIME_CONFIG, SELECTIONS, LOGGING_CONFIG
global LANGFUSE_ENABLED, AGENTA_ENABLED, PROMPT_LOG_LEVEL_NAME
global SELECTED_LLM_NAME, SELECTED_EMBEDDING_NAME, SELECTED_CHUNKER_STRATEGY
global SELECTED_GROUP_ID, SELECTED_USER_ID, SELECTED_APPLY_ID, SELECTED_TEST_DATA_INDICES
global SELECTED_LLM_AGENT_NAME, SELECTED_LLM_VERIFY_NAME, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
global SELECTED_LLM_ID, SELECTED_EMBEDDING_ID, SELECTED_RERANK_ID
global REFLEXION_CONFIG, REFLEXION_ENABLED, REFLEXION_ITERATION_PERIOD, REFLEXION_RANGE, REFLEXION_BASELINE
RUNTIME_CONFIG = runtime_cfg
# 可观测性配置
LANGFUSE_ENABLED = RUNTIME_CONFIG.get("langfuse", {}).get("enabled", False)
AGENTA_ENABLED = RUNTIME_CONFIG.get("agenta", {}).get("enabled", False)
# 日志配置
LOGGING_CONFIG = RUNTIME_CONFIG.get("logging", {})
PROMPT_LOG_LEVEL_NAME = LOGGING_CONFIG.get("prompt_level", DEFAULT_VALUES["prompt_level"])
# 选择配置
SELECTIONS = RUNTIME_CONFIG.get("selections", {})
# 基础模型选择
SELECTED_LLM_NAME = SELECTIONS.get("llm_name", DEFAULT_VALUES["llm_name"])
SELECTED_EMBEDDING_NAME = SELECTIONS.get("embedding_name", DEFAULT_VALUES["embedding_name"])
SELECTED_CHUNKER_STRATEGY = SELECTIONS.get("chunker_strategy", DEFAULT_VALUES["chunker_strategy"])
# 分组和用户配置
SELECTED_GROUP_ID = SELECTIONS.get("group_id", DEFAULT_VALUES["group_id"])
SELECTED_USER_ID = SELECTIONS.get("user_id", DEFAULT_VALUES["user_id"])
SELECTED_APPLY_ID = SELECTIONS.get("apply_id", DEFAULT_VALUES["apply_id"])
SELECTED_TEST_DATA_INDICES = SELECTIONS.get("test_data_indices", None)
# 专用 LLM 配置
SELECTED_LLM_AGENT_NAME = SELECTIONS.get("llm_agent_name", DEFAULT_VALUES["llm_agent_name"])
SELECTED_LLM_VERIFY_NAME = SELECTIONS.get("llm_verify_name", DEFAULT_VALUES["llm_verify_name"])
SELECTED_LLM_PICTURE_NAME = SELECTIONS.get("llm_image_recognition", DEFAULT_VALUES["llm_image_recognition"])
SELECTED_LLM_VOICE_NAME = SELECTIONS.get("llm_voice_recognition", DEFAULT_VALUES["llm_voice_recognition"])
# 模型 ID 配置
SELECTED_LLM_ID = SELECTIONS.get("llm_id", None)
SELECTED_EMBEDDING_ID = SELECTIONS.get("embedding_id", None)
SELECTED_RERANK_ID = SELECTIONS.get("rerank_id", None)
import warnings
warnings.warn(
"Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning,
stacklevel=2
)
# 反思配置
REFLEXION_CONFIG = RUNTIME_CONFIG.get("reflexion", {})
REFLEXION_ENABLED = REFLEXION_CONFIG.get("enabled", False)
REFLEXION_ITERATION_PERIOD = REFLEXION_CONFIG.get("iteration_period", DEFAULT_VALUES["reflexion_iteration_period"])
REFLEXION_RANGE = REFLEXION_CONFIG.get("reflexion_range", DEFAULT_VALUES["reflexion_range"])
REFLEXION_BASELINE = REFLEXION_CONFIG.get("baseline", DEFAULT_VALUES["reflexion_baseline"])
# Keep minimal global state for backward compatibility only
# These will be removed in a future version
global RUNTIME_CONFIG, SELECTIONS
RUNTIME_CONFIG = runtime_cfg
SELECTIONS = RUNTIME_CONFIG.get("selections", {})
# All other global variables have been removed
# Use MemoryConfig objects instead
# 初始化:使用统一配置加载器
def _initialize_configuration() -> None:
"""
初始化配置:使用统一配置加载器
DEPRECATED: Legacy configuration initialization
配置加载优先级(由 overrides.py 统一处理):
1. 数据库配置(如果 dbrun.json 中有 config_id/group_id
2. 环境变量配置(.env
3. runtime.json 默认配置
⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead.
"""
try:
# 使用统一配置加载器(已包含所有优先级处理)
runtime_config = _load_from_runtime_json()
# 暴露为全局常量
_expose_runtime_constants(runtime_config)
except Exception as e:
pass # print(f"[definitions] × 配置初始化失败: {e}")
# 使用空配置
_expose_runtime_constants({"selections": {}})
import warnings
warnings.warn(
"Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning,
stacklevel=2
)
# Initialize with empty configuration for backward compatibility
_expose_runtime_constants({"selections": {}})
# 模块加载时自动初始化配置
_initialize_configuration()
# DEPRECATED: Global variables removed
# These variables have been eliminated in favor of dependency injection
# Use MemoryConfig objects instead of accessing global variables
# 公共 API动态重新加载配置
def reload_configuration_from_database(config_id: int | str, force_reload: bool = False) -> bool:
def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool:
"""
动态重新加载配置(从数据库)- 使用统一配置加载器
用于运行时切换配置,例如前端传入新的 config_id 时调用。
注意:此函数仅在内存中覆写配置,不会修改 runtime.json 文件。
DEPRECATED: Legacy configuration reloading
⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead.
For new code, use:
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
Args:
config_id: 配置 ID整数或字符串会自动转换
force_reload: 保留参数以保持向后兼容(已移除缓存逻辑)
config_id: Configuration ID (deprecated)
force_reload: Force reload flag (deprecated)
Returns:
bool: 是否成功重新加载配置
bool: Always returns False (deprecated functionality)
"""
import logging
import warnings
logger = logging.getLogger(__name__)
# 导入审计日志记录器
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
warnings.warn(
"reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning,
stacklevel=2
)
with _config_lock:
try:
from app.core.memory.utils.config.overrides import load_unified_config
except Exception as e:
logger.error(f"[definitions] 导入统一配置加载器失败: {e}")
# 记录配置加载失败
if audit_logger:
audit_logger.log_config_load(
config_id=config_id,
success=False,
details={"error": f"Import failed: {str(e)}"}
)
return False
try:
logger.info(f"[definitions] 开始重新加载配置config_id={config_id}")
# 使用统一配置加载器(指定 config_id
updated_cfg = load_unified_config(PROJECT_ROOT, config_id=config_id)
# 检查是否成功加载
if not updated_cfg or not updated_cfg.get('selections'):
logger.error(f"[definitions] 配置加载失败:数据库中未找到 config_id={config_id} 的配置")
# 记录配置加载失败
if audit_logger:
audit_logger.log_config_load(
config_id=config_id,
success=False,
details={"reason": "config not found in database"}
)
return False
# 重新暴露常量
_expose_runtime_constants(updated_cfg)
logger.info("[definitions] 配置重新加载成功,已暴露常量")
logger.debug(f"[definitions] 配置详情: LLM_ID={updated_cfg.get('selections', {}).get('llm_id')}, "
f"EMBEDDING_ID={updated_cfg.get('selections', {}).get('embedding_id')}")
# 记录成功的配置加载
if audit_logger:
selections = updated_cfg.get('selections', {})
audit_logger.log_config_load(
config_id=config_id,
user_id=selections.get('user_id', None),
group_id=selections.get('group_id', None),
success=True,
details={
"llm_id": selections.get('llm_id'),
"embedding_id": selections.get('embedding_id'),
"chunker_strategy": selections.get('chunker_strategy')
}
)
return True
except Exception as e:
logger.error(f"[definitions] 重新加载配置时发生异常: {e}", exc_info=True)
# 记录配置加载异常
if audit_logger:
audit_logger.log_config_load(
config_id=config_id,
success=False,
details={"error": str(e)}
)
return False
logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. "
"Use MemoryConfig objects with dependency injection instead.")
return False
@@ -312,49 +216,54 @@ def reload_configuration_from_database(config_id: int | str, force_reload: bool
def get_current_config_id() -> Optional[str]:
"""
获取当前使用的 config_id
DEPRECATED: Legacy config ID retrieval
⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead.
Returns:
Optional[str]: 当前的 config_id如果未设置则返回 None
Optional[str]: None (deprecated functionality)
"""
return SELECTIONS.get("config_id", None)
import warnings
warnings.warn(
"get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning,
stacklevel=2
)
return None
def ensure_fresh_config(config_id: Optional[int | str] = None) -> bool:
def ensure_fresh_config(config_id = None) -> bool:
"""
确保使用最新的配置(每次写入操作前调用)
DEPRECATED: Legacy configuration freshness check
如果提供了 config_id则加载该配置
否则从 dbrun.json 读取并加载最新配置。
⚠️ This function is deprecated and will be removed in a future version.
Use MemoryConfig objects with dependency injection instead.
For new code, use:
- app.services.memory_agent_service.MemoryAgentService.load_memory_config()
- app.services.memory_storage_service.MemoryStorageService.load_memory_config()
Args:
config_id: 可选的配置ID整数或字符串会自动转换
config_id: Configuration ID (deprecated)
Returns:
bool: 是否成功加载配置
bool: Always returns False (deprecated functionality)
"""
import logging
import warnings
logger = logging.getLogger(__name__)
with _config_lock:
try:
if config_id:
# 使用指定的 config_id
logger.debug(f"[definitions] 加载指定配置config_id={config_id}")
return reload_configuration_from_database(config_id)
else:
# 从数据库重新加载配置
logger.debug("[definitions] 从数据库重新加载最新配置")
memory_config = _load_from_database()
if not memory_config or not memory_config.get('selections'):
logger.warning("[definitions] 未能从数据库加载配置,使用当前配置")
return False
_expose_memory_constants(memory_config)
return True
except Exception as e:
logger.error(f"[definitions] 加载配置失败: {e}", exc_info=True)
return False
warnings.warn(
"ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.",
DeprecationWarning,
stacklevel=2
)
logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. "
"Use MemoryConfig objects with dependency injection instead.")
return False

View File

@@ -1,611 +0,0 @@
"""
运行时配置覆写工具 - 统一配置加载器
本模块作为统一的配置加载器,负责从多个来源加载配置并按优先级覆写。
配置来源优先级(从高到低):
1. 数据库配置PostgreSQL data_config 表)
2. 环境变量配置(.env 文件)
3. 默认配置runtime.json 文件)
支持的配置加载方式:
- 基于 config_id 的配置加载(从 dbrun.json 读取或前端传入)
- 基于 group_id 的配置加载(从 dbrun.json 读取)
- 环境变量覆写(支持 INTERNAL/EXTERNAL 网络模式)
主要功能:
- 从 PostgreSQL 数据库读取配置
- 从环境变量读取配置
- 从 runtime.json 读取默认配置
- 按优先级覆写配置项(仅在内存中,不修改文件)
- 支持多种配置字段selections、statement_extraction、deduplication、forgetting_engine、pruning、reflexion
使用场景:
- 应用启动时自动加载配置
- 前端切换配置时动态重新加载
- 多租户场景下的配置隔离
- 内外网环境自动切换
"""
import os
import json
import socket
from typing import Optional, Dict, Any, Literal
NetworkMode = Literal['internal', 'external']
def _set_if_present(target: Dict[str, Any], target_key: str, src: Dict[str, Any], src_key: str, caster):
"""安全地设置目标字典的值(如果源字典中存在且不为 None
Args:
target: 目标字典
target_key: 目标字典的键
src: 源字典
src_key: 源字典的键
caster: 类型转换函数
"""
try:
if src_key in src and src.get(src_key) is not None:
try:
target[target_key] = caster(src.get(src_key))
except Exception:
pass
except Exception:
pass
def _to_bool(val: Any) -> bool:
"""将各种类型的值转换为布尔值
支持的输入:
- bool: 直接返回
- int/float: 非零为 True
- str: "true", "1", "on", "yes" 为 True"false", "0", "off", "no" 为 False
Args:
val: 要转换的值
Returns:
bool: 转换后的布尔值
"""
try:
if isinstance(val, bool):
return val
if isinstance(val, (int, float)):
return bool(val)
if isinstance(val, str):
m = val.strip().lower()
if m in {"true", "1", "on", "yes"}:
return True
if m in {"false", "0", "off", "no"}:
return False
return bool(val)
except Exception:
return False
def _make_pgsql_conn() -> Optional[object]:
"""创建 PostgreSQL 数据库连接
使用环境变量配置连接参数:
- DB_HOST: 数据库主机地址(默认 localhost
- DB_PORT: 数据库端口(默认 5432
- DB_USER: 数据库用户名
- DB_PASSWORD: 数据库密码
- DB_NAME: 数据库名称
Returns:
Optional[object]: 数据库连接对象,失败时返回 None
"""
host = os.getenv("DB_HOST", "localhost")
user = os.getenv("DB_USER")
password = os.getenv("DB_PASSWORD")
dbname = os.getenv("DB_NAME")
port_str = os.getenv("DB_PORT")
try:
import psycopg2 # type: ignore
from psycopg2.extras import RealDictCursor # type: ignore
port = int(port_str) if port_str else 5432
conn = psycopg2.connect(
host=host,
port=port,
user=user,
password=password,
dbname=dbname,
)
conn.autocommit = True
return conn
except Exception:
return None
def _fetch_db_config_by_group_id(group_id: str) -> Optional[Dict[str, Any]]:
"""根据 group_id 从数据库查询配置
Args:
group_id: 组标识符
Returns:
Optional[Dict[str, Any]]: 配置字典,未找到时返回 None
"""
conn = _make_pgsql_conn()
if conn is None:
return None
try:
from psycopg2.extras import RealDictCursor # type: ignore
cur = conn.cursor(cursor_factory=RealDictCursor)
try:
cur.execute("SET TIME ZONE %s", ("Asia/Shanghai",))
except Exception:
pass
sql = (
"SELECT group_id, user_id, apply_id, chunker_strategy, "
" enable_llm_dedup_blockwise, enable_llm_disambiguation "
"FROM data_config WHERE group_id = %s ORDER BY updated_at DESC LIMIT 1"
)
cur.execute(sql, (group_id,))
row = cur.fetchone()
return row if row else None
except Exception:
return None
finally:
try:
cur.close()
except Exception:
pass
try:
conn.close()
except Exception:
pass
def _fetch_db_config_by_config_id(config_id: int | str) -> Optional[Dict[str, Any]]:
"""根据 config_id 从数据库查询配置
Args:
config_id: 配置标识符(整数或字符串,会自动转换为整数)
Returns:
Optional[Dict[str, Any]]: 配置字典,未找到时返回 None
"""
conn = _make_pgsql_conn()
if conn is None:
try:
pass
except Exception:
pass
return None
try:
from psycopg2.extras import RealDictCursor # type: ignore
cur = conn.cursor(cursor_factory=RealDictCursor)
try:
cur.execute("SET TIME ZONE %s", ("Asia/Shanghai",))
except Exception:
pass
# config_id 在数据库中是 Integer 类型,需要转换
try:
config_id_int = int(config_id)
except (ValueError, TypeError) as e:
try:
pass
except Exception:
pass
return None
sql = (
"SELECT config_id, group_id, user_id, apply_id, chunker_strategy, "
" enable_llm_dedup_blockwise, enable_llm_disambiguation, "
" deep_retrieval, t_type_strict, t_name_strict, t_overall, state, "
" statement_granularity, include_dialogue_context, max_context, "
" \"offset\" AS offset, lambda_time, lambda_mem, "
" pruning_enabled, pruning_scene, pruning_threshold, "
" llm_id, embedding_id "
"FROM data_config WHERE config_id = %s LIMIT 1"
)
cur.execute(sql, (config_id_int,))
row = cur.fetchone()
if row:
try:
pass
except Exception:
pass
else:
pass
return row if row else None
except Exception as e:
pass
return None
finally:
try:
cur.close()
except Exception:
pass
try:
conn.close()
except Exception:
pass
def _load_dbrun_group_id(project_root: str) -> Optional[str]:
"""从 dbrun.json 读取 group_id
Args:
project_root: 项目根目录路径
Returns:
Optional[str]: group_id未找到时返回 None
"""
try:
path = os.path.join(project_root, "dbrun.json")
if not os.path.isfile(path):
return None
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
if "group_id" in data:
return str(data.get("group_id"))
sel = data.get("selections", {})
if isinstance(sel, dict) and "group_id" in sel:
return str(sel.get("group_id"))
return None
except Exception:
return None
def _load_dbrun_config_id(project_root: str) -> Optional[str]:
"""从 dbrun.json 读取 config_id
Args:
project_root: 项目根目录路径
Returns:
Optional[str]: config_id未找到时返回 None
"""
try:
path = os.path.join(project_root, "dbrun.json")
if not os.path.isfile(path):
return None
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
if "config_id" in data:
return str(data.get("config_id"))
sel = data.get("selections", {})
if isinstance(sel, dict) and "config_id" in sel:
return str(sel.get("config_id"))
return None
except Exception:
return None
def _apply_overrides_from_db_row(
runtime_cfg: Dict[str, Any],
db_row: Optional[Dict[str, Any]],
identifier: str,
identifier_type: str = "config_id"
) -> Dict[str, Any]:
"""从数据库行数据覆写运行时配置(统一处理函数)
Args:
runtime_cfg: 运行时配置字典
db_row: 数据库查询结果行
identifier: 标识符值group_id 或 config_id
identifier_type: 标识符类型("group_id""config_id"
Returns:
Dict[str, Any]: 覆写后的运行时配置
"""
try:
selections = runtime_cfg.setdefault("selections", {})
selections[identifier_type] = identifier
if not db_row:
return runtime_cfg
# 覆写 selections 字段
for tk in ("group_id", "user_id", "apply_id", "chunker_strategy", "state",
"t_type_strict", "t_name_strict", "t_overall",
"statement_granularity", "include_dialogue_context"):
_set_if_present(selections, tk, db_row, tk, str)
# 特殊处理 UUID 字段,确保转换为字符串格式
for uuid_field in ("llm_id", "embedding_id"):
if uuid_field in db_row and db_row.get(uuid_field) is not None:
try:
value = db_row.get(uuid_field)
# 如果是 UUID 对象,转换为字符串(带连字符的标准格式)
if hasattr(value, 'hex'):
selections[uuid_field] = str(value)
else:
selections[uuid_field] = str(value)
except Exception:
pass
# 覆写 statement_extraction 字段
stmt = runtime_cfg.setdefault("statement_extraction", {})
_set_if_present(stmt, "statement_granularity", db_row, "statement_granularity", int)
_set_if_present(stmt, "include_dialogue_context", db_row, "include_dialogue_context", _to_bool)
_set_if_present(stmt, "max_dialogue_context_chars", db_row, "max_context", int)
# 覆写 deduplication 字段
dedup = runtime_cfg.setdefault("deduplication", {})
for tk in ("enable_llm_dedup_blockwise", "enable_llm_disambiguation"):
_set_if_present(dedup, tk, db_row, tk, _to_bool)
_set_if_present(dedup, "deep_retrieval", db_row, "deep_retrieval", _to_bool)
# 覆写 forgetting_engine 字段
forgetting = runtime_cfg.setdefault("forgetting_engine", {})
_set_if_present(forgetting, "offset", db_row, "offset", float)
_set_if_present(forgetting, "lambda_time", db_row, "lambda_time", float)
_set_if_present(forgetting, "lambda_mem", db_row, "lambda_mem", float)
# 覆写 pruning 字段
pruning = runtime_cfg.setdefault("pruning", {})
_set_if_present(pruning, "enabled", db_row, "pruning_enabled", _to_bool)
_set_if_present(pruning, "scene", db_row, "pruning_scene", str)
# 阈值需要转为 float且限制在 [0.0, 0.9]
try:
if "pruning_threshold" in db_row and db_row.get("pruning_threshold") is not None:
thr = float(db_row.get("pruning_threshold"))
thr = max(0.0, min(0.9, thr)) # 限制在 [0.0, 0.9]
pruning["threshold"] = thr
except Exception:
pass
return runtime_cfg
except Exception as e:
pass
return runtime_cfg
def apply_runtime_overrides_by_group(project_root: str, runtime_cfg: Dict[str, Any]) -> Dict[str, Any]:
"""基于 group_id 从数据库覆写运行时配置
工作流程:
1. 从 dbrun.json 读取 group_id
2. 根据 group_id 查询数据库配置
3. 覆写运行时配置(仅在内存中)
Args:
project_root: 项目根目录路径
runtime_cfg: 运行时配置字典
Returns:
Dict[str, Any]: 覆写后的运行时配置
"""
try:
selected_gid = _load_dbrun_group_id(project_root)
if not selected_gid:
return runtime_cfg
db_row = _fetch_db_config_by_group_id(selected_gid)
if not db_row:
# 如果数据库中没有配置,仍然设置 group_id
runtime_cfg.setdefault("selections", {})["group_id"] = selected_gid
return runtime_cfg
return _apply_overrides_from_db_row(runtime_cfg, db_row, selected_gid, "group_id")
except Exception:
return runtime_cfg
def apply_runtime_overrides_by_config(project_root: str, runtime_cfg: Dict[str, Any]) -> Dict[str, Any]:
"""基于 config_id 从数据库覆写运行时配置(从 dbrun.json 读取)
工作流程:
1. 从 dbrun.json 读取 config_id
2. 根据 config_id 查询数据库配置
3. 覆写运行时配置(仅在内存中)
Args:
project_root: 项目根目录路径
runtime_cfg: 运行时配置字典
Returns:
Dict[str, Any]: 覆写后的运行时配置
"""
try:
selected_cid = _load_dbrun_config_id(project_root)
if not selected_cid:
return runtime_cfg
db_row = _fetch_db_config_by_config_id(selected_cid)
return _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id")
except Exception:
return runtime_cfg
def apply_runtime_overrides_with_config_id(
project_root: str,
runtime_cfg: Dict[str, Any],
config_id: str
) -> tuple[Dict[str, Any], bool]:
"""使用指定的 config_id 从数据库覆写运行时配置(不读 dbrun.json
用于前端动态切换配置的场景。
Args:
project_root: 项目根目录路径
runtime_cfg: 运行时配置字典
config_id: 配置标识符
Returns:
tuple[Dict[str, Any], bool]: (覆写后的运行时配置, 是否成功从数据库加载)
"""
try:
selected_cid = str(config_id).strip()
if not selected_cid:
return runtime_cfg, False
db_row = _fetch_db_config_by_config_id(selected_cid)
if db_row is None:
return runtime_cfg, False
updated_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id")
return updated_cfg, True
except Exception as e:
pass
return runtime_cfg, False
# ============================================================================
# 以下函数已注释:不再需要网络模式自动检测功能
# ============================================================================
# def get_server_ip() -> str:
# """
# 获取当前服务器的IP地址
#
# Returns:
# 服务器IP地址字符串
# """
# try:
# # 方式1从环境变量获取优先
# server_ip = os.getenv('SERVER_IP')
# if server_ip and server_ip not in ['127.0.0.1', 'localhost', '0.0.0.0']:
# return server_ip
#
# # 方式2通过socket获取
# hostname = socket.gethostname()
# ip_address = socket.gethostbyname(hostname)
#
# # 如果是本地回环地址尝试获取真实IP
# if ip_address.startswith('127.'):
# # 尝试连接外部地址来获取本机IP
# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# try:
# s.connect(('8.8.8.8', 80))
# ip_address = s.getsockname()[0]
# finally:
# s.close()
#
# return ip_address
# except Exception as e:
# print(f"[overrides] 获取服务器IP失败: {e},使用默认值 127.0.0.1")
# return '127.0.0.1'
# def auto_detect_network_mode() -> NetworkMode:
# """
# 自动检测网络模式基于服务器IP
#
# 规则:
# - 如果服务器IP在内网IP列表中 → internal内网
# - 其他IP → external外网
#
# 可以通过环境变量 INTERNAL_SERVER_IPS 自定义内网IP列表逗号分隔
#
# Returns:
# 'internal' 或 'external'
# """
# server_ip = get_server_ip()
#
# # 从环境变量获取内网IP列表支持多个IP逗号分隔
# internal_ips_str = os.getenv('INTERNAL_SERVER_IPS', '119.45.181.55')
# internal_ips = [ip.strip() for ip in internal_ips_str.split(',')]
#
# # 判断当前IP是否在内网IP列表中
# if server_ip in internal_ips:
# print(f"[overrides] 自动检测服务器IP {server_ip} 属于内网,使用 INTERNAL 配置")
# return 'internal'
# else:
# print(f"[overrides] 自动检测服务器IP {server_ip} 属于外网,使用 EXTERNAL 配置")
# return 'external'
# ============================================================================
# 环境变量覆写功能已废弃 - 不再使用
# ============================================================================
# def _apply_env_var_overrides(runtime_cfg: Dict[str, Any], network_mode: NetworkMode = None, force_override: bool = False) -> Dict[str, Any]:
# """
# 从环境变量覆写配置(已废弃)
# """
# return runtime_cfg
def load_unified_config(
project_root: str,
config_id: Optional[int | str] = None,
group_id: Optional[str] = None,
network_mode: NetworkMode = None,
env_override_models: bool = True
) -> Dict[str, Any]:
"""
统一配置加载器 - 按优先级加载配置
配置加载优先级:
1. PG数据库配置最高优先级通过 dbrun.json 中的 config_id 读取)
2. runtime.json 默认配置(最低优先级)
Args:
project_root: 项目根目录路径
config_id: 配置ID整数或字符串可选优先从 dbrun.json 读取)
group_id: 组ID可选
network_mode: 已废弃,保留参数仅为向后兼容
env_override_models: 已废弃,保留参数仅为向后兼容
Returns:
Dict[str, Any]: 最终的运行时配置
"""
try:
# 步骤 1: 加载 runtime.json 作为基础配置
runtime_config_path = os.path.join(project_root, "runtime.json")
try:
with open(runtime_config_path, "r", encoding="utf-8") as f:
runtime_cfg = json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
runtime_cfg = {"selections": {}}
# 步骤 2: 尝试从 dbrun.json 读取 config_id 并应用数据库配置(最高优先级)
if config_id:
# 优先使用传入的 config_id
db_row = _fetch_db_config_by_config_id(config_id)
if db_row:
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, config_id, "config_id")
pass
elif group_id:
# 其次使用 group_id
db_row = _fetch_db_config_by_group_id(group_id)
if db_row:
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, group_id, "group_id")
pass
else:
# 尝试从 dbrun.json 读取
dbrun_config_id = _load_dbrun_config_id(project_root)
if dbrun_config_id:
db_row = _fetch_db_config_by_config_id(dbrun_config_id)
if db_row:
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, dbrun_config_id, "config_id")
pass
else:
dbrun_group_id = _load_dbrun_group_id(project_root)
if dbrun_group_id:
db_row = _fetch_db_config_by_group_id(dbrun_group_id)
if db_row:
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, dbrun_group_id, "group_id")
pass
return runtime_cfg
except Exception as e:
return {"selections": {}}
# 向后兼容的别名
apply_runtime_overrides = apply_runtime_overrides_by_config

View File

@@ -0,0 +1,11 @@
"""Embedder utilities module."""
from app.core.memory.utils.embedder.embedder_utils import (
get_embedder_client,
get_embedder_client_from_config,
)
__all__ = [
"get_embedder_client",
"get_embedder_client_from_config",
]

View File

@@ -0,0 +1,81 @@
"""Embedder Client Utilities
This module provides centralized functions for creating embedder clients.
"""
from typing import TYPE_CHECKING
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.models.base import RedBearModelConfig
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
def get_embedder_client_from_config(memory_config: "MemoryConfig") -> OpenAIEmbedderClient:
"""
Get embedder client from MemoryConfig object.
**PREFERRED METHOD**: Use this function in production code when you have a MemoryConfig object.
This ensures proper configuration management and multi-tenant support.
Args:
memory_config: MemoryConfig object containing embedding_model_id
Returns:
OpenAIEmbedderClient: Initialized embedder client
Raises:
ValueError: If embedding model ID is not configured or client initialization fails
Example:
>>> embedder_client = get_embedder_client_from_config(memory_config)
"""
if not memory_config.embedding_model_id:
raise ValueError(
f"Configuration {memory_config.config_id} has no embedding model configured"
)
return get_embedder_client(str(memory_config.embedding_model_id))
def get_embedder_client(embedding_id: str) -> OpenAIEmbedderClient:
"""
Get embedder client by model ID.
**LEGACY/TEST METHOD**: Use this function only for:
- Test/evaluation code where you have a model ID directly
- Legacy code that hasn't been migrated to MemoryConfig yet
For production code with MemoryConfig, use get_embedder_client_from_config() instead.
Args:
embedding_id: Embedding model ID (required)
Returns:
OpenAIEmbedderClient: Initialized embedder client
Raises:
ValueError: If embedding_id is not provided or client initialization fails
Example:
>>> # For tests/evaluations only
>>> embedder_client = get_embedder_client("model-uuid-string")
"""
if not embedding_id:
raise ValueError("Embedding ID is required but was not provided")
try:
embedder_config_dict = get_embedder_config(embedding_id)
except Exception as e:
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e
try:
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
return embedder_client
except Exception as e:
model_name = embedder_config_dict.get('model_name', 'unknown')
raise ValueError(
f"Failed to initialize embedder client for model '{model_name}': {str(e)}"
) from e

View File

@@ -1,67 +1,107 @@
import os
from pydantic import BaseModel
from typing import TYPE_CHECKING
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.memory.utils.config import definitions as config_defs
from app.core.models.base import RedBearModelConfig
from pydantic import BaseModel
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
async def handle_response(response: type[BaseModel]) -> dict:
return response.model_dump()
def get_llm_client(llm_id: str | None = None):
llm_id = llm_id or config_defs.SELECTED_LLM_ID
def get_llm_client_from_config(memory_config: "MemoryConfig") -> OpenAIClient:
"""
Get LLM client from MemoryConfig object.
**PREFERRED METHOD**: Use this function in production code when you have a MemoryConfig object.
This ensures proper configuration management and multi-tenant support.
Args:
memory_config: MemoryConfig object containing llm_model_id
Returns:
OpenAIClient: Initialized LLM client
Raises:
ValueError: If LLM model ID is not configured or client initialization fails
Example:
>>> llm_client = get_llm_client_from_config(memory_config)
"""
if not memory_config.llm_model_id:
raise ValueError(
f"Configuration {memory_config.config_id} has no LLM model configured"
)
return get_llm_client(str(memory_config.llm_model_id))
# Validate LLM ID exists before attempting to get config
def get_llm_client(llm_id: str):
"""
Get LLM client by model ID.
**LEGACY/TEST METHOD**: Use this function only for:
- Test/evaluation code where you have a model ID directly
- Legacy code that hasn't been migrated to MemoryConfig yet
For production code with MemoryConfig, use get_llm_client_from_config() instead.
Args:
llm_id: LLM model ID (required)
Returns:
OpenAIClient: Initialized LLM client
Raises:
ValueError: If llm_id is not provided or client initialization fails
Example:
>>> # For tests/evaluations only
>>> llm_client = get_llm_client("model-uuid-string")
"""
if not llm_id:
raise ValueError("LLM ID is required but was not provided")
try:
model_config = get_model_config(llm_id)
except Exception as e:
# Re-raise with clear error message about invalid LLM ID
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
try:
# 移除调试打印,避免污染终端输出
# print(model_config)
llm_client = OpenAIClient(RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url")
),type_=model_config.get("type"))
# print(llm.dict())
return llm_client
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
def get_reranker_client(rerank_id: str | None = None):
def get_reranker_client(rerank_id: str):
"""
Get an LLM client configured for reranking.
Args:
rerank_id: Optional reranker model ID. If None, uses SELECTED_RERANK_ID.
rerank_id: Reranker model ID (required)
Returns:
OpenAIClient: Initialized client for the reranker model
Raises:
ValueError: If rerank_id is invalid or client initialization fails
ValueError: If rerank_id is not provided or client initialization fails
"""
rerank_id = rerank_id or config_defs.SELECTED_RERANK_ID
# Validate rerank ID exists before attempting to get config
if not rerank_id:
raise ValueError("Rerank ID is required but was not provided")
try:
model_config = get_model_config(rerank_id)
except Exception as e:
# Re-raise with clear error message about invalid rerank ID
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
try:

View File

@@ -10,28 +10,29 @@
从 app.core.memory.src.data_config_api.self_reflexion 迁移而来。
"""
import os
import asyncio
import json
import logging
import asyncio
from typing import List, Dict, Any
import os
import uuid
from typing import Any, Dict, List
#TODO: Fix this
# Default values (previously from definitions.py)
REFLEXION_ENABLED = os.getenv("REFLEXION_ENABLED", "false").lower() == "true"
REFLEXION_ITERATION_PERIOD = os.getenv("REFLEXION_ITERATION_PERIOD", "3")
REFLEXION_RANGE = os.getenv("REFLEXION_RANGE", "retrieval")
REFLEXION_BASELINE = os.getenv("REFLEXION_BASELINE", "TIME")
from app.core.memory.utils.config.definitions import (
REFLEXION_ENABLED,
REFLEXION_ITERATION_PERIOD,
REFLEXION_RANGE,
REFLEXION_BASELINE,
)
from app.db import get_db
from sqlalchemy.orm import Session
from app.models.retrieval_info import RetrievalInfo
from app.core.memory.utils.config.get_data import get_data
from app.core.memory.utils.self_reflexion_utils.evaluate import conflict
from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion
from app.db import get_db
from app.models.retrieval_info import RetrievalInfo
from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from sqlalchemy.orm import Session
# 并发限制(可通过环境变量覆盖)
CONCURRENCY = int(os.getenv("REFLEXION_CONCURRENCY", "5"))

View File

@@ -1,6 +1,21 @@
"""
Validators for file upload system.
Validators package for various validation utilities.
"""
from app.core.validators.file_validator import FileValidator, ValidationResult
from app.core.validators.memory_config_validators import (
validate_and_resolve_model_id,
validate_embedding_model,
validate_llm_model,
validate_model_exists_and_active,
)
__all__ = ["FileValidator", "ValidationResult"]
__all__ = [
# File validators
"FileValidator",
"ValidationResult",
# Memory config validators
"validate_model_exists_and_active",
"validate_and_resolve_model_id",
"validate_embedding_model",
"validate_llm_model",
]

View File

@@ -0,0 +1,250 @@
# -*- coding: utf-8 -*-
"""Memory Configuration Validators
This module provides validation functions for memory configuration models.
Functions:
validate_model_exists_and_active: Validate model exists and is active
validate_and_resolve_model_id: Validate and resolve model ID with DB lookup
validate_embedding_model: Validate embedding model availability
validate_llm_model: Validate LLM model availability
"""
import time
from typing import Optional, Union
from uuid import UUID
from app.core.logging_config import get_config_logger
from app.schemas.memory_config_schema import (
InvalidConfigError,
ModelInactiveError,
ModelNotFoundError,
)
from sqlalchemy.orm import Session
logger = get_config_logger()
def _parse_model_id(model_id: Union[str, UUID, None], model_type: str,
config_id: Optional[int] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
"""Parse model ID from string or UUID."""
if model_id is None:
return None
if isinstance(model_id, UUID):
return model_id
if isinstance(model_id, str):
if not model_id.strip():
return None
try:
return UUID(model_id.strip())
except ValueError:
raise InvalidConfigError(
f"Invalid UUID format for {model_type} model ID: '{model_id}'",
field_name=f"{model_type}_model_id",
invalid_value=model_id,
config_id=config_id,
workspace_id=workspace_id
)
raise InvalidConfigError(
f"Invalid type for {model_type} model ID: expected str or UUID, got {type(model_id).__name__}",
field_name=f"{model_type}_model_id",
invalid_value=model_id,
config_id=config_id,
workspace_id=workspace_id
)
def validate_model_exists_and_active(
model_id: UUID,
model_type: str,
db: Session,
tenant_id: Optional[UUID] = None,
config_id: Optional[int] = None,
workspace_id: Optional[UUID] = None
) -> tuple[str, bool]:
"""Validate that a model exists and is active.
Args:
model_id: Model UUID to validate
model_type: Type of model ("llm", "embedding", "rerank")
db: Database session
tenant_id: Optional tenant ID for filtering
config_id: Optional configuration ID for error context
workspace_id: Optional workspace ID for error context
Returns:
Tuple of (model_name, is_active)
Raises:
ModelNotFoundError: If model does not exist
ModelInactiveError: If model exists but is inactive
"""
from app.repositories.model_repository import ModelConfigRepository
start_time = time.time()
try:
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
elapsed_ms = (time.time() - start_time) * 1000
if not model:
logger.warning(
"Model not found",
extra={"model_id": str(model_id), "model_type": model_type, "elapsed_ms": elapsed_ms}
)
raise ModelNotFoundError(
model_id=model_id,
model_type=model_type,
config_id=config_id,
workspace_id=workspace_id,
message=f"{model_type.title()} model {model_id} not found"
)
if not model.is_active:
logger.warning(
"Model inactive",
extra={"model_id": str(model_id), "model_name": model.name, "elapsed_ms": elapsed_ms}
)
raise ModelInactiveError(
model_id=model_id,
model_name=model.name,
model_type=model_type,
config_id=config_id,
workspace_id=workspace_id,
message=f"{model_type.title()} model {model_id} ({model.name}) is inactive"
)
logger.debug(
"Model validation successful",
extra={"model_id": str(model_id), "model_name": model.name, "elapsed_ms": elapsed_ms}
)
return model.name, model.is_active
except (ModelNotFoundError, ModelInactiveError):
raise
except Exception as e:
logger.error(f"Model validation failed: {e}", exc_info=True)
raise
def validate_and_resolve_model_id(
model_id_str: Union[str, UUID, None],
model_type: str,
db: Session,
tenant_id: Optional[UUID] = None,
required: bool = False,
config_id: Optional[int] = None,
workspace_id: Optional[UUID] = None
) -> tuple[Optional[UUID], Optional[str]]:
"""Validate and resolve a model ID, checking existence and active status.
Returns:
Tuple of (validated_uuid, model_name) or (None, None) if not required and empty
"""
if model_id_str is None or (isinstance(model_id_str, str) and not model_id_str.strip()):
if required:
raise InvalidConfigError(
f"{model_type.title()} model ID is required",
field_name=f"{model_type}_model_id",
invalid_value=model_id_str,
config_id=config_id,
workspace_id=workspace_id
)
return None, None
model_uuid = _parse_model_id(model_id_str, model_type, config_id, workspace_id)
if model_uuid is None:
if required:
raise InvalidConfigError(
f"{model_type.title()} model ID is required",
field_name=f"{model_type}_model_id",
invalid_value=model_id_str,
config_id=config_id,
workspace_id=workspace_id
)
return None, None
model_name, _ = validate_model_exists_and_active(
model_uuid, model_type, db, tenant_id, config_id, workspace_id
)
return model_uuid, model_name
def validate_embedding_model(
config_id: int,
embedding_id: Union[str, UUID, None],
db: Session,
tenant_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None
) -> UUID:
"""Validate that embedding model is available and return its UUID.
Raises:
InvalidConfigError: If embedding_id is not provided or invalid
ModelNotFoundError: If embedding model does not exist
ModelInactiveError: If embedding model is inactive
"""
if embedding_id is None or (isinstance(embedding_id, str) and not embedding_id.strip()):
raise InvalidConfigError(
f"Configuration {config_id} has no embedding model configured",
field_name="embedding_model_id",
invalid_value=embedding_id,
config_id=config_id,
workspace_id=workspace_id
)
embedding_uuid, _ = validate_and_resolve_model_id(
embedding_id, "embedding", db, tenant_id, required=True,
config_id=config_id, workspace_id=workspace_id
)
if embedding_uuid is None:
raise InvalidConfigError(
f"Configuration {config_id} has no embedding model configured",
field_name="embedding_model_id",
invalid_value=embedding_id,
config_id=config_id,
workspace_id=workspace_id
)
return embedding_uuid
def validate_llm_model(
config_id: int,
llm_id: Union[str, UUID, None],
db: Session,
tenant_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None
) -> UUID:
"""Validate that LLM model is available and return its UUID.
Raises:
InvalidConfigError: If llm_id is not provided or invalid
ModelNotFoundError: If LLM model does not exist
ModelInactiveError: If LLM model is inactive
"""
if llm_id is None or (isinstance(llm_id, str) and not llm_id.strip()):
raise InvalidConfigError(
f"Configuration {config_id} has no LLM model configured",
field_name="llm_model_id",
invalid_value=llm_id,
config_id=config_id,
workspace_id=workspace_id
)
llm_uuid, _ = validate_and_resolve_model_id(
llm_id, "llm", db, tenant_id, required=True,
config_id=config_id, workspace_id=workspace_id
)
if llm_uuid is None:
raise InvalidConfigError(
f"Configuration {config_id} has no LLM model configured",
field_name="llm_model_id",
invalid_value=llm_id,
config_id=config_id,
workspace_id=workspace_id
)
return llm_uuid

View File

@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
"""Memory Configuration Model - Backward Compatibility
This module provides backward compatibility for imports.
All classes have been moved to app.schemas.memory_config_schema.
DEPRECATED: Import from app.schemas.memory_config_schema instead.
"""
# Re-export for backward compatibility
from app.schemas.memory_config_schema import (
ConfigurationError,
InvalidConfigError,
MemoryConfig,
MemoryConfigValidation,
ModelInactiveError,
ModelNotFoundError,
ModelValidation,
WorkspaceNotFoundError,
WorkspaceValidation,
validate_memory_config_data,
validate_model_data,
validate_workspace_data,
)
__all__ = [
"ConfigurationError",
"InvalidConfigError",
"MemoryConfig",
"MemoryConfigValidation",
"ModelInactiveError",
"ModelNotFoundError",
"ModelValidation",
"WorkspaceNotFoundError",
"WorkspaceValidation",
"validate_memory_config_data",
"validate_model_data",
"validate_workspace_data",
]

View File

@@ -8,24 +8,26 @@ Classes:
DataConfigRepository: 数据配置仓储类提供CRUD操作
"""
from typing import Dict, List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import desc
import uuid
from typing import Dict, List, Optional, Tuple
from app.core.logging_config import get_config_logger, get_db_logger
from app.models.data_config_model import DataConfig
from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate,
ConfigParamsDelete,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
ConfigKey,
)
from app.core.logging_config import get_db_logger
from sqlalchemy import desc
from sqlalchemy.orm import Session
# 获取数据库专用日志器
db_logger = get_db_logger()
# 获取配置专用日志器
config_logger = get_config_logger()
class DataConfigRepository:
@@ -443,7 +445,129 @@ class DataConfigRepository:
except Exception as e:
db_logger.error(f"根据ID查询数据配置失败: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_config_with_workspace(db: Session, config_id: int) -> Optional[tuple]:
"""Get data config and its associated workspace information
Args:
db: Database session
config_id: Configuration ID
Returns:
Optional[tuple]: (DataConfig, Workspace) tuple, None if not found
Raises:
ValueError: Raised when config exists but workspace doesn't
"""
import time
from app.models.workspace_model import Workspace
start_time = time.time()
# Log configuration loading start
config_logger.info(
"Loading configuration with workspace",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id
}
)
db_logger.debug(f"Querying data config and workspace: config_id={config_id}")
try:
# Use join query to get both config and workspace
result = db.query(DataConfig, Workspace).join(
Workspace, DataConfig.workspace_id == Workspace.id
).filter(DataConfig.config_id == config_id).first()
elapsed_ms = (time.time() - start_time) * 1000
if not result:
# Check if config exists but workspace is missing
config_only = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if config_only:
if config_only.workspace_id is None:
config_logger.error(
"Configuration has no associated workspace ID",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"workspace_id": None,
"load_result": "no_workspace_id",
"elapsed_ms": elapsed_ms
}
)
db_logger.error(f"Data config {config_id} has no associated workspace ID")
raise ValueError(f"Configuration {config_id} has no associated workspace")
else:
config_logger.error(
"Configuration references non-existent workspace",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"workspace_id": str(config_only.workspace_id),
"load_result": "workspace_not_found",
"elapsed_ms": elapsed_ms
}
)
db_logger.error(f"Data config {config_id} references non-existent workspace {config_only.workspace_id}")
raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
config_logger.debug(
"Configuration not found",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"load_result": "not_found",
"elapsed_ms": elapsed_ms
}
)
db_logger.debug(f"Data config not found: config_id={config_id}")
return None
config, workspace = result
# Log successful configuration loading
config_logger.info(
"Configuration with workspace loaded successfully",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"config_name": config.config_name,
"workspace_id": str(workspace.id),
"workspace_name": workspace.name,
"tenant_id": str(workspace.tenant_id),
"load_result": "success",
"elapsed_ms": elapsed_ms
}
)
db_logger.debug(f"Data config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
return (config, workspace)
except ValueError:
# Re-raise known business exceptions
raise
except Exception as e:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Failed to load configuration with workspace",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"load_result": "error",
"error_type": type(e).__name__,
"error_message": str(e),
"elapsed_ms": elapsed_ms
},
exc_info=True
)
db_logger.error(f"Failed to query data config and workspace: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]:
"""获取所有配置参数

View File

@@ -8,13 +8,13 @@ class UserInput(BaseModel):
history: list[dict]
search_switch: str
group_id: str
config_id: Optional[str] = None
config_id: str
class Write_UserInput(BaseModel):
message: str
group_id: str
config_id: Optional[str] = None
config_id: str
class End_User_Information(BaseModel):
end_user_name: str # 这是要更新的用户名

View File

@@ -0,0 +1,451 @@
# -*- coding: utf-8 -*-
"""Memory Configuration Schemas
This module provides schema definitions for memory configuration.
Classes:
MemoryConfig: Immutable memory configuration loaded from database
MemoryConfigValidation: Pydantic model for configuration validation
WorkspaceValidation: Pydantic model for workspace validation
ModelValidation: Pydantic model for model configuration validation
ConfigurationError: Base exception for configuration-related errors
WorkspaceNotFoundError: Raised when workspace does not exist
ModelNotFoundError: Raised when a required model does not exist
ModelInactiveError: Raised when a required model exists but is inactive
InvalidConfigError: Raised when configuration validation fails
"""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, Literal, Optional, Union
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
# ==================== Configuration Exception Classes ====================
class ConfigurationError(Exception):
"""Base exception for configuration-related errors.
This exception includes context information to help with debugging
and provides detailed error messages for different failure scenarios.
"""
def __init__(
self,
message: str,
config_id: Optional[int] = None,
workspace_id: Optional[UUID] = None,
context: Optional[Dict[str, Any]] = None,
):
"""Initialize configuration error with context.
Args:
message: Error message describing the failure
config_id: Optional configuration ID for context
workspace_id: Optional workspace ID for context
context: Optional additional context information
"""
self.config_id = config_id
self.workspace_id = workspace_id
self.context = context or {}
# Build detailed error message with context
detailed_message = message
if config_id is not None:
detailed_message = f"Configuration {config_id}: {message}"
if workspace_id is not None:
detailed_message = f"{detailed_message} (workspace: {workspace_id})"
# Add context information if available
if self.context:
context_str = ", ".join(f"{k}={v}" for k, v in self.context.items())
detailed_message = f"{detailed_message} [Context: {context_str}]"
super().__init__(detailed_message)
class WorkspaceNotFoundError(ConfigurationError):
"""Raised when workspace does not exist."""
def __init__(
self,
workspace_id: UUID,
config_id: Optional[int] = None,
message: Optional[str] = None,
):
if message is None:
message = f"Workspace {workspace_id} not found in database"
context = {"workspace_id": str(workspace_id)}
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
class ModelNotFoundError(ConfigurationError):
"""Raised when a required model does not exist."""
def __init__(
self,
model_id: Union[str, UUID],
model_type: str,
config_id: Optional[int] = None,
workspace_id: Optional[UUID] = None,
message: Optional[str] = None,
):
if message is None:
message = f"{model_type.title()} model {model_id} not found in database"
context = {
"model_id": str(model_id),
"model_type": model_type,
"failure_type": "not_found",
}
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
class ModelInactiveError(ConfigurationError):
"""Raised when a required model exists but is inactive."""
def __init__(
self,
model_id: Union[str, UUID],
model_name: str,
model_type: str,
config_id: Optional[int] = None,
workspace_id: Optional[UUID] = None,
message: Optional[str] = None,
):
if message is None:
message = f"{model_type.title()} model {model_id} ({model_name}) is inactive"
context = {
"model_id": str(model_id),
"model_name": model_name,
"model_type": model_type,
"failure_type": "inactive",
}
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
class InvalidConfigError(ConfigurationError):
"""Raised when configuration validation fails."""
def __init__(
self,
message: str,
field_name: Optional[str] = None,
invalid_value: Optional[Any] = None,
config_id: Optional[int] = None,
workspace_id: Optional[UUID] = None,
):
context = {}
if field_name is not None:
context["field_name"] = field_name
if invalid_value is not None:
context["invalid_value"] = str(invalid_value)
context["invalid_value_type"] = type(invalid_value).__name__
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
# ==================== Pydantic Validation Models ====================
class MemoryConfigValidation(BaseModel):
"""Pydantic model for validating memory configuration data from database."""
config_id: int = Field(..., gt=0, description="Configuration ID must be positive")
config_name: str = Field(..., min_length=1, max_length=255)
workspace_id: UUID = Field(..., description="Workspace UUID")
workspace_name: str = Field(..., min_length=1, max_length=255)
tenant_id: UUID = Field(..., description="Tenant UUID")
embedding_model_id: UUID = Field(..., description="Embedding model UUID (required)")
embedding_model_name: str = Field(..., min_length=1, max_length=255)
llm_model_id: UUID = Field(..., description="LLM model UUID (required)")
llm_model_name: str = Field(..., min_length=1, max_length=255)
rerank_model_id: Optional[UUID] = Field(None, description="Rerank model UUID (optional)")
rerank_model_name: Optional[str] = Field(None, max_length=255)
storage_type: str = Field(..., min_length=1, max_length=50)
chunker_strategy: str = Field(default="RecursiveChunker", min_length=1, max_length=100)
reflexion_enabled: bool = Field(default=False)
reflexion_iteration_period: int = Field(default=3, ge=1, le=100)
reflexion_range: Literal["retrieval", "all"] = Field(default="retrieval")
reflexion_baseline: Literal["time", "fact", "time_and_fact"] = Field(default="time")
llm_params: Dict[str, Any] = Field(default_factory=dict)
embedding_params: Dict[str, Any] = Field(default_factory=dict)
config_version: str = Field(default="2.0", min_length=1, max_length=10)
@field_validator("config_name", "workspace_name", "embedding_model_name", "llm_model_name")
@classmethod
def validate_non_empty_strings(cls, v):
if not v or not v.strip():
raise ValueError("Field cannot be empty or whitespace-only")
return v.strip()
@field_validator("storage_type")
@classmethod
def validate_storage_type(cls, v):
valid_types = ["neo4j", "elasticsearch", "qdrant", "milvus", "chroma"]
if v.lower() not in valid_types:
raise ValueError(f"Storage type must be one of: {valid_types}")
return v.lower()
@field_validator("llm_params", "embedding_params")
@classmethod
def validate_model_params(cls, v):
if not isinstance(v, dict):
raise ValueError("Model parameters must be a dictionary")
reserved_keys = ["model_id", "model_name", "api_key", "base_url"]
for key in v.keys():
if key in reserved_keys:
raise ValueError(f"Model parameters cannot contain reserved parameter '{key}'")
return v
model_config = ConfigDict(validate_assignment=True, extra="forbid")
class WorkspaceValidation(BaseModel):
"""Pydantic model for validating workspace data from database."""
id: UUID = Field(..., description="Workspace UUID")
name: str = Field(..., min_length=1, max_length=255)
tenant_id: UUID = Field(..., description="Tenant UUID")
storage_type: Optional[str] = Field(None, max_length=50)
llm: Optional[str] = Field(None)
embedding: Optional[str] = Field(None)
rerank: Optional[str] = Field(None)
is_active: bool = Field(default=True)
@field_validator("llm", "embedding", "rerank")
@classmethod
def validate_model_ids(cls, v):
if v is None or v == "":
return None
try:
UUID(v.strip())
except ValueError:
raise ValueError("Model ID must be a valid UUID string")
return v.strip()
@field_validator("is_active")
@classmethod
def validate_active_status(cls, v):
if not v:
raise ValueError("Workspace must be active for configuration loading")
return v
model_config = ConfigDict(validate_assignment=True, extra="forbid")
class ModelValidation(BaseModel):
"""Pydantic model for validating model configuration data."""
id: UUID = Field(..., description="Model UUID")
name: str = Field(..., min_length=1, max_length=255)
type: str = Field(..., description="Model type (llm, embedding, rerank)")
tenant_id: UUID = Field(..., description="Tenant UUID")
is_active: bool = Field(..., description="Whether model is active")
is_public: bool = Field(default=False)
@field_validator("type")
@classmethod
def validate_type(cls, v):
valid_types = ["llm", "embedding", "rerank"]
if v.lower() not in valid_types:
raise ValueError(f"Model type must be one of: {valid_types}")
return v.lower()
@field_validator("is_active")
@classmethod
def validate_active_status(cls, v):
if not v:
raise ValueError("Model must be active for configuration use")
return v
model_config = ConfigDict(validate_assignment=True, extra="forbid")
# ==================== Validation Helper Functions ====================
def validate_memory_config_data(
config_data: Dict[str, Any], config_id: Optional[int] = None
) -> MemoryConfigValidation:
"""Validate memory configuration data using Pydantic model."""
try:
return MemoryConfigValidation(**config_data)
except ValidationError as e:
error_messages = []
for error in e.errors():
field_path = " -> ".join(str(loc) for loc in error["loc"])
error_messages.append(f"Field '{field_path}': {error['msg']}")
detailed_message = "Configuration validation failed:\n" + "\n".join(
f" - {msg}" for msg in error_messages
)
first_error = e.errors()[0] if e.errors() else {}
first_field = " -> ".join(str(loc) for loc in first_error.get("loc", []))
raise InvalidConfigError(
detailed_message,
field_name=first_field or None,
invalid_value=first_error.get("input"),
config_id=config_id,
)
def validate_workspace_data(
workspace_data: Dict[str, Any], config_id: Optional[int] = None
) -> WorkspaceValidation:
"""Validate workspace data using Pydantic model."""
try:
return WorkspaceValidation(**workspace_data)
except ValidationError as e:
error_messages = []
for error in e.errors():
field_path = " -> ".join(str(loc) for loc in error["loc"])
error_messages.append(f"Field '{field_path}': {error['msg']}")
detailed_message = "Workspace validation failed:\n" + "\n".join(
f" - {msg}" for msg in error_messages
)
first_error = e.errors()[0] if e.errors() else {}
first_field = " -> ".join(str(loc) for loc in first_error.get("loc", []))
workspace_id = workspace_data.get("id") if isinstance(workspace_data, dict) else None
raise InvalidConfigError(
detailed_message,
field_name=first_field or None,
invalid_value=first_error.get("input"),
config_id=config_id,
workspace_id=workspace_id,
)
def validate_model_data(
model_data: Dict[str, Any], config_id: Optional[int] = None
) -> ModelValidation:
"""Validate model data using Pydantic model."""
try:
return ModelValidation(**model_data)
except ValidationError as e:
error_messages = []
for error in e.errors():
field_path = " -> ".join(str(loc) for loc in error["loc"])
error_messages.append(f"Field '{field_path}': {error['msg']}")
detailed_message = "Model validation failed:\n" + "\n".join(
f" - {msg}" for msg in error_messages
)
first_error = e.errors()[0] if e.errors() else {}
first_field = " -> ".join(str(loc) for loc in first_error.get("loc", []))
raise InvalidConfigError(
detailed_message,
field_name=first_field or None,
invalid_value=first_error.get("input"),
config_id=config_id,
)
# ==================== Immutable Configuration Data Structure ====================
@dataclass(frozen=True)
class MemoryConfig:
"""Immutable memory configuration loaded from database."""
config_id: int
config_name: str
workspace_id: UUID
workspace_name: str
tenant_id: UUID
embedding_model_id: UUID
embedding_model_name: str
llm_model_id: UUID
llm_model_name: str
storage_type: str
chunker_strategy: str
reflexion_enabled: bool
reflexion_iteration_period: int
reflexion_range: str
reflexion_baseline: str
loaded_at: datetime
rerank_model_id: Optional[UUID] = None
rerank_model_name: Optional[str] = None
llm_params: Dict[str, Any] = field(default_factory=dict)
embedding_params: Dict[str, Any] = field(default_factory=dict)
config_version: str = "2.0"
def __post_init__(self):
"""Validate configuration after initialization."""
if not self.config_name or not self.config_name.strip():
raise InvalidConfigError("Configuration name cannot be empty")
if not self.embedding_model_id:
raise InvalidConfigError("Embedding model ID is required")
if not self.llm_model_id:
raise InvalidConfigError("LLM model ID is required")
@classmethod
def from_validated_data(
cls, validated_config: MemoryConfigValidation, loaded_at: datetime
) -> "MemoryConfig":
"""Create MemoryConfig from validated Pydantic data."""
return cls(
config_id=validated_config.config_id,
config_name=validated_config.config_name,
workspace_id=validated_config.workspace_id,
workspace_name=validated_config.workspace_name,
tenant_id=validated_config.tenant_id,
embedding_model_id=validated_config.embedding_model_id,
embedding_model_name=validated_config.embedding_model_name,
storage_type=validated_config.storage_type,
chunker_strategy=validated_config.chunker_strategy,
reflexion_enabled=validated_config.reflexion_enabled,
reflexion_iteration_period=validated_config.reflexion_iteration_period,
reflexion_range=validated_config.reflexion_range,
reflexion_baseline=validated_config.reflexion_baseline,
loaded_at=loaded_at,
llm_model_id=validated_config.llm_model_id,
llm_model_name=validated_config.llm_model_name,
rerank_model_id=validated_config.rerank_model_id,
rerank_model_name=validated_config.rerank_model_name,
llm_params=validated_config.llm_params,
embedding_params=validated_config.embedding_params,
config_version=validated_config.config_version,
)
def get_model_summary(self) -> Dict[str, Optional[str]]:
"""Get a summary of configured models."""
return {
"llm": self.llm_model_name,
"embedding": self.embedding_model_name,
"rerank": self.rerank_model_name,
}
def is_model_configured(self, model_type: str) -> bool:
"""Check if a specific model type is configured."""
if model_type == "llm":
return True
elif model_type == "embedding":
return True
elif model_type == "rerank":
return self.rerank_model_id is not None
else:
raise ValueError(f"Unknown model type: {model_type}")

View File

@@ -4,45 +4,44 @@ Memory Agent Service
Handles business logic for memory agent operations including read/write services,
health checks, and message type classification.
"""
import json
import os
import re
import time
import json
import uuid
from threading import Lock
from typing import Dict, List, Optional, Any, AsyncGenerator
from app.services.memory_konwledges_server import write_rag
from typing import Any, AsyncGenerator, Dict, List, Optional
import redis
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_mcp_adapters.tools import load_mcp_tools
from sqlalchemy.orm import Session
from sqlalchemy import func
from pydantic import BaseModel, Field
from app.core.config import settings
from app.core.logging_config import get_logger
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
from app.core.memory.agent.utils.type_classifier import status_typle
from app.db import get_db
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.schemas.memory_storage_schema import ApiResponse, ok, fail
from app.db import get_db
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.data_config_repository import DataConfigRepository
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.services.memory_konwledges_server import memory_konwledges_up, SimpleUser, find_document_id_by_kb_and_filename
from app.core.memory.utils.config.definitions import reload_configuration_from_database
from app.schemas.file_schema import CustomTextFileCreate
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
)
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_mcp_adapters.tools import load_mcp_tools
from pydantic import BaseModel, Field
from sqlalchemy import func
from sqlalchemy.orm import Session
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
logger = get_logger(__name__)
config_logger = get_config_logger()
# Initialize Neo4j connector for analytics functions
_neo4j_connector = Neo4jConnector()
@@ -56,6 +55,27 @@ class MemoryAgentService:
self.user_locks: Dict[str, Lock] = {}
self.locks_lock = Lock()
def load_memory_config(self, config_id: int) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
This method delegates to the centralized MemoryConfigService to avoid
code duplication with other services.
Args:
config_id: Configuration ID from database
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
"""
return MemoryConfigService.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
countext = re.findall(r'"status": "(.*?)",', messages)[0]
@@ -277,19 +297,20 @@ class MemoryAgentService:
import time
start_time = time.time()
# 如果 config_id 为 None使用默认值 "17"
config_loaded = reload_configuration_from_database(config_id)
if not config_loaded:
error_msg = f"Failed to load configuration for config_id: {config_id}"
# Load configuration from database only
try:
memory_config = self.load_memory_config(config_id)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# 记录失败的操作
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation( operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg )
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
logger.info(f"Configuration loaded successfully for config_id: {config_id}")
mcp_config = get_mcp_server_config()
client = MultiServerMCPClient(mcp_config)
@@ -300,20 +321,43 @@ class MemoryAgentService:
async with client.session("data_flow") as session:
logger.debug("Connected to MCP Server: data_flow")
tools = await load_mcp_tools(session)
workflow_errors = [] # Track errors from workflow
# Pass config_id to the graph workflow
async with make_write_graph(group_id, tools, group_id, group_id, config_id=config_id) as graph:
# Pass memory_config to the graph workflow
async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph:
logger.debug("Write graph created successfully")
config = {"configurable": {"thread_id": group_id}}
async for event in graph.astream(
{"messages": message, "config_id": config_id},
{"messages": message, "memory_config": memory_config, "errors": []},
stream_mode="values",
config=config
):
messages = event.get('messages')
return self.writer_messages_deal(messages,start_time,group_id,config_id,message)
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.error(f"Write workflow failed with errors: {error_details}")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="WRITE",
config_id=config_id,
group_id=group_id,
success=False,
duration=duration,
error=error_details
)
raise ValueError(f"Write workflow failed: {error_details}")
return self.writer_messages_deal(messages, start_time, group_id, config_id, message)
async def read_memory(
self,
@@ -365,15 +409,15 @@ class MemoryAgentService:
group_lock = self.get_group_lock(group_id)
with group_lock:
# Step 1: Load configuration from database
from app.core.memory.utils.config.definitions import reload_configuration_from_database
config_loaded = reload_configuration_from_database(config_id)
if not config_loaded:
error_msg = f"Failed to load configuration for config_id: {config_id}"
# Step 1: Load configuration from database only
try:
memory_config = self.load_memory_config(config_id)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# 记录失败的操作
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
@@ -387,8 +431,6 @@ class MemoryAgentService:
raise ValueError(error_msg)
logger.info(f"Configuration loaded successfully for config_id: {config_id}")
# Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
@@ -404,45 +446,52 @@ class MemoryAgentService:
intermediate_outputs = []
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
# Pass config_id to the graph workflow
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, config_id=config_id,storage_type=storage_type,user_rag_memory_id=user_rag_memory_id) as graph:
# Pass memory_config to the graph workflow
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
start = time.time()
config = {"configurable": {"thread_id": group_id}}
workflow_errors = [] # Track errors from workflow
async for event in graph.astream(
{"messages": history, "config_id": config_id},
{"messages": history, "memory_config": memory_config, "errors": []},
stream_mode="values",
config=config
):
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
for msg in messages:
msg_content = msg.content
msg_role = msg.__class__.__name__.lower().replace("message", "")
outputs.append({
"role": msg.__class__.__name__.lower().replace("message", ""),
"role": msg_role,
"content": msg_content
})
# Extract intermediate outputs
if hasattr(msg, 'content'):
try:
# Debug: log message type and content preview
msg_type = msg.__class__.__name__
content_preview = str(msg_content)[:200] if msg_content else "empty"
logger.debug(f"Processing message type={msg_type}, content preview={content_preview}")
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
content_to_parse = msg_content
if isinstance(msg_content, list):
for block in msg_content:
if isinstance(block, dict) and block.get('type') == 'text':
content_to_parse = block.get('text', '')
break
else:
continue # No text block found
# Try to parse content as JSON
if isinstance(msg_content, str):
if isinstance(content_to_parse, str):
try:
parsed = json.loads(msg_content)
parsed = json.loads(content_to_parse)
if isinstance(parsed, dict):
# Debug: log what keys are in parsed
logger.debug(f"Parsed dict keys: {list(parsed.keys())}")
# Check for single intermediate output
if '_intermediate' in parsed:
intermediate_data = parsed['_intermediate']
output_key = self._create_intermediate_key(intermediate_data)
logger.debug(f"Found _intermediate: {intermediate_data.get('type', 'unknown')}")
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
@@ -450,34 +499,14 @@ class MemoryAgentService:
# Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in parsed:
logger.debug(f"Found _intermediates list with {len(parsed['_intermediates'])} items")
for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
logger.debug(f"Processing intermediate: {intermediate_data.get('type', 'unknown')}")
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
except (json.JSONDecodeError, ValueError):
pass
elif isinstance(msg_content, dict):
# Check for single intermediate output
if '_intermediate' in msg_content:
intermediate_data = msg_content['_intermediate']
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
# Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in msg_content:
for intermediate_data in msg_content['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
@@ -489,18 +518,57 @@ class MemoryAgentService:
for messages in outputs:
if messages['role'] == 'tool':
message = messages['content']
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(message, list):
# Extract text from MCP content blocks
for block in message:
if isinstance(block, dict) and block.get('type') == 'text':
message = block.get('text', '')
break
else:
continue # No text block found
try:
message = json.loads(message) if isinstance(message, str) else message
if isinstance(message, dict) and message.get('status') != '':
summary_result = message.get('summary_result')
if summary_result:
final_answer = summary_result
parsed = json.loads(message) if isinstance(message, str) else message
if isinstance(parsed, dict):
if parsed.get('status') == 'success':
summary_result = parsed.get('summary_result')
if summary_result:
final_answer = summary_result
except (json.JSONDecodeError, ValueError):
pass
# 记录成功的操作
total_duration = time.time() - start_time
if audit_logger:
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}")
if audit_logger:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=False,
duration=total_duration,
error=error_details,
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer),
"errors": workflow_errors
}
)
# Raise error if no answer was produced
if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
@@ -612,19 +680,25 @@ class MemoryAgentService:
else:
return output
async def classify_message_type(self, message: str) -> Dict:
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
"""
Determine the type of user message (read or write)
Updated to eliminate global variables in favor of explicit parameters.
Args:
message: User message to classify
config_id: Configuration ID to load LLM model from database
db: Database session
Returns:
Type classification result
"""
logger.info("Classifying message type")
status = await status_typle(message)
# Load configuration to get LLM model ID
memory_config = self.load_memory_config(config_id)
status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}")
return status
@@ -790,7 +864,8 @@ class MemoryAgentService:
async def get_user_profile(
self,
end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None
current_user_id: Optional[str] = None,
llm_id: Optional[str] = None
) -> Dict[str, Any]:
"""
获取用户详情,包含:
@@ -801,6 +876,7 @@ class MemoryAgentService:
参数:
- end_user_id: 用户ID可选
- current_user_id: 当前登录用户的ID保留参数
- llm_id: LLM模型ID用于生成标签可选如果不提供则跳过标签生成
返回格式:
{
@@ -862,15 +938,17 @@ class MemoryAgentService:
await connector.close()
if not statements:
if not statements or not llm_id:
result["tags"] = []
if not llm_id and statements:
logger.warning("llm_id not provided, skipping tag generation")
else:
# 构建摘要文本
summary_text = f"用户语句样本:{' | '.join(statements[:20])}\n核心实体:{', '.join(entities)}"
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
# 使用LLM提取标签
llm_client = get_llm_client()
llm_client = get_llm_client(llm_id)
# 定义标签提取的结构
class UserTags(BaseModel):

View File

@@ -0,0 +1,264 @@
"""
Memory Configuration Service
Centralized configuration loading and management for memory services.
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
Database session management is handled internally.
"""
import time
from datetime import datetime
from app.core.logging_config import get_config_logger, get_logger
from app.core.validators.memory_config_validators import (
validate_and_resolve_model_id,
validate_embedding_model,
validate_model_exists_and_active,
)
from app.repositories.data_config_repository import DataConfigRepository
from app.schemas.memory_config_schema import (
ConfigurationError,
InvalidConfigError,
MemoryConfig,
ModelInactiveError,
ModelNotFoundError,
)
from sqlalchemy.orm import Session
logger = get_logger(__name__)
config_logger = get_config_logger()
def _validate_config_id(config_id):
"""Validate configuration ID format."""
if config_id is None:
raise InvalidConfigError(
"Configuration ID cannot be None",
field_name="config_id",
invalid_value=config_id,
)
if isinstance(config_id, int):
if config_id <= 0:
raise InvalidConfigError(
f"Configuration ID must be positive: {config_id}",
field_name="config_id",
invalid_value=config_id,
)
return config_id
if isinstance(config_id, str):
try:
parsed_id = int(config_id.strip())
if parsed_id <= 0:
raise InvalidConfigError(
f"Configuration ID must be positive: {parsed_id}",
field_name="config_id",
invalid_value=config_id,
)
return parsed_id
except ValueError as e:
raise InvalidConfigError(
f"Invalid configuration ID format: '{config_id}'",
field_name="config_id",
invalid_value=config_id,
)
raise InvalidConfigError(
f"Invalid type for configuration ID: expected int or str, got {type(config_id).__name__}",
field_name="config_id",
invalid_value=config_id,
)
class MemoryConfigService:
"""
Centralized service for memory configuration loading and validation.
This class provides a single implementation of configuration loading logic
that can be shared across multiple services, eliminating code duplication.
Database session management is handled internally.
"""
@staticmethod
def load_memory_config(
config_id: int,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
This method manages its own database session internally.
Args:
config_id: Configuration ID from database
service_name: Name of the calling service (for logging purposes)
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
"""
from app.db import get_db
db_gen = get_db()
db = next(db_gen)
try:
return MemoryConfigService._load_memory_config_with_db(
config_id=config_id,
db=db,
service_name=service_name,
)
finally:
db.close()
@staticmethod
def _load_memory_config_with_db(
config_id: int,
db: Session,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""Internal method that loads memory configuration with an existing db session."""
start_time = time.time()
config_logger.info(
"Starting memory configuration loading",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": config_id,
},
)
logger.info(f"Loading memory configuration from database: config_id={config_id}")
try:
validated_config_id = _validate_config_id(config_id)
result = DataConfigRepository.get_config_with_workspace(db, validated_config_id)
if not result:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Configuration not found in database",
extra={
"operation": "load_memory_config",
"config_id": validated_config_id,
"load_result": "not_found",
"elapsed_ms": elapsed_ms,
"service": service_name,
},
)
raise ConfigurationError(
f"Configuration {validated_config_id} not found in database"
)
memory_config, workspace = result
# Validate embedding model
embedding_uuid = validate_embedding_model(
validated_config_id,
memory_config.embedding_id,
db,
workspace.tenant_id,
workspace.id,
)
# Resolve LLM model
llm_uuid, llm_name = validate_and_resolve_model_id(
memory_config.llm_id,
"llm",
db,
workspace.tenant_id,
required=True,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Resolve optional rerank model
rerank_uuid = None
rerank_name = None
if memory_config.rerank_id:
rerank_uuid, rerank_name = validate_and_resolve_model_id(
memory_config.rerank_id,
"rerank",
db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Get embedding model name
embedding_name, _ = validate_model_exists_and_active(
embedding_uuid,
"embedding",
db,
workspace.tenant_id,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Create immutable MemoryConfig object
config = MemoryConfig(
config_id=memory_config.config_id,
config_name=memory_config.config_name,
workspace_id=workspace.id,
workspace_name=workspace.name,
tenant_id=workspace.tenant_id,
llm_model_id=llm_uuid,
llm_model_name=llm_name,
embedding_model_id=embedding_uuid,
embedding_model_name=embedding_name,
rerank_model_id=rerank_uuid,
rerank_model_name=rerank_name,
storage_type=workspace.storage_type or "neo4j",
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
reflexion_enabled=memory_config.enable_self_reflexion or False,
reflexion_iteration_period=int(memory_config.iteration_period or "3"),
reflexion_range=memory_config.reflexion_range or "retrieval",
reflexion_baseline=memory_config.baseline or "time",
loaded_at=datetime.now(),
)
elapsed_ms = (time.time() - start_time) * 1000
config_logger.info(
"Memory configuration loaded successfully",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": validated_config_id,
"config_name": config.config_name,
"workspace_id": str(config.workspace_id),
"load_result": "success",
"elapsed_ms": elapsed_ms,
},
)
logger.info(f"Memory configuration loaded successfully: {config.config_name}")
return config
except Exception as e:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Failed to load memory configuration",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": config_id,
"load_result": "error",
"error_type": type(e).__name__,
"error_message": str(e),
"elapsed_ms": elapsed_ms,
},
exc_info=True,
)
logger.error(f"Failed to load memory configuration {config_id}: {e}")
if isinstance(e, (ConfigurationError, ValueError)):
raise
else:
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")

View File

@@ -4,39 +4,40 @@ Memory Storage Service
Handles business logic for memory storage operations.
"""
from typing import Dict, List, Optional, Any, AsyncGenerator
import os
import json
import asyncio
import json
import os
import time
from datetime import datetime
from sqlalchemy.orm import Session
from typing import Any, AsyncGenerator, Dict, List, Optional
from dotenv import load_dotenv
from app.models.user_model import User
from app.models.end_user_model import EndUser
from app.core.logging_config import get_logger
from app.utils.sse_utils import format_sse_message
from app.schemas.memory_storage_schema import (
ConfigFilter,
ConfigPilotRun,
ConfigParamsCreate,
ConfigParamsDelete,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
ConfigKey,
)
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.memory_insight import MemoryInsight
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
from app.core.memory.analytics.user_summary import generate_user_summary
from app.models.end_user_model import EndUser
from app.models.user_model import User
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
from app.schemas.memory_storage_schema import (
ConfigFilter,
ConfigKey,
ConfigParamsCreate,
ConfigParamsDelete,
ConfigPilotRun,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
)
from app.services.memory_config_service import MemoryConfigService
from app.utils.sse_utils import format_sse_message
from dotenv import load_dotenv
from sqlalchemy.orm import Session
logger = get_logger(__name__)
config_logger = get_config_logger()
# Load environment variables for Neo4j connector
load_dotenv()
@@ -48,6 +49,27 @@ class MemoryStorageService:
def __init__(self):
logger.info("MemoryStorageService initialized")
def load_memory_config(self, config_id: int, db: Session) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
This method delegates to the centralized MemoryConfigService to avoid
code duplication with other services.
Args:
config_id: Configuration ID from database
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
"""
return MemoryConfigService.load_memory_config(
config_id=config_id,
service_name="MemoryStorageService"
)
async def get_storage_info(self) -> dict:
"""
@@ -248,7 +270,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
RuntimeError: 当管线执行失败时
"""
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
try:
# 发出初始进度事件
@@ -257,24 +278,12 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"time": int(time.time() * 1000)
})
# 步骤 1: 配置加载和验证(复用现有逻辑
# 步骤 1: 配置加载和验证(数据库优先
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
cid: Optional[str] = payload_cid if payload_cid else None
if not cid and os.path.isfile(dbrun_path):
try:
with open(dbrun_path, "r", encoding="utf-8") as f:
dbrun = json.load(f)
if isinstance(dbrun, dict):
sel = dbrun.get("selections", {})
if isinstance(sel, dict):
fallback_cid = str(sel.get("config_id") or "").strip()
cid = fallback_cid or None
except Exception:
cid = None
if not cid:
raise ValueError("未提供 payload.config_id且 dbrun.json 未设置 selections.config_id禁止启动试运行")
raise ValueError("未提供 payload.config_id禁止启动试运行")
# 验证 dialogue_text 必须提供
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
@@ -282,12 +291,15 @@ class DataConfigService: # 数据配置服务类PostgreSQL
if not dialogue_text:
raise ValueError("试运行模式必须提供 dialogue_text 参数")
# 应用内存覆写并刷新常量
from app.core.memory.utils.config.definitions import reload_configuration_from_database
ok_override = reload_configuration_from_database(cid)
if not ok_override:
raise RuntimeError("运行时覆写失败config_id 无效或刷新常量失败")
# Load configuration from database only using centralized manager
try:
memory_config = MemoryConfigService.load_memory_config(
config_id=int(cid),
service_name="MemoryStorageService.pilot_run_stream"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
raise RuntimeError(f"Configuration loading failed: {e}")
# 步骤 2: 创建进度回调函数捕获管线进度
# 使用队列在回调和生成器之间传递进度事件

View File

@@ -1,28 +1,30 @@
import os
import asyncio
from typing import Any, Dict, List, Optional
import requests
from datetime import datetime, timezone
import json
import os
import time
import uuid
from datetime import datetime, timezone
from math import ceil
import redis
import json
from typing import Any, Dict, List, Optional
from app.db import get_db
from app.models.document_model import Document
from app.models.knowledge_model import Knowledge
from app.core.rag.llm.cv_model import QWenCV
from app.core.rag.llm.chat_model import Base
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.rag.models.chunk import DocumentChunk
from app.services.memory_agent_service import MemoryAgentService
from app.core.config import settings
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
from app.core.rag.prompts.generator import question_proposal
import redis
import requests
# Import a unified Celery instance
from app.celery_app import celery_app
from app.core.config import settings
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.cv_model import QWenCV
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.prompts.generator import question_proposal
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchVectorFactory,
)
from app.db import get_db
from app.models.document_model import Document
from app.models.knowledge_model import Knowledge
from app.services.memory_agent_service import MemoryAgentService
@celery_app.task(name="tasks.process_item")
@@ -221,11 +223,17 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
except Exception as e:
except BaseException as e:
elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
else:
detailed_error = str(e)
return {
"status": "FAILURE",
"error": str(e),
"error": detailed_error,
"group_id": group_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
@@ -283,11 +291,17 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
except Exception as e:
except BaseException as e:
elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
else:
detailed_error = str(e)
return {
"status": "FAILURE",
"error": str(e),
"error": detailed_error,
"group_id": group_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
@@ -300,9 +314,10 @@ def reflection_engine() -> None:
Intentionally left blank; replace with real reflection logic later.
"""
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
import asyncio
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
asyncio.run(self_reflexion(host_id))
@@ -377,10 +392,10 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.services.memory_storage_service import search_all
from app.repositories.memory_increment_repository import write_memory_increment
from app.models.end_user_model import EndUser
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.repositories.memory_increment_repository import write_memory_increment
from app.services.memory_storage_service import search_all
db = next(get_db())
try: