Merge #44 into develop from refactor/memory-config-management

Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management

* refactor/memory-config-management: (7 commits)
  refactor(memory): restructure memory agent and config management
  Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management
  refactor(memory): restructure memory system and improve configuration management
  Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management
  refactor(memory): reorganize imports and move MemoryClientFactory to utils
  feat(memory): make config_id optional and improve configuration validation
  Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management

Signed-off-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/44
This commit is contained in:
朱文辉
2025-12-24 16:12:13 +08:00
89 changed files with 5044 additions and 4864 deletions

View File

@@ -1,9 +1,9 @@
import os
from datetime import timedelta
from urllib.parse import quote
from celery import Celery
from app.core.config import settings
from app.core.memory.utils.config.definitions import reload_configuration_from_database
from celery import Celery
# 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB 0
@@ -13,7 +13,6 @@ celery_app = Celery(
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
)
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
# 配置使用本地队列,避免与远程 worker 冲突
celery_app.conf.task_default_queue = 'localhost_test_wyl'
@@ -22,6 +21,7 @@ celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
# macOS 兼容性配置
import platform
if platform.system() == 'Darwin': # macOS
# 设置环境变量解决 fork 问题
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')

View File

@@ -10,22 +10,21 @@ Routes:
POST /emotion/suggestions - 获取个性化情绪建议
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.dependencies import get_current_user, get_db
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.schemas.emotion_schema import (
EmotionHealthRequest,
EmotionSuggestionsRequest,
EmotionTagsRequest,
EmotionWordcloudRequest,
EmotionHealthRequest,
EmotionSuggestionsRequest
)
from app.schemas.response_schema import ApiResponse
from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.core.logging_config import get_api_logger
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
# 获取API专用日志器
api_logger = get_api_logger()
@@ -211,13 +210,28 @@ async def get_emotion_suggestions(
"""
try:
# 验证 config_id如果提供
# 获取终端用户关联的配置
config_id = request.config_id
if config_id is not None:
from app.controllers.memory_agent_controller import validate_config_id
if config_id is None:
# 如果没有提供 config_id尝试获取用户关联的配置
try:
config_id = validate_config_id(config_id, db)
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(request.group_id, db)
config_id = connected_config.get("memory_config_id")
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
return fail(BizCode.INVALID_PARAMETER, "无法获取用户关联的配置", str(e))
else:
# 如果提供了 config_id验证其有效性
from app.services.memory_config_service import MemoryConfigService
try:
config_service = MemoryConfigService(db)
config = config_service.get_config_by_id(config_id)
if not config:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", f"配置 {config_id} 不存在")
except Exception as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID验证失败", str(e))
api_logger.info(
f"用户 {current_user.username} 请求获取个性化情绪建议",
@@ -230,7 +244,7 @@ async def get_emotion_suggestions(
# 调用服务层
data = await emotion_service.generate_emotion_suggestions(
end_user_id=request.group_id,
config_id=config_id
db=db
)
api_logger.info(

View File

@@ -1,36 +1,28 @@
import json
import time
from typing import Optional, List
from fastapi import APIRouter, Depends, Query, UploadFile
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.db import get_db
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.rag.llm.cv_model import QWenCV
from app.models import ModelApiKey, Knowledge
from app.services.memory_agent_service import MemoryAgentService
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
from typing import List, Optional
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.services import task_service, workspace_service
from app.core.logging_config import get_api_logger
from app.core.rag.llm.cv_model import QWenCV
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import cur_workspace_access_guard, get_current_user
from app.models import ModelApiKey
from app.models.user_model import User
from app.repositories import knowledge_repository
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse
from app.dependencies import get_current_user
from app.models.user_model import User
from fastapi import APIRouter, Depends, File, UploadFile, Form
from app.repositories import knowledge_repository
from app.services import task_service, workspace_service
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
import os
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
# 加载.env文件
load_dotenv()
# Get API logger
api_logger = get_api_logger()
# Initialize service
memory_agent_service = MemoryAgentService()
router = APIRouter(
@@ -39,95 +31,6 @@ router = APIRouter(
)
def validate_config_id(config_id: int, db: Session) -> int:
"""
Validate and ensure config_id is available, valid, and exists in database.
Args:
config_id: Configuration ID to validate
db: Database session for checking existence
Returns:
int: Validated config_id
Raises:
ValueError: If config_id is None, invalid, or doesn't exist in database
"""
if config_id is None:
api_logger.info("config_id is required but was not provided")
config_id = os.getenv('config_id')
if config_id is None:
raise ValueError("config_id is required but was not provided")
# Check if config exists in database
try:
from app.models.data_config_model import DataConfig
from app.models.models_model import ModelConfig
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if config is None:
error_msg = f"Configuration with config_id={config_id} does not exist in database"
api_logger.error(error_msg)
raise ValueError(error_msg)
# Validate llm_id exists and is usable
if config.llm_id:
try:
llm_config = db.query(ModelConfig).filter(ModelConfig.id == config.llm_id).first()
if llm_config is None:
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) does not exist"
api_logger.error(error_msg)
raise ValueError(error_msg)
if not llm_config.is_active:
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) is not active"
api_logger.error(error_msg)
raise ValueError(error_msg)
api_logger.debug(f"LLM validation successful: llm_id={config.llm_id}, name={llm_config.name}")
except ValueError:
raise
except Exception as e:
error_msg = f"Error validating LLM model: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
else:
api_logger.error(f"Config {config_id} has no llm_id set")
raise ValueError(f"Config {config_id} has no llm_id set")
# Validate embedding_id exists and is usable
if config.embedding_id:
try:
embedding_config = db.query(ModelConfig).filter(ModelConfig.id == config.embedding_id).first()
if embedding_config is None:
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) does not exist"
api_logger.error(error_msg)
raise ValueError(error_msg)
if not embedding_config.is_active:
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) is not active"
api_logger.error(error_msg)
raise ValueError(error_msg)
api_logger.debug(f"Embedding validation successful: embedding_id={config.embedding_id}, name={embedding_config.name}")
except ValueError:
raise
except Exception as e:
error_msg = f"Error validating embedding model: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
else:
api_logger.error(f"Config {config_id} has no embedding_id set")
raise ValueError(f"Config {config_id} has no embedding_id set")
api_logger.info(f"Config validation successful: config_id={config_id}, config_name={config.config_name}, llm_id={config.llm_id}, embedding_id={config.embedding_id}")
return config_id
except ValueError:
# Re-raise ValueError from above
raise
except Exception as e:
error_msg = f"Database error while validating config_id={config_id}: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
@router.get("/health/status", response_model=ApiResponse)
async def get_health_status(
current_user: User = Depends(get_current_user)
@@ -225,12 +128,7 @@ async def write_server(
Returns:
Response with write operation status
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
@@ -265,13 +163,20 @@ async def write_server(
result = await memory_agent_service.write_memory(
user_input.group_id,
user_input.message,
config_id,
config_id,
db,
storage_type,
user_rag_memory_id
)
return success(data=result, msg="写入成功")
except Exception as e:
api_logger.error(f"Write operation error: {str(e)}")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
@@ -292,12 +197,7 @@ async def write_server_async(
Task ID for tracking async operation
Use GET /memory/write_result/{task_id} to check task status and get result
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
@@ -352,12 +252,7 @@ async def read_server(
Returns:
Response with query answer
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
@@ -386,12 +281,19 @@ async def read_server(
user_input.history,
user_input.search_switch,
config_id,
db,
storage_type,
user_rag_memory_id
)
return success(data=result, msg="回复对话消息成功")
except Exception as e:
api_logger.error(f"Read operation error: {str(e)}")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
api_logger.error(f"Read operation error (TaskGroup): {detailed_error}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", detailed_error)
api_logger.error(f"Read operation error: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
@@ -456,12 +358,7 @@ async def read_server_async(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async read service: workspace_id={workspace_id}, config_id={config_id}")
@@ -653,6 +550,7 @@ async def get_write_task_result(
@router.post("/status_type", response_model=ApiResponse)
async def status_type(
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
@@ -666,7 +564,11 @@ async def status_type(
"""
api_logger.info(f"Status type check requested for group {user_input.group_id}")
try:
result = await memory_agent_service.classify_message_type(user_input.message)
result = await memory_agent_service.classify_message_type(
user_input.message,
user_input.config_id,
db
)
return success(data=result)
except Exception as e:
api_logger.error(f"Message type classification failed: {str(e)}")
@@ -741,6 +643,7 @@ async def get_hot_memory_tags_by_user_api(
@router.get("/analytics/user_profile", response_model=ApiResponse)
async def get_user_profile_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
@@ -764,7 +667,8 @@ async def get_user_profile_api(
try:
result = await memory_agent_service.get_user_profile(
end_user_id=end_user_id,
current_user_id=str(current_user.id)
current_user_id=str(current_user.id),
db=db
)
return success(data=result, msg="获取用户详情成功")
except Exception as e:
@@ -799,4 +703,41 @@ async def get_user_profile_api(
# )
# except Exception as e:
# api_logger.error(f"API docs retrieval failed: {str(e)}")
# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e))
# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e))
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
async def get_end_user_connected_config(
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取终端用户关联的记忆配置
通过以下流程获取配置:
1. 根据 end_user_id 获取用户的 app_id
2. 获取该应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id
Args:
end_user_id: 终端用户ID
Returns:
包含 memory_config_id 和相关信息的响应
"""
from app.services.memory_agent_service import (
get_end_user_connected_config as get_config,
)
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
try:
result = get_config(end_user_id, db)
return success(data=result, msg="获取终端用户关联配置成功")
except ValueError as e:
api_logger.warning(f"End user config not found: {str(e)}")
return fail(BizCode.NOT_FOUND, str(e))
except Exception as e:
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))

View File

@@ -1,22 +1,27 @@
import asyncio
import time
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from sqlalchemy import text
from app.core.logging_config import get_api_logger
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
ReflectionConfig,
ReflectionEngine,
)
from app.core.response_utils import success
from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionConfig, ReflectionEngine
from app.dependencies import get_current_user
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService
from app.schemas.memory_reflection_schemas import Memory_Reflection
from app.services.memory_reflection_service import (
MemoryReflectionService,
WorkspaceAppService,
)
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import text
from sqlalchemy.orm import Session
load_dotenv()
api_logger = get_api_logger()

View File

@@ -1,50 +1,50 @@
from typing import Optional
import datetime
import os
import uuid
import datetime
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from typing import Optional
from app.db import get_db
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.memory.utils.self_reflexion_utils import self_reflexion
from app.services.memory_storage_service import (
MemoryStorageService,
DataConfigService,
kb_type_distribution,
search_dialogue,
search_chunk,
search_statement,
search_entity,
search_all,
search_detials,
search_edges,
search_entity_graph,
analytics_hot_memory_tags,
analytics_recent_activity_stats,
)
from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import (
ConfigParamsCreate,
ConfigParamsDelete,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
ConfigKey,
ConfigPilotRun,
GenerateCacheRequest,
)
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.end_user_model import EndUser
from app.models.user_model import User
from app.schemas.end_user_schema import (
EndUserProfileResponse,
EndUserProfileUpdate,
)
from app.models.end_user_model import EndUser
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate,
ConfigParamsDelete,
ConfigPilotRun,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
GenerateCacheRequest,
)
from app.schemas.response_schema import ApiResponse
from app.services.memory_storage_service import (
DataConfigService,
MemoryStorageService,
analytics_hot_memory_tags,
analytics_recent_activity_stats,
kb_type_distribution,
search_all,
search_chunk,
search_detials,
search_dialogue,
search_edges,
search_entity,
search_entity_graph,
search_statement,
)
from app.services.user_memory_service import analytics_user_summary
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
# Get API logger
api_logger = get_api_logger()
@@ -335,8 +335,10 @@ async def pilot_run(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StreamingResponse:
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
api_logger.info(
f"Pilot run requested: config_id={payload.config_id}, "
f"dialogue_text_length={len(payload.dialogue_text)}"
)
svc = DataConfigService(db)
return StreamingResponse(
svc.pilot_run_stream(payload),
@@ -344,8 +346,8 @@ async def pilot_run(
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
"X-Accel-Buffering": "no",
},
)
"""
@@ -508,6 +510,20 @@ async def get_recent_activity_stats_api(
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
@router.get("/analytics/user_summary", response_model=ApiResponse)
async def get_user_summary_api(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"User summary requested for end_user_id: {end_user_id}")
try:
result = await analytics_user_summary(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"User summary failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户摘要生成失败", str(e))
@router.get("/self_reflexion")
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
"""

View File

@@ -9,18 +9,19 @@ LangChain Agent 封装
"""
import os
import time
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain_core.tools import BaseTool
from langchain.agents import create_agent
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.core.logging_config import get_business_logger
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
logger = get_business_logger()
@@ -198,10 +199,24 @@ class LangChainAgent:
"""
message_chat= message
start_time = time.time()
if config_id == None:
actual_config_id = os.getenv("config_id")
else:
actual_config_id = config_id
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.db import get_db
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
@@ -295,10 +310,24 @@ class LangChainAgent:
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
logger.info("=" * 80)
message_chat = message
if config_id == None:
actual_config_id = os.getenv("config_id")
else:
actual_config_id = config_id
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.db import get_db
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
history_term_memory = await self.term_memory_redis_read(end_user_id)
if memory_flag:

View File

@@ -1,7 +1,8 @@
import os
import json
import os
from pathlib import Path
from typing import Dict, Any, Optional
from typing import Any, Dict, Optional
from dotenv import load_dotenv
load_dotenv()
@@ -81,6 +82,7 @@ class Settings:
VOLC_QUERY_URL: str = os.getenv("VOLC_QUERY_URL", "https://openspeech.bytedance.com/api/v3/auc/bigmodel/query")
# Langfuse configuration
LANGFUSE_ENABLED: bool = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
LANGFUSE_PUBLIC_KEY: str = os.getenv("LANGFUSE_PUBLIC_KEY", "")
LANGFUSE_SECRET_KEY: str = os.getenv("LANGFUSE_SECRET_KEY", "")
LANGFUSE_HOST: str = os.getenv("LANGFUSE_HOST", "")
@@ -156,9 +158,6 @@ class Settings:
# Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
MEMORY_CONFIG_FILE: str = os.getenv("MEMORY_CONFIG_FILE", "config.json")
MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json")
MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json")
# Tool Management Configuration
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
@@ -181,65 +180,6 @@ class Settings:
return str(base_path / filename)
return str(base_path)
def get_memory_config_path(self, config_file: str = "") -> str:
"""
Get the full path for memory module configuration files.
Args:
config_file: Optional config filename (defaults to MEMORY_CONFIG_FILE)
Returns:
Full path to the config file
"""
if not config_file:
config_file = self.MEMORY_CONFIG_FILE
return str(Path(self.MEMORY_CONFIG_DIR) / config_file)
def load_memory_config(self) -> Dict[str, Any]:
"""
Load memory module configuration from config.json.
Returns:
Dictionary containing memory configuration
"""
config_path = self.get_memory_config_path(self.MEMORY_CONFIG_FILE)
try:
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory config file not found or malformed at {config_path}. Error: {e}")
return {}
def load_memory_runtime_config(self) -> Dict[str, Any]:
"""
Load memory module runtime configuration from runtime.json.
Returns:
Dictionary containing runtime configuration
"""
runtime_path = self.get_memory_config_path(self.MEMORY_RUNTIME_FILE)
try:
with open(runtime_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory runtime config not found or malformed at {runtime_path}. Error: {e}")
return {"selections": {}}
def load_memory_dbrun_config(self) -> Dict[str, Any]:
"""
Load memory module database run configuration from dbrun.json.
Returns:
Dictionary containing dbrun configuration
"""
dbrun_path = self.get_memory_config_path(self.MEMORY_DBRUN_FILE)
try:
with open(dbrun_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Warning: Memory dbrun config not found or malformed at {dbrun_path}. Error: {e}")
return {"selections": {}}
def ensure_memory_output_dir(self) -> None:
"""
Ensure the memory output directory exists.

View File

@@ -326,7 +326,7 @@ def log_prompt_rendering(prompt_type: str, content: str) -> None:
logger.info(log_message)
def log_template_rendering(template_name: str, context: dict | None = None) -> None:
def log_template_rendering(template_name: str, context: Optional[dict] = None) -> None:
"""Log template rendering information.
Logs the template name and context keys for debugging template rendering.
@@ -575,6 +575,43 @@ def get_named_logger(name: str) -> logging.Logger:
return get_agent_logger(name)
def get_config_logger() -> logging.Logger:
"""Get a specialized logger for memory configuration operations.
Returns a logger configured specifically for configuration loading, validation,
and model resolution operations with:
- Logger name: memory.config
- Output: Inherits from root logger (console + file)
- Level: Inherits from root logger
- Format: Standard format with timing information
This logger is optimized for configuration operations and includes
structured logging for timing, validation steps, and error context.
Returns:
Logger configured for memory configuration operations
Example:
>>> logger = get_config_logger()
>>> logger.info("Loading configuration", extra={
... "config_id": 123,
... "workspace_id": "uuid-here",
... "operation": "load_config"
... })
"""
# Ensure memory logging is initialized
if not LoggingConfig._memory_loggers_initialized:
LoggingConfig.setup_memory_logging()
# Get configuration logger with memory namespace
logger = logging.getLogger("memory.config")
# The logger automatically inherits handlers, formatters, and level from root logger
# through Python's logging hierarchy, so no additional configuration is needed
return logger
def get_memory_logger(name: Optional[str] = None) -> logging.Logger:
"""Get a standard logger for memory module components.

View File

@@ -9,11 +9,11 @@ import logging
import re
import uuid
from datetime import datetime
from typing import Dict, Any
from langchain_core.messages import AIMessage
from typing import Any, Dict
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
logger = logging.getLogger(__name__)
@@ -25,7 +25,8 @@ async def create_input_message(
search_switch: str,
apply_id: str,
group_id: str,
multimodal_processor: MultimodalProcessor
multimodal_processor: MultimodalProcessor,
memory_config: MemoryConfig,
) -> Dict[str, Any]:
"""
Create initial tool call message from user input.
@@ -46,6 +47,7 @@ async def create_input_message(
apply_id: Application identifier
group_id: Group identifier
multimodal_processor: Processor for handling image/audio inputs
memory_config: MemoryConfig object containing all configuration
Returns:
State update with AIMessage containing tool_call
@@ -53,7 +55,7 @@ async def create_input_message(
Examples:
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
>>> result = await create_input_message(
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor, config
... )
>>> result["messages"][0].tool_calls[0]["name"]
'Split_The_Problem'
@@ -123,20 +125,24 @@ async def create_input_message(
f"with ID: {tool_call_id}"
)
# Build tool arguments
tool_args = {
"sentence": last_message,
"sessionid": session_id,
"messages_id": str(uuid_str),
"search_switch": search_switch,
"apply_id": apply_id,
"group_id": group_id,
"memory_config": memory_config,
}
return {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": tool_name,
"args": {
"sentence": last_message,
"sessionid": session_id,
"messages_id": str(uuid_str),
"search_switch": search_switch,
"apply_id": apply_id,
"group_id": group_id
},
"args": tool_args,
"id": tool_call_id
}]
)

View File

@@ -9,14 +9,14 @@ import logging
import time
from typing import Any, Callable, Dict
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode
from app.core.memory.agent.langgraph_graph.state.extractors import (
extract_content_payload,
extract_tool_call_id,
extract_content_payload
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode
logger = logging.getLogger(__name__)
@@ -38,8 +38,9 @@ class ToolExecutionNode:
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
memory_config: MemoryConfig object containing all configuration
"""
def __init__(
self,
tool: Callable,
@@ -49,8 +50,9 @@ class ToolExecutionNode:
apply_id: str,
group_id: str,
parameter_builder: ParameterBuilder,
storage_type:str,
user_rag_memory_id:str
storage_type: str,
user_rag_memory_id: str,
memory_config: MemoryConfig,
):
"""
Initialize the tool execution node.
@@ -63,6 +65,9 @@ class ToolExecutionNode:
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
storage_type: Storage type for the workspace
user_rag_memory_id: User RAG memory identifier
memory_config: MemoryConfig object containing all configuration
"""
self.tool_node = ToolNode([tool])
self.id = node_id
@@ -72,9 +77,10 @@ class ToolExecutionNode:
self.apply_id = apply_id
self.group_id = group_id
self.parameter_builder = parameter_builder
self.storage_type=storage_type
self.user_rag_memory_id=user_rag_memory_id
self.storage_type = storage_type
self.user_rag_memory_id = user_rag_memory_id
self.memory_config = memory_config
logger.info(
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
)
@@ -124,8 +130,12 @@ class ToolExecutionNode:
# Extract content payload using state extractors
content = extract_content_payload(last_message)
logger.debug(
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}"
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}, content_keys: {list(content.keys()) if isinstance(content, dict) else 'N/A'}"
)
# Log raw message content for debugging
if hasattr(last_message, 'content'):
raw = last_message.content
logger.debug(f"[ToolExecutionNode] {self.id} - Raw message content (first 500 chars): {str(raw)[:500]}")
except Exception as e:
logger.error(
@@ -143,8 +153,9 @@ class ToolExecutionNode:
search_switch=self.search_switch,
apply_id=self.apply_id,
group_id=self.group_id,
memory_config=self.memory_config,
storage_type=self.storage_type,
user_rag_memory_id=self.user_rag_memory_id
user_rag_memory_id=self.user_rag_memory_id,
)
logger.debug(
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
@@ -179,7 +190,29 @@ class ToolExecutionNode:
f"[ToolExecutionNode] {self.id} - Tool execution completed"
)
# Return the result directly - it already contains the messages list
# Check for error in tool response
error_entry = None
if result and "messages" in result:
for msg in result["messages"]:
if hasattr(msg, 'content'):
try:
import json
content = msg.content
if isinstance(content, str):
parsed = json.loads(content)
if isinstance(parsed, dict) and "error" in parsed:
error_msg = parsed["error"]
logger.warning(
f"[ToolExecutionNode] {self.id} - Tool returned error: {error_msg}"
)
error_entry = {"tool": self.tool_name, "error": error_msg, "node_id": self.id}
except (json.JSONDecodeError, TypeError):
pass
# Return result with error tracking if error was found
if error_entry:
result["errors"] = [error_entry]
return result
except Exception as e:
@@ -187,13 +220,15 @@ class ToolExecutionNode:
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
exc_info=True
)
# Return error as ToolMessage to maintain message chain consistency
# Track error in state and return error message
from langchain_core.messages import ToolMessage
error_entry = {"tool": self.tool_name, "error": str(e), "node_id": self.id}
return {
"messages": [
ToolMessage(
content=f"Error executing tool: {str(e)}",
tool_call_id=f"{self.id}_{tool_call_id}"
)
]
],
"errors": [error_entry]
}

View File

@@ -1,38 +1,26 @@
import asyncio
import io
import json
import logging
import os
import re
import time
import uuid
import warnings
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Literal
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from functools import partial
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
from langgraph.checkpoint.memory import InMemorySaver
from app.core.memory.agent.utils.redis_tool import store
from app.core.logging_config import get_agent_logger
# Import new modular components
from app.core.memory.agent.langgraph_graph.nodes import ToolExecutionNode, create_input_message
from app.core.memory.agent.langgraph_graph.routing.routers import (
Verify_continue,
Retrieve_continue,
Split_continue
from app.core.memory.agent.langgraph_graph.nodes import (
ToolExecutionNode,
create_input_message,
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
from app.core.memory.agent.utils.llm_tools import COUNTState, ReadState
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
logger = get_agent_logger(__name__)
@@ -44,9 +32,9 @@ redisdb=os.getenv('REDISDB')
redispassword=os.getenv('REDISPASSWORD')
counter = COUNTState(limit=3)
# 在工作流中添加循环计数更新
# Update loop count in workflow
async def update_loop_count(state):
"""更新循环计数器"""
"""Update loop counter"""
current_count = state.get("loop_count", 0)
return {"loop_count": current_count + 1}
@@ -54,13 +42,13 @@ async def update_loop_count(state):
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
messages = state["messages"]
# 添加边界检查
# Add boundary check
if not messages:
return END
counter.add(1) # 累加 1
counter.add(1) # Increment by 1
loop_count = counter.get_total()
logger.debug(f"[should_continue] 当前循环次数: {loop_count}")
logger.debug(f"[should_continue] Current loop count: {loop_count}")
last_message = messages[-1]
last_message_str = str(last_message).replace('\\', '')
@@ -71,15 +59,15 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
counter.reset()
return "Summary"
elif "failed" in status_tools:
if loop_count < 2: # 最大循环次数 3
if loop_count < 2: # Maximum loop count is 3
return "content_input"
else:
counter.reset()
return "Summary_fails"
else:
# 添加默认返回值,避免返回 None
# Add default return value to avoid returning None
counter.reset()
return "Summary" # 或根据业务需求选择合适的默认值
return "Summary" # Default based on business requirements
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
@@ -115,8 +103,8 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
elif search_switch == '1':
return 'Retrieve_Summary'
# 添加默认返回值,避免返回 None
return 'Retrieve_Summary' # 或根据业务逻辑选择合适的默认值
# Add default return value to avoid returning None
return 'Retrieve_Summary' # Default based on business logic
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
@@ -151,46 +139,7 @@ def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
search_switch = str(search_switch)
if search_switch == '2':
return 'Input_Summary'
return 'Split_The_Problem' # 默认情况
# 在 input_sentence 函数中修改参数名称
async def input_sentence(state, name, id, search_switch,apply_id,group_id):
messages = state["messages"]
last_message = messages[-1].content if messages else ""
if last_message.endswith('.jpg') or last_message.endswith('.png'):
last_message=await picture_model_requests(last_message)
if any(last_message.endswith(ext) for ext in audio_extensions):
last_message=await Vico_recognition([last_message]).run()
logger.debug(f"Audio recognition result: {last_message}")
uuid_str = uuid.uuid4()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
namespace = str(id).split('_id_')[1]
if 'verified_data' in str(last_message):
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
last_message = re.findall(r'"query": "(.*?)",', str(messages_last))[0]
return {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": name,
"args": {
"sentence": last_message,
'sessionid': id,
'messages_id': str(uuid_str),
"search_switch": search_switch, # 正确地将 search_switch 放入 args 中
"apply_id":apply_id,
"group_id":group_id
},
"id": id + f'_{uuid_str}'
}]
)
]
}
return 'Split_The_Problem' # Default case
class ProblemExtensionNode:
@@ -208,30 +157,28 @@ class ProblemExtensionNode:
async def __call__(self, state):
messages = state["messages"]
last_message = messages[-1] if messages else ""
logger.debug(f"ProblemExtensionNode {self.id} - 当前时间: {time.time()} - Message: {last_message}")
if self.tool_name=='Input_Summary':
tool_call =re.findall("'id': '(.*?)'",str(last_message))[0]
else:tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
# try:
# content = json.loads(last_message.content) if hasattr(last_message, 'content') else last_message
# except:
# content = last_message.content if hasattr(last_message, 'content') else str(last_message)
# 尝试从上一工具的结果中提取实际的内容载荷(而不是整个对象的字符串表示)
logger.debug(f"ProblemExtensionNode {self.id} - Current time: {time.time()} - Message: {last_message}")
if self.tool_name == 'Input_Summary':
tool_call = re.findall("'id': '(.*?)'", str(last_message))[0]
else:
tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
# Try to extract actual content payload from previous tool result
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
extracted_payload = None
# 捕获 ToolMessage content 字段(支持单/双引号),并避免贪婪匹配
# Capture ToolMessage content field (supports single/double quotes), avoid greedy matching
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
if m:
extracted_payload = m.group(1)
else:
# 回退:直接尝试使用原始字符串
# Fallback: use raw string directly
extracted_payload = raw_msg
# 优先尝试将内容解析为 JSON
# Try to parse content as JSON first
try:
content = json.loads(extracted_payload)
except Exception:
# 尝试从文本中提取 JSON 片段再解析
# Try to extract JSON fragment from text and parse
parsed = None
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
for cand in candidates:
@@ -240,14 +187,14 @@ class ProblemExtensionNode:
break
except Exception:
continue
# 如果仍然失败,则以原始字符串作为内容
# If still fails, use raw string as content
content = parsed if parsed is not None else extracted_payload
# 根据工具名称构建正确的参数
# Build correct parameters based on tool name
tool_args = {}
if self.tool_name == "Verify":
# Verify工具需要contextusermessages参数
# Verify tool requires context and usermessages parameters
if isinstance(content, dict):
tool_args["context"] = content
else:
@@ -256,7 +203,7 @@ class ProblemExtensionNode:
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Retrieve":
# Retrieve工具需要contextusermessages参数
# Retrieve tool requires context and usermessages parameters
if isinstance(content, dict):
tool_args["context"] = content
else:
@@ -266,9 +213,9 @@ class ProblemExtensionNode:
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary":
# Summary工具需要字符串类型的context参数
# Summary tool requires string type context parameter
if isinstance(content, dict):
# 将字典转换为JSON字符串
# Convert dict to JSON string
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
@@ -276,24 +223,24 @@ class ProblemExtensionNode:
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary_fails":
# Summary工具需要字符串类型的context参数
# Summary_fails tool requires string type context parameter
if isinstance(content, dict):
# 将字典转换为JSON字符串
# Convert dict to JSON string
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name=='Input_Summary':
tool_args["context"] =str(last_message)
elif self.tool_name == 'Input_Summary':
tool_args["context"] = str(last_message)
tool_args["usermessages"] = str(tool_call)
tool_args["search_switch"] = str(self.search_switch)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
tool_args["storage_type"] = getattr(self, 'storage_type', "")
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
elif self.tool_name=='Retrieve_Summary' :
elif self.tool_name == 'Retrieve_Summary':
# Retrieve_Summary expects dict directly, not JSON string
# content might be a JSON string, try to parse it
if isinstance(content, str):
@@ -320,7 +267,7 @@ class ProblemExtensionNode:
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
else:
# 其他工具使用context参数
# Other tools use context parameter
if isinstance(content, dict):
tool_args["context"] = content
else:
@@ -349,12 +296,24 @@ class ProblemExtensionNode:
@asynccontextmanager
async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config_id=None,storage_type=None,user_rag_memory_id=None):
async def make_read_graph(namespace, tools, search_switch, apply_id, group_id, memory_config: MemoryConfig, storage_type=None, user_rag_memory_id=None):
"""
Create a read graph workflow for memory operations.
Args:
namespace: Namespace identifier
tools: MCP tools loaded from session
search_switch: Search mode switch ("0", "1", or "2")
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type (optional)
user_rag_memory_id: User RAG memory ID (optional)
"""
memory = InMemorySaver()
tool=[i.name for i in tools ]
tool = [i.name for i in tools]
logger.info(f"Initializing read graph with tools: {tool}")
if config_id:
logger.info(f"使用配置 ID: {config_id}")
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
# Extract tool functions
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
@@ -382,9 +341,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Retrieve_node = ToolExecutionNode(
tool=Retrieve_,
node_id="Retrieve_id",
@@ -394,9 +354,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Verify_node = ToolExecutionNode(
tool=Verify_,
node_id="Verify_id",
@@ -406,7 +367,8 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Summary_node = ToolExecutionNode(
@@ -418,9 +380,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Summary_fails_node = ToolExecutionNode(
tool=Summary_fails_,
node_id="Summary_fails_id",
@@ -430,9 +393,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Retrieve_Summary_node = ToolExecutionNode(
tool=Retrieve_Summary_,
node_id="Retrieve_Summary_id",
@@ -442,9 +406,10 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Input_Summary_node = ToolExecutionNode(
tool=Input_Summary_,
node_id="Input_Summary_id",
@@ -454,16 +419,16 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
async def content_input_node(state):
state_search_switch = state.get("search_switch", search_switch)
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
return await create_input_message(
state=state,
tool_name=tool_name,
@@ -471,7 +436,8 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
multimodal_processor=multimodal_processor
multimodal_processor=multimodal_processor,
memory_config=memory_config,
)
@@ -501,8 +467,3 @@ async def make_read_graph(namespace,tools,search_switch,apply_id,group_id,config
graph = workflow.compile(checkpointer=memory)
yield graph
# 添加到文件末尾或创建新的执行脚本
# 在 memory_agent_service.py 文件中添加以下函数

View File

@@ -128,6 +128,15 @@ def extract_content_payload(message: Any) -> Any:
# For ToolMessages (responses from tools), extract from content
if hasattr(message, "content"):
raw_content = message.content
logger.info(f"extract_content_payload: raw_content type={type(raw_content)}, value={str(raw_content)[:500]}")
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(raw_content, list):
for block in raw_content:
if isinstance(block, dict) and block.get('type') == 'text':
raw_content = block.get('text', '')
logger.info(f"extract_content_payload: extracted text from MCP format: {str(raw_content)[:300]}")
break
# If content is empty and this is an AIMessage with tool_calls,
# extract from args (this handles the initial tool call from content_input)
@@ -140,13 +149,16 @@ def extract_content_payload(message: Any) -> Any:
# If content is already a dict or list, return it directly
if isinstance(raw_content, (dict, list)):
logger.info(f"extract_content_payload: returning raw dict/list with keys={list(raw_content.keys()) if isinstance(raw_content, dict) else 'list'}")
return raw_content
# Try to parse as JSON
if isinstance(raw_content, str):
# First, try direct JSON parsing
try:
return json.loads(raw_content)
parsed = json.loads(raw_content)
logger.info(f"extract_content_payload: parsed JSON, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
return parsed
except (json.JSONDecodeError, ValueError):
pass
@@ -156,9 +168,12 @@ def extract_content_payload(message: Any) -> Any:
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
for candidate in json_candidates:
try:
return json.loads(candidate)
parsed = json.loads(candidate)
logger.info(f"extract_content_payload: parsed JSON from candidate, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
return parsed
except (json.JSONDecodeError, ValueError):
continue
# If all parsing attempts fail, return the raw content
logger.info(f"extract_content_payload: returning raw content (parsing failed)")
return raw_content

View File

@@ -1,116 +1,71 @@
import asyncio
import json
from contextlib import asynccontextmanager
from langgraph.constants import START, END
from langgraph.graph import add_messages, StateGraph
from langgraph.prebuilt import ToolNode
from app.core.memory.agent.utils.llm_tools import WriteState
import warnings
import sys
from langchain_core.messages import AIMessage
import warnings
from contextlib import asynccontextmanager
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
@asynccontextmanager
async def make_write_graph(user_id, tools, apply_id, group_id, config_id=None):
logger.info("加载 MCP 工具: %s", [t.name for t in tools])
if config_id:
logger.info(f"使用配置 ID: {config_id}")
data_type_tool = next((t for t in tools if t.name == "Data_type_differentiation"), None)
@asynccontextmanager
async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: MemoryConfig):
"""
Create a write graph workflow for memory operations.
Args:
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
logger.info("Loading MCP tools: %s", [t.name for t in tools])
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
if not data_type_tool or not data_write_tool:
logger.error('不存在数据存储工具', exc_info=True)
raise ValueError('不存在数据存储工具')
# ToolNode
write_node = ToolNode([data_write_tool])
if not data_write_tool:
logger.error("Data_write tool not found", exc_info=True)
raise ValueError("Data_write tool not found")
write_node = ToolNode([data_write_tool])
async def call_model(state):
messages = state["messages"]
last_message = messages[-1]
content = last_message[1] if isinstance(last_message, tuple) else last_message.content
# 调用 Data_type_differentiation 工具
try:
raw_result = await data_type_tool.ainvoke({
"context": last_message[1] if isinstance(last_message, tuple) else last_message.content
})
# MCP工具返回的是列表格式需要提取内容
logger.debug(f"Data_type_differentiation raw result type: {type(raw_result)}, value: {raw_result}")
# 处理不同的返回格式
if isinstance(raw_result, list) and len(raw_result) > 0:
# MCP工具返回格式: [{"type": "text", "text": "..."}]
result_text = raw_result[0].get("text", "{}") if isinstance(raw_result[0], dict) else str(raw_result[0])
elif isinstance(raw_result, str):
result_text = raw_result
else:
result_text = str(raw_result)
# 解析JSON字符串
try:
result = json.loads(result_text)
except json.JSONDecodeError as je:
logger.error(f"Failed to parse result as JSON: {result_text}, error: {je}")
return {"messages": [AIMessage(content=json.dumps({
"status": "error",
"message": f"Invalid JSON response from Data_type_differentiation: {str(je)}"
}))]}
# 检查是否有错误
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("message", "Unknown error in Data_type_differentiation")
logger.error(f"Data_type_differentiation 返回错误: {error_msg}")
return {"messages": [AIMessage(content=json.dumps({
"status": "error",
"message": error_msg
}))]}
except Exception as e:
logger.error(f"调用 Data_type_differentiation 失败: {e}", exc_info=True)
return {"messages": [AIMessage(content=json.dumps({
"status": "error",
"message": f"Data type differentiation failed: {str(e)}"
}))]}
# 调用 Data_write传递 config_id
# Call Data_write directly with memory_config
write_params = {
"content": result.get("context", last_message.content if hasattr(last_message, 'content') else str(last_message)),
"content": content,
"apply_id": apply_id,
"group_id": group_id,
"user_id": user_id
"user_id": user_id,
"memory_config": memory_config,
}
# 如果提供了 config_id添加到参数中
if config_id:
write_params["config_id"] = config_id
logger.debug(f"传递 config_id 到 Data_write: {config_id}")
try:
write_result = await data_write_tool.ainvoke(write_params)
logger.debug(f"Passing memory_config to Data_write: {memory_config.config_id}")
if isinstance(write_result, dict):
content = write_result.get("data", str(write_result))
else:
content = str(write_result)
logger.info("写入内容: %s", content)
return {"messages": [AIMessage(content=content)]}
except Exception as e:
logger.error(f"调用 Data_write 失败: {e}", exc_info=True)
return {"messages": [AIMessage(content=json.dumps({
"status": "error",
"message": f"Data write failed: {str(e)}"
}))]}
write_result = await data_write_tool.ainvoke(write_params)
if isinstance(write_result, dict):
result_content = write_result.get("data", str(write_result))
else:
result_content = str(write_result)
logger.info("Write content: %s", result_content)
return {"messages": [AIMessage(content=result_content)]}
workflow = StateGraph(WriteState)
workflow.add_node("content_input", call_model)

View File

@@ -10,19 +10,19 @@ Package structure:
- models: Pydantic response models
- services: Business logic services
"""
from app.core.memory.agent.mcp_server.server import (
mcp,
initialize_context,
main,
get_context_resource
)
# from app.core.memory.agent.mcp_server.server import (
# mcp,
# initialize_context,
# main,
# get_context_resource
# )
# Import tools to register them (but don't export them)
from app.core.memory.agent.mcp_server import tools
# # Import tools to register them (but don't export them)
# from app.core.memory.agent.mcp_server import tools
__all__ = [
'mcp',
'initialize_context',
'main',
'get_context_resource',
]
# __all__ = [
# 'mcp',
# 'initialize_context',
# 'main',
# 'get_context_resource',
# ]

View File

@@ -6,19 +6,15 @@ in the context for dependency injection into tool functions.
"""
import os
import sys
from mcp.server.fastmcp import FastMCP
from app.core.config import settings
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.redis_tool import RedisSessionStore, store
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID,reload_configuration_from_database
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.services.search_service import SearchService
from app.core.memory.agent.mcp_server.services.session_service import SessionService
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import store
logger = get_agent_logger(__name__)
@@ -78,17 +74,11 @@ def initialize_context():
logger.info("Registering session_store in context")
mcp.session_store = store
# Register LLM client
try:
logger.info(f"Registering llm_client in context with model ID: {SELECTED_LLM_ID}")
llm_client = get_llm_client(SELECTED_LLM_ID)
mcp.llm_client = llm_client
logger.info("llm_client registered successfully")
except Exception as e:
logger.error(f"Failed to register llm_client: {e}", exc_info=True)
# 注册一个 None 值,避免工具调用时找不到资源
mcp.llm_client = None
logger.warning("llm_client set to None due to initialization failure")
# Note: LLM client is NOT loaded at server startup
# It should be loaded dynamically when needed, with config_id passed explicitly
# to make_write_graph or make_read_graph functions
logger.info("LLM client will be loaded dynamically with config_id when needed")
mcp.llm_client = None # Placeholder - actual client loaded per-request with config_id
# Register application settings (renamed to avoid conflict with FastMCP's settings)
logger.info("Registering app_settings in context")
@@ -124,26 +114,20 @@ def main():
Initializes context and starts the server with SSE transport.
"""
try:
# logger.info("Starting MCP server initialization")
reload_configuration_from_database(config_id=os.getenv("config_id"), force_reload=True)
logger.info("Starting MCP server initialization")
# Initialize context resources
initialize_context()
# Import and register tools
# logger.info("Importing MCP tools")
from app.core.memory.agent.mcp_server.tools import (
# Import and register tools (imports trigger tool registration)
from app.core.memory.agent.mcp_server.tools import ( # noqa: F401
data_tools,
problem_tools,
retrieval_tools,
verification_tools,
summary_tools,
data_tools
verification_tools,
)
# logger.info("All MCP tools imported and registered")
# Log registered tools for debugging
import asyncio
tools_list = asyncio.run(mcp.list_tools())
# logger.info(f"Registered {len(tools_list)} MCP tools: {[t.name for t in tools_list]}")
# Tools are registered via imports above
# Get MCP port from environment (default: 8081)
mcp_port = int(os.getenv("MCP_PORT", "8081"))

View File

@@ -4,22 +4,22 @@ Parameter Builder for constructing tool call arguments.
This service provides tool-specific parameter transformation logic
to build correct arguments for each tool type.
"""
import json
from typing import Any, Dict, Optional
from app.core.logging_config import get_agent_logger
from app.schemas.memory_config_schema import MemoryConfig
logger = get_agent_logger(__name__)
class ParameterBuilder:
"""Service for building tool call arguments based on tool type."""
def __init__(self):
"""Initialize the parameter builder."""
logger.info("ParameterBuilder initialized")
def build_tool_args(
self,
tool_name: str,
@@ -28,8 +28,9 @@ class ParameterBuilder:
search_switch: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Build tool arguments based on tool type.
@@ -48,6 +49,7 @@ class ParameterBuilder:
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
@@ -58,18 +60,19 @@ class ParameterBuilder:
base_args = {
"usermessages": tool_call_id,
"apply_id": apply_id,
"group_id": group_id
"group_id": group_id,
"memory_config": memory_config,
}
# Always add storage_type and user_rag_memory_id (with defaults if None)
base_args["storage_type"] = storage_type if storage_type is not None else ""
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
# Tool-specific argument construction
if tool_name in ["Verify","Summary", "Summary_fails",'Retrieve_Summary']:
# Verify expects dict context
if tool_name in ["Verify", "Summary", "Summary_fails", "Retrieve_Summary", "Problem_Extension"]:
# These tools expect dict context
return {
"context": content if isinstance(content, dict) else {},
"context": content if isinstance(content, dict) else {"content": content},
**base_args
}

View File

@@ -4,21 +4,31 @@ Search Service for executing hybrid search and processing results.
This service provides clean search result processing with content extraction
and deduplication.
"""
from typing import List, Tuple, Optional
from typing import TYPE_CHECKING, List, Optional, Tuple
from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
logger = get_agent_logger(__name__)
class SearchService:
"""Service for executing hybrid search and processing results."""
def __init__(self):
"""Initialize the search service."""
def __init__(self, memory_config: "MemoryConfig" = None):
"""
Initialize the search service.
Args:
memory_config: Optional MemoryConfig for embedding model configuration.
If not provided, must be passed to execute_hybrid_search.
"""
self.memory_config = memory_config
logger.info("SearchService initialized")
def extract_content_from_result(self, result: dict) -> str:
@@ -93,12 +103,13 @@ class SearchService:
self,
group_id: str,
question: str,
limit: int = 5,
limit: int = 15,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False
return_raw_results: bool = False,
memory_config: "MemoryConfig" = None,
) -> Tuple[str, str, Optional[dict]]:
"""
Execute hybrid search and return clean content.
@@ -112,6 +123,7 @@ class SearchService:
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
output_path: Path to save search results (default: "search_results.json")
return_raw_results: If True, also return the raw search results as third element (default: False)
memory_config: MemoryConfig object for embedding model. Falls back to self.memory_config if not provided.
Returns:
Tuple of (clean_content, cleaned_query, raw_results)
@@ -119,12 +131,17 @@ class SearchService:
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
# Use provided memory_config or fall back to instance config
config = memory_config or self.memory_config
if not config:
raise ValueError("memory_config is required for search - either pass it to __init__ or execute_hybrid_search")
# Clean query
cleaned_query = self.clean_query(question)
try:
# Execute search
# Execute search using memory_config
answer = await run_hybrid_search(
query_text=cleaned_query,
search_type=search_type,
@@ -132,7 +149,8 @@ class SearchService:
limit=limit,
include=include,
output_path=output_path,
rerank_alpha=rerank_alpha
memory_config=config,
rerank_alpha=rerank_alpha,
)
# Extract results based on search type and include parameter

View File

@@ -3,16 +3,20 @@ Data Tools for data type differentiation and writing.
This module contains MCP tools for distinguishing data types and writing data.
"""
import os
from mcp.server.fastmcp import Context
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.models.retrieval_models import (
DistinguishTypeResponse,
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.retrieval_models import DistinguishTypeResponse
from app.core.memory.agent.utils.write_tools import write
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@@ -20,7 +24,8 @@ logger = get_agent_logger(__name__)
@mcp.tool()
async def Data_type_differentiation(
ctx: Context,
context: str
context: str,
memory_config: MemoryConfig,
) -> dict:
"""
Distinguish the type of data (read or write).
@@ -28,6 +33,7 @@ async def Data_type_differentiation(
Args:
ctx: FastMCP context for dependency injection
context: Text to analyze for type differentiation
memory_config: MemoryConfig object containing LLM configuration
Returns:
dict: Contains 'context' with the original text and 'type' field
@@ -35,7 +41,11 @@ async def Data_type_differentiation(
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
llm_client = get_context_resource(ctx, 'llm_client')
# Get LLM client from memory_config using factory pattern
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Render template
try:
@@ -53,7 +63,7 @@ async def Data_type_differentiation(
"type": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
@@ -98,7 +108,7 @@ async def Data_write(
user_id: str,
apply_id: str,
group_id: str,
config_id: str
memory_config: MemoryConfig,
) -> dict:
"""
Write data to the database/file system.
@@ -109,7 +119,7 @@ async def Data_write(
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
config_id: Configuration ID for processing (optional, integer)
memory_config: MemoryConfig object containing all configuration
Returns:
dict: Contains 'status', 'saved_to', and 'data' fields
@@ -118,32 +128,28 @@ async def Data_write(
# Ensure output directory exists
os.makedirs("data_output", exist_ok=True)
file_path = os.path.join("data_output", "user_data.csv")
# Write data using utility function
try:
await write(content, user_id, apply_id, group_id, config_id=config_id)
logger.info(f"写入成功Config ID: {config_id if config_id else 'None'}")
return {
"status": "success",
"saved_to": file_path,
"data": content,
"config_id": config_id
}
except Exception as e:
logger.error(f"写入失败: {e}", exc_info=True)
return {
"status": "error",
"message": str(e)
}
except Exception as e:
logger.error(
f"Data_write failed: {e}",
exc_info=True
# Write data - clients are constructed inside write() from memory_config
await write(
content=content,
user_id=user_id,
apply_id=apply_id,
group_id=group_id,
memory_config=memory_config,
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
return {
"status": "success",
"saved_to": file_path,
"data": content,
"config_id": memory_config.config_id,
"config_name": memory_config.config_name,
}
except Exception as e:
logger.error(f"Data_write failed: {e}", exc_info=True)
return {
"status": "error",
"message": str(e)
"message": str(e),
}

View File

@@ -2,25 +2,24 @@
Problem Tools for question segmentation and extension.
This module contains MCP tools for breaking down and extending user questions.
LLM clients are constructed from MemoryConfig when needed.
"""
import json
import time
from typing import List
from pydantic import BaseModel, Field, RootModel
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.problem_models import (
ProblemBreakdownItem,
ProblemBreakdownResponse,
ExtendedQuestionItem,
ProblemExtensionResponse
ProblemExtensionResponse,
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@@ -32,7 +31,8 @@ async def Split_The_Problem(
sessionid: str,
messages_id: str,
apply_id: str,
group_id: str
group_id: str,
memory_config: MemoryConfig,
) -> dict:
"""
Segment the dialogue or sentence into sub-problems.
@@ -44,17 +44,22 @@ async def Split_The_Problem(
messages_id: Message identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
Returns:
dict: Contains 'context' (JSON string of split results) and 'original' sentence
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Extract user ID from session
user_id = session_service.resolve_user_id(sessionid)
@@ -116,8 +121,8 @@ async def Split_The_Problem(
)
split_result = json.dumps([], ensure_ascii=False)
logger.info("问题拆分")
logger.info(f"问题拆分结果==>>:{split_result}")
logger.info("Problem splitting")
logger.info(f"Problem split result: {split_result}")
# Emit intermediate output for frontend
result = {
@@ -150,7 +155,7 @@ async def Split_The_Problem(
duration = end - start
except Exception:
duration = 0.0
log_time('问题拆分', duration)
log_time('Problem splitting', duration)
@mcp.tool()
@@ -160,8 +165,9 @@ async def Problem_Extension(
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
user_rag_memory_id: str = "",
) -> dict:
"""
Extend the problem with additional sub-questions.
@@ -172,6 +178,7 @@ async def Problem_Extension(
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
@@ -179,12 +186,16 @@ async def Problem_Extension(
dict: Contains 'context' (aggregated questions) and 'original' question
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID from usermessages
from app.core.memory.agent.utils.messages_tool import Resolve_username
@@ -250,8 +261,8 @@ async def Problem_Extension(
)
aggregated_dict = {}
logger.info("问题扩展")
logger.info(f"问题扩展==>>:{aggregated_dict}")
logger.info("Problem extension")
logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend
result = {
@@ -290,4 +301,4 @@ async def Problem_Extension(
duration = end - start
except Exception:
duration = 0.0
log_time('问题扩展', duration)
log_time('Problem extension', duration)

View File

@@ -3,25 +3,24 @@ Retrieval Tools for database and context retrieval.
This module contains MCP tools for retrieving data using hybrid search.
"""
from dotenv import load_dotenv
import os
from app.core.rag.nlp.search import knowledge_retrieval
# 加载.env文件
load_dotenv()
import time
from typing import List
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.llm_tools import deduplicate_entries, merge_to_key_value_pairs
from app.core.memory.agent.utils.llm_tools import (
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
from app.core.rag.nlp.search import knowledge_retrieval
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
from mcp.server.fastmcp import Context
load_dotenv()
logger = get_agent_logger(__name__)
@@ -32,8 +31,9 @@ async def Retrieve(
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
user_rag_memory_id: str = "",
) -> dict:
"""
Retrieve data from the database using hybrid search.
@@ -44,6 +44,7 @@ async def Retrieve(
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
@@ -66,6 +67,7 @@ async def Retrieve(
}
start = time.time()
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
logger.info(f"Retrieve: context type={type(context)}, context={str(context)[:500]}")
try:
# Extract services from context
@@ -77,7 +79,13 @@ async def Retrieve(
if isinstance(context, dict):
# Process dict context with extended questions
all_items = []
logger.info(f"Retrieve: context keys={list(context.keys())}")
content, original = await Retriev_messages_deal(context)
logger.info(f"Retrieve: after Retriev_messages_deal - content_type={type(content)}, content={str(content)[:300]}")
logger.info(f"Retrieve: original='{original[:100] if original else 'EMPTY'}'")
if not original:
logger.warning(f"Retrieve: original query is empty! context={context}")
# Extract all query items from content
# content is like {original_question: [extended_questions...], ...}
@@ -113,9 +121,11 @@ async def Retrieve(
clean_content = ''
raw_results=''
cleaned_query = question
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
databases_anser.append({
"Query_small": cleaned_query,
@@ -206,9 +216,11 @@ async def Retrieve(
clean_content = ''
raw_results = ''
cleaned_query = query
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(**search_params)
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
# Keep structure for Verify/Retrieve_Summary compatibility
dup_databases = {
"Query": cleaned_query,
@@ -236,7 +248,7 @@ async def Retrieve(
}
logger.info(
f"检索==>>:{storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
f"Retrieval: {storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
)
@@ -279,4 +291,4 @@ async def Retrieve(
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration)
log_time('Retrieval', duration)

View File

@@ -2,33 +2,32 @@
Summary Tools for data summarization.
This module contains MCP tools for summarizing retrieved data and generating responses.
LLM clients are constructed from MemoryConfig when needed.
"""
import json
import os
import re
import time
from typing import List
from pydantic import BaseModel, Field
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.mcp_server.models.summary_models import (
SummaryData,
RetrieveSummaryResponse,
SummaryResponse,
RetrieveSummaryData,
RetrieveSummaryResponse
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.messages_tool import (
Resolve_username,
Summary_messages_deal,
Resolve_username
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
import os
from mcp.server.fastmcp import Context
# 加载.env文件
load_dotenv()
logger = get_agent_logger(__name__)
@@ -40,8 +39,9 @@ async def Summary(
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
user_rag_memory_id: str = "",
) -> dict:
"""
Summarize the verified data.
@@ -52,6 +52,7 @@ async def Summary(
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
@@ -59,12 +60,16 @@ async def Summary(
dict: Contains 'status' and 'summary_result'
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages)
@@ -155,7 +160,7 @@ async def Summary(
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"验证之后的总结==>>:{aimessages}")
logger.info(f"Summary after verification: {aimessages}")
# Log execution time
end = time.time()
@@ -163,7 +168,7 @@ async def Summary(
duration = end - start
except Exception:
duration = 0.0
log_time('总结', duration)
log_time('Summary', duration)
return {
"status": "success",
@@ -180,8 +185,9 @@ async def Retrieve_Summary(
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
user_rag_memory_id: str = "",
) -> dict:
"""
Summarize data directly from retrieval results.
@@ -192,6 +198,7 @@ async def Retrieve_Summary(
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
@@ -202,9 +209,13 @@ async def Retrieve_Summary(
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages)
@@ -212,6 +223,8 @@ async def Retrieve_Summary(
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
logger.debug(f"Retrieve_Summary: raw context type={type(context)}, keys={list(context.keys()) if isinstance(context, dict) else 'N/A'}")
if isinstance(context, dict):
if "content" in context:
inner = context["content"]
@@ -252,17 +265,19 @@ async def Retrieve_Summary(
query = context_dict.get("Query", "")
expansion_issue = context_dict.get("Expansion_issue", [])
logger.debug(f"Retrieve_Summary: query='{query}', expansion_issue count={len(expansion_issue)}")
logger.debug(f"Retrieve_Summary: expansion_issue={expansion_issue[:2] if expansion_issue else 'empty'}")
# Extract retrieve_info from expansion_issue
retrieve_info = []
for item in expansion_issue:
# Check for both Answer_Small and Answer_Samll (typo) for backward compatibility
# Check for both Answer_Small and Answer_Small (typo) for backward compatibility
answer = None
if isinstance(item, dict):
if "Answer_Small" in item:
answer = item["Answer_Small"]
elif "Answer_Samll" in item:
answer = item["Answer_Samll"]
if answer is not None:
# Handle both string and list formats
@@ -350,7 +365,7 @@ async def Retrieve_Summary(
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"检索之后的总结==>>:{aimessages}")
logger.info(f"Summary after retrieval: {aimessages}")
# Log execution time
end = time.time()
@@ -358,7 +373,7 @@ async def Retrieve_Summary(
duration = end - start
except Exception:
duration = 0.0
log_time('检索总结', duration)
log_time('Retrieval summary', duration)
# Emit intermediate output for frontend
return {
@@ -384,8 +399,9 @@ async def Input_Summary(
search_switch: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
user_rag_memory_id: str = "",
) -> dict:
"""
Generate a quick summary for direct input without verification.
@@ -397,6 +413,7 @@ async def Input_Summary(
search_switch: Search switch value for routing ('2' for summaries only)
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
@@ -406,21 +423,16 @@ async def Input_Summary(
start = time.time()
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# Initialize variables to avoid UnboundLocalError
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
session_service = get_context_resource(ctx, 'session_service')
llm_client = get_context_resource(ctx, 'llm_client')
search_service = get_context_resource(ctx, 'search_service')
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
search_service = get_context_resource(ctx, "search_service")
# Check if llm_client is None
if llm_client is None:
error_msg = "LLM client is not available. Please check server configuration and SELECTED_LLM_ID environment variable."
logger.error(error_msg)
return error_msg
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages) or ""
@@ -479,7 +491,7 @@ async def Input_Summary(
# Add storage-specific parameters
'''检索'''
# Retrieval
if search_switch == '2':
search_params["include"] = ["summaries"]
if storage_type == "rag" and user_rag_memory_id:
@@ -509,12 +521,16 @@ async def Input_Summary(
except:
retrieve_info=''
raw_results=['']
logger.info(f"知识库没有检索的内容{user_rag_memory_id}")
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
logger.info("Input_Summary: 使用 summary 进行检索")
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
logger.info("Input_Summary: Using summary for retrieval")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(**search_params)
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
except Exception as e:
logger.error(
@@ -547,7 +563,7 @@ async def Input_Summary(
)
aimessages = "信息不足,无法回答"
logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
logger.info(f"Quick answer summary: {storage_type}--{user_rag_memory_id}--{aimessages}")
# Emit intermediate output for frontend
return {
@@ -587,7 +603,7 @@ async def Input_Summary(
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration)
log_time('Retrieval', duration)
@mcp.tool()

View File

@@ -5,20 +5,19 @@ This module contains MCP tools for verifying retrieved data.
"""
import time
from jinja2 import Template
from mcp.server.fastmcp import Context
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.verify_tool import VerifyTool
from app.core.memory.agent.utils.messages_tool import (
Verify_messages_deal,
Retrieve_verify_tool_messages_deal,
Resolve_username
)
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import (
Resolve_username,
Retrieve_verify_tool_messages_deal,
Verify_messages_deal,
)
from app.core.memory.agent.utils.verify_tool import VerifyTool
from app.schemas.memory_config_schema import MemoryConfig
from jinja2 import Template
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@@ -30,6 +29,7 @@ async def Verify(
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
@@ -42,6 +42,7 @@ async def Verify(
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
@@ -91,8 +92,12 @@ async def Verify(
# Call verification workflow
verify_tool = VerifyTool(system_prompt, messages)
# Call verification workflow with LLM model ID from memory_config
verify_tool = VerifyTool(
system_prompt=system_prompt,
verify_data=messages,
llm_model_id=str(memory_config.llm_model_id)
)
verify_result = await verify_tool.verify()
# Parse LLM verification result with error handling
@@ -118,7 +123,7 @@ async def Verify(
"history": history,
}
logger.info(f"验证==>>:{messages_deal}")
logger.info(f"Verification result: {messages_deal}")
# Emit intermediate output for frontend
return {
@@ -128,7 +133,7 @@ async def Verify(
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "verification",
"title": "数据验证",
"title": "Data Verification",
"result": messages_deal.get("split_result", "unknown"),
"reason": messages_deal.get("reason", ""),
"query": query,
@@ -166,4 +171,4 @@ async def Verify(
duration = end - start
except Exception:
duration = 0.0
log_time('验证', duration)
log_time('Verification', duration)

View File

@@ -1,22 +1,21 @@
import asyncio
import json
from collections import defaultdict
from typing import TypedDict, Annotated
import os
import logging
from jinja2 import Template
from langchain_core.messages import AnyMessage
from dotenv import load_dotenv
from langgraph.graph import add_messages
from openai import OpenAI
import os
from collections import defaultdict
from typing import Annotated, TypedDict
from app.core.memory.agent.utils.messages_tool import read_template_file
from app.core.memory.utils.config.config_utils import get_picture_config, get_voice_config
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
from app.core.models.base import RedBearModelConfig
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.config.config_utils import (
get_picture_config,
get_voice_config,
)
# Removed global variable imports - use dependency injection instead
from dotenv import load_dotenv
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from openai import OpenAI
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logger = logging.getLogger(__name__)
@@ -44,6 +43,7 @@ class WriteState(TypedDict):
user_id:str
apply_id:str
group_id:str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
class ReadState(TypedDict):
'''
@@ -53,6 +53,7 @@ class ReadState(TypedDict):
loop_count:Traverse times
search_switchtype
config_id: configuration id for filtering results
errors: list of errors that occurred during workflow execution
'''
messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息
name: str
@@ -63,6 +64,7 @@ class ReadState(TypedDict):
apply_id: str
group_id: str
config_id: str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
class COUNTState:
@@ -109,9 +111,17 @@ def deduplicate_entries(entries):
async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str:
async def Picture_recognize(image_path, PROMPT_TICKET_EXTRACTION, picture_model_name: str) -> str:
"""
Updated to eliminate global variables in favor of explicit parameters.
Args:
image_path: Path to image file
PROMPT_TICKET_EXTRACTION: Extraction prompt
picture_model_name: Picture model name (required, no longer from global variables)
"""
try:
model_config = get_picture_config(SELECTED_LLM_PICTURE_NAME)
model_config = get_picture_config(picture_model_name)
except Exception as e:
err = f"LLM配置不可用{str(e)}。请检查 config.json 和 runtime.json。"
logger.error(err)
@@ -147,9 +157,15 @@ async def Picture_recognize(image_path,PROMPT_TICKET_EXTRACTION) -> str:
picture_text = json.loads(picture_text)
return (picture_text['statement'])
async def Voice_recognize():
async def Voice_recognize(voice_model_name: str):
"""
Updated to eliminate global variables in favor of explicit parameters.
Args:
voice_model_name: Voice model name (required, no longer from global variables)
"""
try:
model_config = get_voice_config(SELECTED_LLM_VOICE_NAME)
model_config = get_voice_config(voice_model_name)
except Exception as e:
err = f"LLM配置不可用{str(e)}。请检查 config.json 和 runtime.json。"
logger.error(err)

View File

@@ -1,10 +1,10 @@
import json
import logging
import re
from typing import List, Any
from typing import Any, List
from langchain_core.messages import AnyMessage
from app.core.logging_config import get_agent_logger
from langchain_core.messages import AnyMessage
logger = get_agent_logger(__name__)
@@ -119,11 +119,23 @@ async def Problem_Extension_messages_deal(context):
extent_quest = []
original = context.get('original', '')
messages = context.get('context', '')
messages = json.loads(messages)
for message in messages:
question = message.get('question', '')
type = message.get('type', '')
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
# Handle empty or non-string messages
if not messages:
return extent_quest, original
if isinstance(messages, str):
try:
messages = json.loads(messages)
except json.JSONDecodeError:
# If JSON parsing fails, return empty list
return extent_quest, original
if isinstance(messages, list):
for message in messages:
question = message.get('question', '')
type = message.get('type', '')
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
return extent_quest, original
@@ -135,10 +147,19 @@ async def Retriev_messages_deal(context):
context:
Returns:
'''
logger.info(f"Retriev_messages_deal input: type={type(context)}, value={str(context)[:500]}")
if isinstance(context, dict):
logger.info(f"Retriev_messages_deal: context is dict with keys={list(context.keys())}")
if 'context' in context or 'original' in context:
return context.get('context', {}), context.get('original', '')
return content, original_value
content = context.get('context', {})
original = context.get('original', '')
logger.info(f"Retriev_messages_deal output: content_type={type(content)}, content={str(content)[:300]}, original='{original[:50] if original else ''}'")
return content, original
# Return empty defaults if context is not a dict or doesn't have expected keys
logger.warning(f"Retriev_messages_deal: context missing expected keys, returning empty defaults")
return {}, ''
async def Verify_messages_deal(context):
'''

View File

@@ -1,22 +1,22 @@
# 角色
你是验证专家
你的目标是针对用户的输入Query_Samll字段的提问和Answer_Samll的回答分析是不是回答Query_Samll这个字段的问题
你的目标是针对用户的输入Query_Small字段的提问和Answer_Small的回答分析是不是回答Query_Small这个字段的问题
{#以下可以采用先总括,再展开详细说明的方式,描述你希望智能体在每一个步骤如何进行工作,具体的工作步骤数量可以根据实际需求增删#}
## 工作步骤
1. 获取所有的Query_Samll字段和Answer_Samll字段
2. 分析Answer_Samll的回复是不是和Query_Samll有关系
3. 判断Answer_Samll和Query_Samll之间分析出来的关系状态
1. 获取所有的Query_Small字段和Answer_Small字段
2. 分析Answer_Small的回复是不是和Query_Small有关系
3. 判断Answer_Small和Query_Small之间分析出来的关系状态
4. 如果是True保留否则不要相对应的问题和回答
5. 输出,需要严格按照模版
输入:{{history}}
历史消息:{"history":{{sentence}}}
### 第一步 获取用户的输入
获取用户的输入提取对应的Query_Samll和Answer_Samll
获取用户的输入提取对应的Query_Small和Answer_Small
### 第二步 分析验证
需要分析Query_Samll和Answer_Samll之间的关系可以参考history字段的内容如果有关系不是答非所问
需要分析Query_Small和Answer_Small之间的关系可以参考history字段的内容如果有关系不是答非所问
## 核心验证标准
在评估子问题拆分时必须严格遵循以下标准且验证过程中完全不依赖于子问题的相关信息Answer_Samll
在评估子问题拆分时必须严格遵循以下标准且验证过程中完全不依赖于子问题的相关信息Answer_Small
1. 合理性标准(必须全部满足)
- 完整性:每个不同的子问题必须完整覆盖原问题的所有关键要素(如时间、主体、动作、目标等),无遗漏。
- 最小化每个不同的子问题数量应尽可能少通常不超过原问题关键要素数量的2倍建议2-4个避免冗余和不必要拆分。

View File

@@ -1,15 +1,14 @@
"""
Type classification utility for distinguishing read/write operations.
"""
from jinja2 import Template
from pydantic import BaseModel
from app.core.config import settings
from app.core.logging_config import get_agent_logger, log_prompt_rendering
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import read_template_file
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.config import settings
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from jinja2 import Template
from pydantic import BaseModel
logger = get_agent_logger(__name__)
@@ -19,12 +18,14 @@ class DistinguishTypeResponse(BaseModel):
type: str
async def status_typle(messages: str) -> dict:
async def status_typle(messages: str, llm_model_id: str) -> dict:
"""
Classify message type as read or write operation.
Updated to eliminate global variables in favor of explicit parameters.
Args:
messages: User message to classify
llm_model_id: LLM model ID to use (required, no longer from global variables)
Returns:
dict: Contains 'type' field with classification result
@@ -42,8 +43,9 @@ async def status_typle(messages: str) -> dict:
"message": f"Prompt rendering failed: {str(e)}"
}
from app.core.memory.utils.config import definitions as config_defs
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(llm_model_id)
try:
structured = await llm_client.response_structured(

View File

@@ -1,18 +1,19 @@
from typing import TypedDict, Annotated, List, Any
from langchain_core.messages import AnyMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph, add_messages
import asyncio
import json
from dotenv import load_dotenv, find_dotenv
import os
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from langchain_core.messages import HumanMessage
from jinja2 import Environment, FileSystemLoader
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
from typing import Annotated, Any, List, TypedDict
# Removed global variable imports - use dependency injection instead
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import _to_openai_messages
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from dotenv import find_dotenv, load_dotenv
from jinja2 import Environment, FileSystemLoader
from langchain_core.messages import AnyMessage, HumanMessage
from langgraph.constants import END, START
from langgraph.graph import StateGraph, add_messages
load_dotenv(find_dotenv())
@@ -31,8 +32,17 @@ class State(TypedDict):
class VerifyTool:
def __init__(self, system_prompt: str="", verify_data: Any=None):
def __init__(self, system_prompt: str="", verify_data: Any=None, llm_model_id: str=None):
"""
Updated to eliminate global variables in favor of explicit parameters.
Args:
system_prompt: System prompt for verification
verify_data: Data to verify
llm_model_id: LLM model ID (required, no longer from global variables)
"""
self.system_prompt = system_prompt
self.llm_model_id = llm_model_id
if isinstance(verify_data, str):
self.verify_data = verify_data
else:
@@ -42,7 +52,11 @@ class VerifyTool:
self.verify_data = str(verify_data)
async def model_1(self, state: State) -> State:
llm_client = get_llm_client(SELECTED_LLM_ID)
if not self.llm_model_id:
raise ValueError("llm_model_id is required but not provided")
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(self.llm_model_id)
response_content = await llm_client.chat(
messages=[{"role": "system", "content": self.system_prompt}, *_to_openai_messages(state["messages"])]
)

View File

@@ -1,91 +1,93 @@
import asyncio
from dotenv import load_dotenv
"""
Write Tools for Memory Knowledge Extraction Pipeline
This module provides the main write function for executing the knowledge extraction
pipeline. Only MemoryConfig is needed - clients are constructed internally.
"""
import time
from datetime import datetime
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
# 使用新的模块化架构
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
embedding_generation_all,
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
)
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# 导入配置模块(而不是直接导入变量)
from app.core.memory.utils.config import definitions as config_defs
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.logging_utils import log_time
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import Memory_summary_generation
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.db import get_db_context
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
load_dotenv()
logger = get_agent_logger(__name__)
async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id: str = "wyl20251027", config_id: str = None) -> None:
async def write(
content: str,
user_id: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
ref_id: str = "wyl20251027",
) -> None:
"""
执行完整的知识提取流水线(使用新的 ExtractionOrchestrator
Execute the complete knowledge extraction pipeline.
Only MemoryConfig is needed - LLM and embedding clients are constructed
internally from the config.
Args:
content: 对话内容
user_id: 用户ID
apply_id: 应用ID
group_id: 组ID
ref_id: 参考ID默认为 "wyl20251027"
config_id: 配置ID用于标记数据处理配置
content: Dialogue content to process
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
ref_id: Reference ID, defaults to "wyl20251027"
"""
# 如果提供了config_id重新加载配置
if config_id:
from app.core.memory.utils.config.definitions import reload_configuration_from_database
logger.info(f"Reloading configuration for config_id: {config_id}")
config_loaded = reload_configuration_from_database(config_id)
if not config_loaded:
error_msg = f"Failed to load configuration for config_id: {config_id}"
logger.error(error_msg)
raise ValueError(error_msg)
logger.info(f"Configuration reloaded successfully for config_id: {config_id}")
# Extract config values
embedding_model_id = str(memory_config.embedding_model_id)
chunker_strategy = memory_config.chunker_strategy
config_id = str(memory_config.config_id)
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
logger.info(f"Using model: {config_defs.SELECTED_LLM_NAME}")
logger.info(f"Using LLM ID: {config_defs.SELECTED_LLM_ID}")
logger.info(f"Using chunker strategy: {config_defs.SELECTED_CHUNKER_STRATEGY}")
logger.info(f"Using group ID: {config_defs.SELECTED_GROUP_ID}")
logger.info(f"Using embedding ID: {config_defs.SELECTED_EMBEDDING_ID}")
logger.info(f"Config ID: {config_id if config_id else 'None'}")
logger.info(f"LANGFUSE_ENABLED: {config_defs.LANGFUSE_ENABLED}")
logger.info(f"AGENTA_ENABLED: {config_defs.AGENTA_ENABLED}")
logger.info(f"Config: {memory_config.config_name} (ID: {config_id})")
logger.info(f"Workspace: {memory_config.workspace_name}")
logger.info(f"LLM model: {memory_config.llm_model_name}")
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
logger.info(f"Chunker strategy: {chunker_strategy}")
logger.info(f"Group ID: {group_id}")
# Construct clients from memory_config using factory pattern with db session
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
embedder_client = factory.get_embedder_client_from_config(memory_config)
logger.info("LLM and embedding clients constructed")
# Initialize timing log
log_file = "logs/time.log"
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pipeline Run Started: {timestamp} ===\n")
f.write(f"Config: {memory_config.config_name} (ID: {config_id})\n")
pipeline_start = time.time()
# 初始化客户端
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
# 获取 embedder 配置
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
# Initialize Neo4j connector
neo4j_connector = Neo4jConnector()
# Step 1: 加载和分块数据
# Step 1: Load and chunk data
step_start = time.time()
chunked_dialogs = await get_chunked_dialogs(
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
chunker_strategy=chunker_strategy,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
@@ -94,21 +96,21 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
config_id=config_id,
)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# Step 2: 初始化并运行 ExtractionOrchestrator
# Step 2: Initialize and run ExtractionOrchestrator
step_start = time.time()
from app.core.memory.utils.config.config_utils import get_pipeline_config
config = get_pipeline_config()
pipeline_config = get_pipeline_config(memory_config)
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
config=pipeline_config,
embedding_id=embedding_model_id,
)
# 运行完整的提取流水线
# orchestrator.run returns a flat tuple of 7 values after deduplication
# Run the complete extraction pipeline
(
all_dialogue_nodes,
all_chunk_nodes,
@@ -118,14 +120,12 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
all_statement_entity_edges,
all_entity_entity_edges,
all_dedup_details,
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# Step 8: Save all data to Neo4j database using graph models
# Step 3: Save all data to Neo4j database
step_start = time.time()
# 运行索引创建
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
try:
await create_fulltext_indexes()
@@ -152,18 +152,16 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
log_time("Neo4j Database Save", time.time() - step_start, log_file)
# Step 9: Generate Memory summaries and save to local vector DB and Neo4j
# Step 4: Generate Memory summaries and save to Neo4j
step_start = time.time()
try:
summaries = await Memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
summaries = await memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client
)
# Save memory summaries to Neo4j as nodes
try:
ms_connector = Neo4jConnector()
await add_memory_summary_nodes(summaries, ms_connector)
# Link summaries to statements via chunks for summary→entity queries
await add_memory_summary_statement_edges(summaries, ms_connector)
finally:
try:
@@ -173,24 +171,15 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
finally:
log_time("Memory Summary (Local Vector DB & Neo4j)", time.time() - step_start, log_file)
log_time("Memory Summary (Neo4j)", time.time() - step_start, log_file)
# Log total pipeline time
total_time = time.time() - pipeline_start
log_time("TOTAL PIPELINE TIME", total_time, log_file)
# Add completion marker to log
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")
logger.info(f"Timing details saved to: {log_file}")
if __name__ == "__main__":
content = "你好,我是张三,是张曼婷的新朋友。请问张曼婷喜欢什么?"
asyncio.run(write(content, ref_id="wyl20251027"))

View File

@@ -1,8 +1,9 @@
import sys
import os
import asyncio
from neo4j import GraphDatabase
import os
import sys
from typing import List, Tuple
from neo4j import GraphDatabase
from pydantic import BaseModel, Field
# ------------------- 自包含路径解析 -------------------
@@ -31,21 +32,54 @@ except NameError:
# ---------------------------------------------------------------------
# 现在路径已经配置好,我们可以使用绝对导入
from app.core.config import settings
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
from app.core.memory.utils.llm.llm_utils import get_llm_client
import json
from app.core.config import settings
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
#TODO: Fix this
# Default values (previously from definitions.py)
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
# 定义用于LLM结构化输出的Pydantic模型
class FilteredTags(BaseModel):
"""用于接收LLM筛选后的核心标签列表的模型。"""
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
async def filter_tags_with_llm(tags: List[str], llm_client) -> List[str]:
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
"""
使用LLM筛选标签列表仅保留具有代表性的核心名词。
"""
try:
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
# 3. 构建Prompt
tag_list_str = ", ".join(tags)
@@ -140,8 +174,8 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
limit: 返回的标签数量限制
by_user: 是否按user_id查询默认False按group_id查询
"""
# 默认从 runtime.json selections.group_id 读取
group_id = group_id or SELECTED_GROUP_ID
# 默认从环境变量读取
group_id = group_id or DEFAULT_GROUP_ID
# 1. 从数据库获取原始排名靠前的标签
raw_tags_with_freq = get_raw_tags_from_db(group_id, limit, by_user=by_user)
if not raw_tags_with_freq:
@@ -150,9 +184,7 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
from app.core.memory.utils.config import definitions as config_defs
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, llm_client)
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
# 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序
final_tags = []
@@ -165,8 +197,8 @@ async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_u
if __name__ == "__main__":
print("开始获取热门记忆标签...")
try:
# 直接使用 runtime.json 中的 group_id
group_id_to_query = SELECTED_GROUP_ID
# 直接使用环境变量中的 group_id
group_id_to_query = DEFAULT_GROUP_ID
# 使用 asyncio.run 来执行异步主函数
top_tags = asyncio.run(get_hot_memory_tags(group_id=group_id_to_query))

View File

@@ -5,9 +5,9 @@ This script can be executed directly to generate a memory insight report for a t
"""
import asyncio
import json
import os
import sys
import json
from collections import Counter
from datetime import datetime
@@ -17,12 +17,18 @@ src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if src_path not in sys.path:
sys.path.insert(0, src_path)
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
from pydantic import BaseModel, Field
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
#TODO: Fix this
# Default values (previously from definitions.py)
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
# 定义用于LLM结构化输出的Pydantic模型
class TagClassification(BaseModel):
@@ -55,8 +61,33 @@ class MemoryInsight:
def __init__(self, user_id: str):
self.user_id = user_id
self.neo4j_connector = Neo4jConnector()
from app.core.memory.utils.config import definitions as config_defs
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
async def close(self):
"""关闭数据库连接。"""
@@ -294,8 +325,8 @@ async def main():
"""
Initializes and runs the memory insight analysis for a test user.
"""
# 默认从 runtime.json selections.group_id 读取
test_user_id = SELECTED_GROUP_ID
# 默认从环境变量读取
test_user_id = DEFAULT_GROUP_ID
print(f"正在为用户 {test_user_id} 生成记忆洞察报告...\n")
insight = None

View File

@@ -6,10 +6,10 @@ Usage:
python -m analytics.user_summary --user_id <group_id>
"""
import os
import sys
import asyncio
import json
import os
import sys
from dataclasses import dataclass
from typing import List, Tuple
@@ -24,10 +24,17 @@ try:
except Exception:
pass
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, SELECTED_LLM_ID
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
#TODO: Fix this
# Default values (previously from definitions.py)
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
@dataclass
@@ -42,8 +49,33 @@ class UserSummary:
def __init__(self, user_id: str):
self.user_id = user_id
self.connector = Neo4jConnector()
from app.core.memory.utils.config import definitions as config_defs
self.llm = get_llm_client(config_defs.SELECTED_LLM_ID)
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
async def close(self):
await self.connector.close()
@@ -107,8 +139,8 @@ class UserSummary:
async def generate_user_summary(user_id: str | None = None) -> str:
# 默认从 runtime.json selections.group_id 读取
effective_group_id = user_id or SELECTED_GROUP_ID
# 默认从环境变量读取
effective_group_id = user_id or DEFAULT_GROUP_ID
svc = UserSummary(effective_group_id)
try:
return await svc.generate()
@@ -139,7 +171,7 @@ if __name__ == "__main__":
with open(dashboard_path, "r", encoding="utf-8") as rf:
existing = json.load(rf)
existing["user_summary"] = {
"group_id": SELECTED_GROUP_ID,
"group_id": DEFAULT_GROUP_ID,
"summary": summary
}
with open(dashboard_path, "w", encoding="utf-8") as wf:

View File

@@ -1,132 +0,0 @@
{
"llm_list": [
{
"llm_name": "qwen2.5-14b-instruct-awq",
"api_base": "http://175.27.131.196:9090/v1",
"api_key": "OPENAI_API_KEY"
},
{
"llm_name": "openai/qwen2.5-14b-instruct-awq",
"api_base": "http://175.27.131.196:9090/v1",
"api_key": "OPENAI_API_AGENT_KEY"
},
{
"llm_name": "openai/qwen2.5-14b",
"api_base": "http://43.137.4.24:9090/v1",
"api_key": "OPENAI_API_KEY"
},
{
"llm_name": "openai/qwen2.5-14b-instruct-awq",
"api_base": "http://175.27.131.196:9090/v1",
"api_key": "OPENAI_API_KEY"
},
{
"llm_name": "openai/qwen3-14b",
"api_base": "http://43.137.4.24:9090/v1",
"api_key": "OPENAI_API_KEY"
},
{
"llm_name": "openai/deepseek-r1-0528-qwen3-8b",
"api_base": "http://43.137.4.24:9090/v1",
"api_key": "OPENAI_API_KEY"
},
{
"llm_name": "openai/qwen3-235b-a22b-instruct-2507",
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"api_key": "DASHSCOPE_API_KEY"
}
,
{
"llm_name": "openai/qwen-plus",
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"api_key": "DASHSCOPE_API_KEY"
},
{
"llm_name": "bedrock/global.anthropic.claude-sonnet-4-5-20250929-v1:0"
},
{
"llm_name": "bedrock/global.anthropic.claude-sonnet-4-20250514-v1:0"
}
],
"embedding_list": [
{
"embedding_name": "openai/nomic-embed-text:v1.5",
"api_base": "http://119.45.239.97:11434/v1",
"dimension": 768
},
{
"embedding_name": "openai/bge-m3",
"api_base": "http://43.137.4.24:9090/v1",
"dimension": 1024
}
],
"neo4j": {
"uri": "bolt://1.94.111.67:7687",
"username": "neo4j"
},
"chunker_list": [
{
"chunker_strategy": "TokenChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 512,
"chunk_overlap": 56,
"tokenizer_or_token_counter": "character"
},
{
"chunker_strategy": "RecursiveChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 512,
"min_characters_per_chunk": 50
},
{
"chunker_strategy": "SemanticChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 1024,
"threshold": 0.8,
"min_sentences": 2,
"skip_window": 1,
"min_characters_per_chunk": 100
},
{
"chunker_strategy": "LateChunker",
"embedding_model": "all-MiniLM-L6-v2",
"chunk_size": 2048,
"min_characters_per_chunk": 24
},
{
"chunker_strategy": "NeuralChunker",
"embedding_model": "mirth/chonky_modernbert_base_1",
"min_characters_per_chunk": 24
},
{
"chunker_strategy": "LLMChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 1000,
"min_characters_per_chunk": 100
},
{
"chunker_strategy": "HybridChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 512,
"threshold": 0.8,
"min_characters_per_chunk": 100
},
{
"chunker_strategy": "SentenceChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 2048,
"chunk_overlap": 128,
"min_sentences_per_chunk": 1,
"min_characters_per_sentence": 12,
"delim": [".", "!", "?", "\n"],
"include_delim": "prev",
"tokenizer_or_token_counter": "character"
}
],
"langfuse": {
"enabled": true
},
"agenta": {
"enabled": false
}
}

File diff suppressed because one or more lines are too long

View File

@@ -1,5 +0,0 @@
{
"selections": {
"config_id": ""
}
}

View File

@@ -1,22 +1,34 @@
import os
import asyncio
import json
from typing import List, Dict, Any, Optional
from datetime import datetime
import os
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.llm_tools.openai_client import LLMClient
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import SELECTED_CHUNKER_STRATEGY, SELECTED_EMBEDDING_ID
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
DialogData,
)
# 使用新的模块化架构
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
)
from app.core.memory.utils.config.definitions import (
SELECTED_CHUNKER_STRATEGY,
SELECTED_EMBEDDING_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
# Import from database module
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# Cypher queries for evaluation
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
@@ -52,7 +64,9 @@ async def ingest_contexts_via_full_pipeline(
llm_available = True
try:
from app.core.memory.utils.config import definitions as config_defs
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
except Exception as e:
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
llm_available = False
@@ -133,12 +147,13 @@ async def ingest_contexts_via_full_pipeline(
return False
# 初始化 embedder 客户端
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.services.memory_config_service import MemoryConfigService
try:
embedder_config_dict = get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
with get_db_context() as db:
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
except Exception as e:
@@ -236,15 +251,15 @@ async def ingest_contexts_via_full_pipeline(
print("[Ingestion] Generating memory summaries...")
try:
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
Memory_summary_generation,
memory_summary_generation,
)
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
summaries = await Memory_summary_generation(
summaries = await memory_summary_generation(
chunked_dialogs=dialog_data_list,
llm_client=llm_client,
embedding_id=embedding_name or SELECTED_EMBEDDING_ID
embedder_client=embedder_client
)
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
except Exception as e:

View File

@@ -15,7 +15,7 @@ import json
import os
import time
from datetime import datetime
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
try:
from dotenv import load_dotenv
@@ -23,37 +23,38 @@ except ImportError:
def load_dotenv():
pass
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config_utils import get_embedder_config
from app.core.memory.utils.definitions import (
PROJECT_ROOT,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
SELECTED_EMBEDDING_ID
)
from app.core.memory.utils.llm_utils import get_llm_client
from app.core.memory.evaluation.common.metrics import (
f1_score,
avg_context_tokens,
bleu1,
f1_score,
jaccard,
latency_stats,
avg_context_tokens
)
from app.core.memory.evaluation.locomo.locomo_metrics import (
get_category_name,
locomo_f1_score,
locomo_multi_f1,
get_category_name
)
from app.core.memory.evaluation.locomo.locomo_utils import (
load_locomo_data,
extract_conversations,
ingest_conversations_if_needed,
load_locomo_data,
resolve_temporal_references,
select_and_format_information,
retrieve_relevant_information,
ingest_conversations_if_needed
select_and_format_information,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
async def run_locomo_benchmark(
@@ -160,10 +161,16 @@ async def run_locomo_benchmark(
# Step 3: Initialize clients
print("🔧 Initializing clients...")
connector = Neo4jConnector()
llm_client = get_llm_client(SELECTED_LLM_ID)
# Initialize LLM client with database context
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# Initialize embedder
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)

View File

@@ -1,14 +1,16 @@
# file name: check_neo4j_connection_fixed.py
import asyncio
import os
import sys
import json
import time
import math
import os
import re
import sys
import time
from datetime import datetime, timedelta
from typing import List, Dict, Any
from typing import Any, Dict, List
from dotenv import load_dotenv
# 1
# 添加项目根目录到路径
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -34,7 +36,7 @@ def _loc_normalize(text: str) -> str:
# 尝试从 metrics.py 导入基础指标
try:
from common.metrics import f1_score, bleu1, jaccard
from common.metrics import bleu1, f1_score, jaccard
print("✅ 从 metrics.py 导入基础指标成功")
except ImportError as e:
print(f"❌ 从 metrics.py 导入失败: {e}")
@@ -111,10 +113,14 @@ try:
# 尝试从不同位置导入
try:
from locomo.qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
from locomo.qwen_search_eval import (
_resolve_relative_times,
loc_f1_score,
loc_multi_f1,
)
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError:
from qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError as e:
@@ -429,13 +435,17 @@ async def run_enhanced_evaluation():
return None
# 修正导入路径:使用 app.core.memory.src 前缀
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 加载数据
# 获取项目根目录
@@ -458,10 +468,14 @@ async def run_enhanced_evaluation():
# 初始化增强监控器
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
llm = get_llm_client(SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化embedder
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)

View File

@@ -2,10 +2,11 @@ import argparse
import asyncio
import json
import os
import statistics
import time
from datetime import datetime, timedelta
from typing import List, Dict, Any
import statistics
from typing import Any, Dict, List
try:
from dotenv import load_dotenv
except Exception:
@@ -13,16 +14,31 @@ except Exception:
return None
import re
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
bleu1,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, bleu1, jaccard, latency_stats, avg_context_tokens
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
@@ -327,9 +343,13 @@ async def run_locomo_eval(
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
# 使用异步LLM客户端
llm_client = get_llm_client(SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化embedder用于直接调用
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)

View File

@@ -2,11 +2,11 @@ import argparse
import asyncio
import json
import os
import time
import re
import statistics
import time
from datetime import datetime, timedelta
from typing import List, Dict, Any
from typing import Any, Dict, List
try:
from dotenv import load_dotenv
@@ -16,6 +16,7 @@ except Exception:
# 确保可以找到 src 及项目根路径
import sys
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(_THIS_DIR)))
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
@@ -25,19 +26,33 @@ for _p in (_SRC_DIR, _PROJECT_ROOT):
# 与现有评估脚本保持一致的导入方式
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
try:
# 优先从 extraction_utils1 导入
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline # type: ignore
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline, # type: ignore
)
except Exception:
ingest_contexts_via_full_pipeline = None # 在运行时做兜底检查
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.services.memory_config_service import MemoryConfigService
try:
from app.core.memory.evaluation.common.metrics import exact_match
except Exception:
@@ -686,9 +701,13 @@ async def run_longmemeval_test(
)
# 初始化组件(摄入后再初始化连接器)- 使用异步LLM客户端
llm_client = get_llm_client(SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
connector = Neo4jConnector()
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
@@ -748,10 +767,10 @@ async def run_longmemeval_test(
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# for sm in summaries:
# summary_text = str(sm.get("summary", "")).strip()
# if summary_text:
# contexts_all.append(summary_text)
# 实体摘要最多3个
scored = [e for e in entities if e.get("score") is not None]

View File

@@ -2,11 +2,11 @@ import argparse
import asyncio
import json
import os
import time
import re
import statistics
import time
from datetime import datetime, timedelta
from typing import List, Dict, Any
from typing import Any, Dict, List
try:
from dotenv import load_dotenv
@@ -15,15 +15,26 @@ except Exception:
return None
# 与现有评估脚本保持一致的导入方式
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config_utils import get_embedder_config
from app.core.memory.utils.llm_utils import get_llm_client
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_LLM_ID, SELECTED_EMBEDDING_ID
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
try:
from app.core.memory.evaluation.common.metrics import exact_match
except Exception:
@@ -647,9 +658,13 @@ async def run_longmemeval_test(
items = qa_list[start_index:start_index + sample_size]
# 初始化组件 - 使用异步LLM客户端
llm_client = get_llm_client(SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
connector = Neo4jConnector()
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)

View File

@@ -4,19 +4,35 @@ import json
import os
import time
from datetime import datetime
from typing import List, Dict, Any
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
exact_match,
latency_stats,
)
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline,
)
from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
@@ -119,7 +135,7 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
return merged
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid") -> Dict[str, Any]:
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
group_id = group_id or SELECTED_GROUP_ID
# Load data
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
@@ -134,7 +150,9 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
await ingest_contexts_via_full_pipeline(contexts, group_id)
# LLM client (使用异步调用)
llm_client = get_llm_client(SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# Evaluate each item
connector = Neo4jConnector()
@@ -159,6 +177,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
limit=search_limit,
include=["dialogues", "statements", "entities"],
output_path=None,
memory_config=memory_config,
)
except Exception:
results = None
@@ -242,7 +261,11 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
correct_flags.append(exact_match(pred, reference))
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
from app.core.memory.evaluation.common.metrics import (
bleu1,
f1_score,
jaccard,
)
f1s.append(f1_score(str(pred), str(reference)))
b1s.append(bleu1(str(pred), str(reference)))
jss.append(jaccard(str(pred), str(reference)))

View File

@@ -2,10 +2,10 @@ import argparse
import asyncio
import json
import os
import re
import time
from datetime import datetime
from typing import List, Dict, Any
import re
from typing import Any, Dict, List
try:
from dotenv import load_dotenv
@@ -15,6 +15,7 @@ except Exception:
# 路径与模块导入保持与现有评估脚本一致
import sys
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
@@ -23,17 +24,27 @@ for _p in (_SRC_DIR, _PROJECT_ROOT):
sys.path.insert(0, _p)
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
exact_match,
latency_stats,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config_utils import get_embedder_config
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.definitions import PROJECT_ROOT, SELECTED_GROUP_ID, SELECTED_EMBEDDING_ID, SELECTED_LLM_ID
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
try:
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard
except Exception:
# 兜底:简单实现(必要时)
def f1_score(pred: str, ref: str) -> float:
@@ -226,13 +237,17 @@ async def run_memsciqa_test(
items = all_items[start_index:start_index + sample_size]
# 初始化 LLM纯测试不进行摄入
llm = get_llm_client(SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化 Neo4j 连接与向量检索 Embedder对齐 locomo_test
connector = Neo4jConnector()
embedder = None
if search_type in ("embedding", "hybrid"):
cfg_dict = get_embedder_config(SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)

View File

@@ -5,18 +5,17 @@ OpenAI LLM 客户端实现
"""
import asyncio
from typing import List, Dict, Any
import json
import logging
from typing import Any, Dict, List
from pydantic import BaseModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from app.core.config import settings
from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException
from app.core.models.base import RedBearModelConfig
from app.core.models.llm import RedBearLLM
from app.core.memory.llm_tools.llm_client import LLMClient, LLMClientException
from app.core.memory.utils.config.definitions import LANGFUSE_ENABLED
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel
logger = logging.getLogger(__name__)
@@ -43,7 +42,7 @@ class OpenAIClient(LLMClient):
# 初始化 Langfuse 回调处理器(如果启用)
self.langfuse_handler = None
if LANGFUSE_ENABLED:
if settings.LANGFUSE_ENABLED:
try:
from langfuse.langchain import CallbackHandler
self.langfuse_handler = CallbackHandler()

View File

@@ -1,403 +0,0 @@
"""
MemSci 记忆系统主入口 - 重构版本
该模块是重构后的记忆系统主入口,使用新的模块化架构。
旧版本入口app/core/memory/src/main.py已删除。
主要功能:
1. 协调整个知识提取流水线
2. 支持试运行模式和正常运行模式
3. 使用重构后的 storage_services 模块
4. 提供统一的配置管理和日志记录
作者Lance77
日期2025-11-22
"""
# 必须在最开始禁用 LangSmith 追踪,避免速率限制错误
import os
os.environ["LANGCHAIN_TRACING_V2"] = "false"
os.environ["LANGCHAIN_TRACING"] = "false"
import asyncio
import time
from datetime import datetime
from typing import Optional, Callable, Awaitable
from dotenv import load_dotenv
# 导入重构后的模块
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.message_models import ConversationMessage, ConversationContext, DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
# 导入数据加载函数
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
get_chunked_dialogs_with_preprocessing,
get_chunked_dialogs_from_preprocessed,
)
# 导入配置模块(而不是直接导入变量)
from app.core.memory.utils.config import definitions as config_defs
from app.core.logging_config import get_memory_logger, log_time
load_dotenv()
logger = get_memory_logger(__name__)
async def main(
dialogue_text: Optional[str] = None,
is_pilot_run: bool = False,
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None
):
"""
记忆系统主流程 - 重构版本
该函数是重构后的主入口,使用新的模块化架构。
Args:
dialogue_text: 输入的对话文本(可选,用于试运行模式)
is_pilot_run: 是否为试运行模式
- True: 试运行模式,不保存到 Neo4j
- False: 正常运行模式,保存到 Neo4j
progress_callback: 可选的进度回调函数
- 类型: Callable[[str, str, Optional[dict]], Awaitable[None]]
- 参数1 (stage): 当前处理阶段标识符
- 参数2 (message): 人类可读的进度消息
- 参数3 (data): 可选的附加数据字典,包含详细的进度信息或结果
- 在管线关键点调用以报告进度和结果数据
工作流程:
1. 初始化客户端和配置
2. 加载或准备数据
3. 执行知识提取流水线
4. 保存结果(正常模式)或输出结果(试运行模式)
"""
print("=" * 60)
print("MemSci 知识提取流水线 - 重构版本")
print("=" * 60)
print(f"运行模式: {'试运行不保存到Neo4j' if is_pilot_run else '正常运行保存到Neo4j'}")
print("Using chunker strategy:", config_defs.SELECTED_CHUNKER_STRATEGY)
print("Using group ID:", config_defs.SELECTED_GROUP_ID)
print("Using model ID:", config_defs.SELECTED_LLM_ID)
print("Using embedding model ID:", config_defs.SELECTED_EMBEDDING_ID)
print("LANGFUSE_ENABLED:", config_defs.LANGFUSE_ENABLED)
print("AGENTA_ENABLED:", config_defs.AGENTA_ENABLED)
print("=" * 60)
# 初始化日志
log_file = "logs/time.log"
os.makedirs(os.path.dirname(log_file), exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pipeline Run Started: {timestamp} ({'Pilot Run' if is_pilot_run else 'Normal Run'}) ===\n")
pipeline_start = time.time()
try:
# 步骤 1: 初始化客户端
logger.info("Initializing clients...")
step_start = time.time()
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
# 获取 embedder 配置并转换为 RedBearModelConfig 对象
from app.core.models.base import RedBearModelConfig
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
neo4j_connector = Neo4jConnector()
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 加载或准备数据
logger.info("Loading data...")
logger.info(f"[MAIN] dialogue_text type={type(dialogue_text)}, length={len(dialogue_text) if dialogue_text else 0}, is_pilot_run={is_pilot_run}")
logger.info(f"[MAIN] dialogue_text preview: {repr(dialogue_text)[:200] if dialogue_text else 'None'}")
logger.info(f"[MAIN] Condition check: dialogue_text={bool(dialogue_text)}, isinstance={isinstance(dialogue_text, str) if dialogue_text else False}, strip={bool(dialogue_text.strip()) if dialogue_text and isinstance(dialogue_text, str) else False}")
step_start = time.time()
if dialogue_text and isinstance(dialogue_text, str) and dialogue_text.strip():
# 试运行模式:处理前端传入的对话文本
logger.info("[MAIN] ✓ Using frontend dialogue text (pilot run mode)")
import re
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
pattern = r"(用户|AI)[:]\s*([^\n]+(?:\n(?!(?:用户|AI)[:])[^\n]*)*?)"
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
messages = [
ConversationMessage(role=r, msg=c.strip())
for r, c in matches if c.strip()
]
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
if not messages:
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
# 创建对话上下文和对话数据
context = ConversationContext(msgs=messages)
dialog = DialogData(
context=context,
ref_id="pilot_dialog_1",
group_id=config_defs.SELECTED_GROUP_ID,
user_id=config_defs.SELECTED_USER_ID,
apply_id=config_defs.SELECTED_APPLY_ID,
metadata={"source": "pilot_run", "input_type": "frontend_text"}
)
# 进度回调:开始预处理文本
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
# 对前端传入的对话进行分块处理
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
llm_client=llm_client,
)
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
# 进度回调:输出每个分块的结果
if progress_callback:
for dialog in chunked_dialogs:
for i, chunk in enumerate(dialog.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dialog.id,
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
# 进度回调:预处理文本完成
preprocessing_summary = {
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
else:
# 正常运行模式:从 testdata.json 文件加载
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
logger.info("Loading data from testdata.json...")
test_data_path = os.path.join(
os.path.dirname(__file__), "data", "testdata.json"
)
if not os.path.exists(test_data_path):
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
# 进度回调:开始预处理文本
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
group_id=config_defs.SELECTED_GROUP_ID,
user_id=config_defs.SELECTED_USER_ID,
apply_id=config_defs.SELECTED_APPLY_ID,
indices=config_defs.SELECTED_TEST_DATA_INDICES,
input_data_path=test_data_path,
llm_client=llm_client,
skip_cleaning=True,
)
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
# 进度回调:输出每个分块的结果
if progress_callback:
for dialog in chunked_dialogs:
for i, chunk in enumerate(dialog.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dialog.id,
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
# 进度回调:预处理文本完成
preprocessing_summary = {
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化流水线编排器
logger.info("Initializing extraction orchestrator...")
step_start = time.time()
# 从 runtime.json 加载配置(已经过数据库覆写)
from app.core.memory.utils.config.config_utils import get_pipeline_config
config = get_pipeline_config()
logger.info(f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}")
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
progress_callback=progress_callback, # 传递进度回调
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
# 步骤 4: 执行知识提取流水线
logger.info("Running extraction pipeline...")
step_start = time.time()
# 进度回调:正在知识抽取
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=is_pilot_run, # 传递试运行模式标志
)
# 解包 extraction_result tuple
# extraction_result 是一个包含 7 个元素的 tuple:
# (dialogue_nodes, chunk_nodes, statement_nodes, entity_nodes,
# statement_chunk_edges, statement_entity_edges, entity_edges)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# 进度回调:生成结果
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 保存结果或输出结果
if is_pilot_run:
logger.info("Pilot run mode: Skipping Neo4j save")
print("\n试运行模式:跳过 Neo4j 保存,流水线处理完成。")
print("提取结果已生成,可在相关输出中查看。")
else:
logger.info("Normal mode: Saving to Neo4j...")
step_start = time.time()
# 创建索引和约束
try:
from app.repositories.neo4j.create_indexes import (
create_fulltext_indexes,
create_unique_constraints,
)
await create_fulltext_indexes()
await create_unique_constraints()
logger.info("Successfully created indexes and constraints")
except Exception as e:
logger.error(f"Error creating indexes/constraints: {e}")
# 保存数据到 Neo4j
try:
from app.repositories.neo4j.graph_saver import (
save_dialog_and_statements_to_neo4j,
)
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=dialogue_nodes,
chunk_nodes=chunk_nodes,
statement_nodes=statement_nodes,
entity_nodes=entity_nodes,
statement_chunk_edges=statement_chunk_edges,
statement_entity_edges=statement_entity_edges,
entity_edges=entity_edges,
connector=neo4j_connector,
)
if success:
logger.info("Successfully saved all data to Neo4j")
print("\n✓ 成功保存所有数据到 Neo4j")
else:
logger.warning("Failed to save some data to Neo4j")
print("\n⚠ 部分数据保存到 Neo4j 失败")
except Exception as e:
logger.error(f"Error saving to Neo4j: {e}", exc_info=True)
print(f"\n✗ 保存到 Neo4j 失败: {e}")
log_time("Neo4j Database Save", time.time() - step_start, log_file)
# 步骤 6: 生成记忆摘要(可选)
try:
logger.info("Generating memory summaries...")
step_start = time.time()
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
Memory_summary_generation,
)
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.add_edges import (
add_memory_summary_statement_edges,
)
summaries = await Memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedding_id=config_defs.SELECTED_EMBEDDING_ID
)
if not is_pilot_run:
# 保存记忆摘要到 Neo4j
ms_connector = Neo4jConnector()
try:
await add_memory_summary_nodes(summaries, ms_connector)
await add_memory_summary_statement_edges(summaries, ms_connector)
finally:
await ms_connector.close()
log_time("Memory Summary Generation", time.time() - step_start, log_file)
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
except Exception as e:
logger.error(f"Pipeline execution failed: {e}", exc_info=True)
print(f"\n✗ 流水线执行失败: {e}")
raise
finally:
# 清理资源
try:
await neo4j_connector.close()
except Exception:
pass
# 记录总时间
total_time = time.time() - pipeline_start
log_time("TOTAL PIPELINE TIME", total_time, log_file)
# 添加完成标记
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds")
logger.info(f"Timing details saved to: {log_file}")
print("\n" + "=" * 60)
print("✓ 流水线执行完成")
print(f"✓ 总耗时: {total_time:.2f}")
print(f"✓ 详细日志: {log_file}")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -20,14 +20,14 @@ Classes:
MemorySummaryNode: Node representing a memory summary
"""
from uuid import uuid4
import re
from datetime import datetime, timezone
from typing import List, Optional
from pydantic import BaseModel, Field, field_validator
import re
from uuid import uuid4
from app.core.memory.utils.data.ontology import TemporalInfo
from app.core.memory.utils.alias_utils import validate_aliases
from app.core.memory.utils.data.ontology import TemporalInfo
from pydantic import BaseModel, Field, field_validator
def parse_historical_datetime(v):
@@ -361,7 +361,7 @@ class ExtractedEntityNode(Node):
description="Entity aliases - alternative names for this entity"
)
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
fact_summary: str = Field(..., description="Summary of the fact about this entity")
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")

View File

@@ -10,9 +10,10 @@ Classes:
"""
from typing import List, Optional
from pydantic import BaseModel, Field, ConfigDict
from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field
class Entity(BaseModel):
"""Represents an extracted entity from dialogue.

View File

@@ -1,31 +1,43 @@
import argparse
import asyncio
import json
import math
import os
import time
from typing import List, Dict, Any, Optional
from dotenv import load_dotenv
from datetime import datetime
import math
from typing import TYPE_CHECKING, Any, Dict, List, Optional
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
from app.core.logging_config import get_memory_logger
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import (
search_graph_by_embedding, search_graph,
search_graph_by_temporal, search_graph_by_keyword_temporal,
search_graph_by_chunk_id
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.config_models import TemporalSearchParams
from app.core.memory.utils.config.config_utils import get_embedder_config, get_pipeline_config
from app.core.memory.utils.data.time_utils import normalize_date_safe
from app.core.memory.models.variate_config import ForgettingEngineConfig
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import (
ForgettingEngine,
)
from app.core.memory.utils.config.config_utils import (
get_pipeline_config,
)
from app.core.memory.utils.data.text_utils import extract_plain_query
from app.core.memory.utils.config import definitions as config_defs
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.data.time_utils import normalize_date_safe
from app.core.memory.utils.llm.llm_utils import get_reranker_client
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import (
search_graph,
search_graph_by_chunk_id,
search_graph_by_embedding,
search_graph_by_keyword_temporal,
search_graph_by_temporal,
)
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
from dotenv import load_dotenv
load_dotenv()
logger = get_memory_logger(__name__)
@@ -131,7 +143,7 @@ def rerank_hybrid_results(
# Add keyword results with BM25 scores
for item in keyword_items:
item_id = item.get("id") or item.get("uuid")
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
@@ -139,7 +151,7 @@ def rerank_hybrid_results(
# Add or update with embedding results
for item in embedding_items:
item_id = item.get("id") or item.get("uuid")
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id:
if item_id in combined_items:
# Update existing item with embedding score
@@ -220,7 +232,7 @@ def rerank_with_forgetting_curve(
(keyword_items, False), (embedding_items, True)
):
for item in src_items:
item_id = item.get("id") or item.get("uuid")
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if not item_id:
continue
existing = combined_items.get(item_id)
@@ -266,26 +278,25 @@ def rerank_with_forgetting_curve(
return reranked
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = "search_log.txt"):
"""Log search query information to file"""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None):
"""Log search query information using the logger.
Args:
query_text: The search query text
search_type: Type of search (keyword, embedding, hybrid)
group_id: Group identifier for filtering
limit: Maximum number of results
include: List of result types to include
log_file: Deprecated parameter, kept for backward compatibility
"""
# Ensure the query text is plain and clean before logging
cleaned_query = extract_plain_query(query_text)
log_entry = {
"timestamp": timestamp,
# "query": query_text,
"query": cleaned_query,
"search_type": search_type,
"group_id": group_id,
"limit": limit,
"include": include
}
# Append to log file
with open(log_file, "a", encoding="utf-8") as f:
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
logger.info(f"Search logged: {query_text} ({search_type})")
# Log using the standard logger
logger.info(
f"Search query: query='{cleaned_query}', type={search_type}, "
f"group_id={group_id}, limit={limit}, include={include}"
)
def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
@@ -315,229 +326,229 @@ def apply_reranker_placeholder(
If config enables reranker, annotate items with a final_score equal to combined_score
and keep ordering. This is a no-op reranker to be replaced later.
"""
try:
rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}))
except Exception as e:
logger.debug(f"Failed to load reranker config: {e}")
rc = {}
if not rc or not rc.get("enabled", False):
return results
# try:
# rc = (RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {}))
# except Exception as e:
# logger.debug(f"Failed to load reranker config: {e}")
# rc = {}
# if not rc or not rc.get("enabled", False):
# return results
top_k = int(rc.get("top_k", 100))
model_name = rc.get("model", "placeholder")
# top_k = int(rc.get("top_k", 100))
# model_name = rc.get("model", "placeholder")
for cat, items in results.items():
head = items[:top_k]
for it in head:
base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0)
it["final_score"] = base
it["reranker_model"] = model_name
# Keep overall order by final_score if present, otherwise combined/score
results[cat] = sorted(
items,
key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)),
reverse=True,
)
# for cat, items in results.items():
# head = items[:top_k]
# for it in head:
# base = float(it.get("combined_score", it.get("score", 0.0)) or 0.0)
# it["final_score"] = base
# it["reranker_model"] = model_name
# # Keep overall order by final_score if present, otherwise combined/score
# results[cat] = sorted(
# items,
# key=lambda x: float(x.get("final_score", x.get("combined_score", x.get("score", 0.0)) or 0.0)),
# reverse=True,
# )
return results
async def apply_llm_reranker(
results: Dict[str, List[Dict[str, Any]]],
query_text: str,
reranker_client: Optional[Any] = None,
llm_weight: Optional[float] = None,
top_k: Optional[int] = None,
batch_size: Optional[int] = None,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Apply LLM-based reranking to search results.
# async def apply_llm_reranker(
# results: Dict[str, List[Dict[str, Any]]],
# query_text: str,
# reranker_client: Optional[Any] = None,
# llm_weight: Optional[float] = None,
# top_k: Optional[int] = None,
# batch_size: Optional[int] = None,
# ) -> Dict[str, List[Dict[str, Any]]]:
# """
# Apply LLM-based reranking to search results.
Args:
results: Search results organized by category
query_text: Original search query
reranker_client: Optional pre-initialized reranker client
llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
top_k: Maximum number of items to rerank per category
batch_size: Number of items to process concurrently
# Args:
# results: Search results organized by category
# query_text: Original search query
# reranker_client: Optional pre-initialized reranker client
# llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
# top_k: Maximum number of items to rerank per category
# batch_size: Number of items to process concurrently
Returns:
Reranked results with final_score and reranker_model fields
"""
# Load reranker configuration from runtime.json
try:
rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})
except Exception as e:
logger.debug(f"Failed to load reranker config: {e}")
rc = {}
# Returns:
# Reranked results with final_score and reranker_model fields
# """
# # Load reranker configuration from runtime.json
# # try:
# # rc = RUNTIME_CONFIG.get("reranker", {}) or CONFIG.get("reranker", {})
# # except Exception as e:
# # logger.debug(f"Failed to load reranker config: {e}")
# # rc = {}
# Check if reranking is enabled
enabled = rc.get("enabled", False)
if not enabled:
logger.debug("LLM reranking is disabled in configuration")
return results
# # Check if reranking is enabled
# enabled = rc.get("enabled", False)
# if not enabled:
# logger.debug("LLM reranking is disabled in configuration")
# return results
# Load configuration parameters with defaults
llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
top_k = top_k if top_k is not None else rc.get("top_k", 20)
batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
# # Load configuration parameters with defaults
# llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
# top_k = top_k if top_k is not None else rc.get("top_k", 20)
# batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
# Initialize reranker client if not provided
if reranker_client is None:
try:
reranker_client = get_reranker_client()
except Exception as e:
logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
return results
# # Initialize reranker client if not provided
# if reranker_client is None:
# try:
# reranker_client = get_reranker_client()
# except Exception as e:
# logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
# return results
# Get model name for metadata
model_name = getattr(reranker_client, 'model_name', 'unknown')
# # Get model name for metadata
# model_name = getattr(reranker_client, 'model_name', 'unknown')
# Process each category
reranked_results = {}
for category in ["statements", "chunks", "entities", "summaries"]:
items = results.get(category, [])
if not items:
reranked_results[category] = []
continue
# # Process each category
# reranked_results = {}
# for category in ["statements", "chunks", "entities", "summaries"]:
# items = results.get(category, [])
# if not items:
# reranked_results[category] = []
# continue
# Select top K items by combined_score for reranking
sorted_items = sorted(
items,
key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
reverse=True
)
# # Select top K items by combined_score for reranking
# sorted_items = sorted(
# items,
# key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
# reverse=True
# )
top_items = sorted_items[:top_k]
remaining_items = sorted_items[top_k:]
# top_items = sorted_items[:top_k]
# remaining_items = sorted_items[top_k:]
# Extract text content from each item
def extract_text(item: Dict[str, Any]) -> str:
"""Extract text content from a result item."""
# Try different text fields based on category
text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
return str(text).strip()
# # Extract text content from each item
# def extract_text(item: Dict[str, Any]) -> str:
# """Extract text content from a result item."""
# # Try different text fields based on category
# text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
# return str(text).strip()
# Batch items for concurrent processing
batches = []
for i in range(0, len(top_items), batch_size):
batch = top_items[i:i + batch_size]
batches.append(batch)
# # Batch items for concurrent processing
# batches = []
# for i in range(0, len(top_items), batch_size):
# batch = top_items[i:i + batch_size]
# batches.append(batch)
# Process batches concurrently
async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Process a batch of items with LLM relevance scoring."""
scored_batch = []
# # Process batches concurrently
# async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# """Process a batch of items with LLM relevance scoring."""
# scored_batch = []
for item in batch:
item_text = extract_text(item)
# for item in batch:
# item_text = extract_text(item)
# Skip items with no text
if not item_text:
item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item_copy["final_score"] = combined_score
item_copy["llm_relevance_score"] = 0.0
item_copy["reranker_model"] = model_name
scored_batch.append(item_copy)
continue
# # Skip items with no text
# if not item_text:
# item_copy = item.copy()
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
# item_copy["final_score"] = combined_score
# item_copy["llm_relevance_score"] = 0.0
# item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy)
# continue
# Create relevance scoring prompt
prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
# # Create relevance scoring prompt
# prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
Query: {query_text}
# Query: {query_text}
Result: {item_text}
# Result: {item_text}
Respond with only a number between 0.0 and 1.0, where:
- 0.0 means completely irrelevant
- 1.0 means perfectly relevant
# Respond with only a number between 0.0 and 1.0, where:
# - 0.0 means completely irrelevant
# - 1.0 means perfectly relevant
Relevance score:"""
# Relevance score:"""
# Send request to LLM
try:
messages = [{"role": "user", "content": prompt}]
response = await reranker_client.chat(messages)
# # Send request to LLM
# try:
# messages = [{"role": "user", "content": prompt}]
# response = await reranker_client.chat(messages)
# Parse LLM response to extract relevance score
response_text = str(response.content if hasattr(response, 'content') else response).strip()
# # Parse LLM response to extract relevance score
# response_text = str(response.content if hasattr(response, 'content') else response).strip()
# Try to extract a float from the response
try:
# Remove any non-numeric characters except decimal point
import re
score_match = re.search(r'(\d+\.?\d*)', response_text)
if score_match:
llm_score = float(score_match.group(1))
# Clamp to [0.0, 1.0]
llm_score = max(0.0, min(1.0, llm_score))
else:
raise ValueError("No numeric score found in response")
except (ValueError, AttributeError) as e:
logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
llm_score = None
# # Try to extract a float from the response
# try:
# # Remove any non-numeric characters except decimal point
# import re
# score_match = re.search(r'(\d+\.?\d*)', response_text)
# if score_match:
# llm_score = float(score_match.group(1))
# # Clamp to [0.0, 1.0]
# llm_score = max(0.0, min(1.0, llm_score))
# else:
# raise ValueError("No numeric score found in response")
# except (ValueError, AttributeError) as e:
# logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
# llm_score = None
# Calculate final score
item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
# # Calculate final score
# item_copy = item.copy()
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
if llm_score is not None:
final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
item_copy["llm_relevance_score"] = llm_score
else:
# Use combined_score as fallback
final_score = combined_score
item_copy["llm_relevance_score"] = combined_score
# if llm_score is not None:
# final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
# item_copy["llm_relevance_score"] = llm_score
# else:
# # Use combined_score as fallback
# final_score = combined_score
# item_copy["llm_relevance_score"] = combined_score
item_copy["final_score"] = final_score
item_copy["reranker_model"] = model_name
scored_batch.append(item_copy)
except Exception as e:
logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score")
item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item_copy["final_score"] = combined_score
item_copy["llm_relevance_score"] = combined_score
item_copy["reranker_model"] = model_name
scored_batch.append(item_copy)
# item_copy["final_score"] = final_score
# item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy)
# except Exception as e:
# logger.warning(f"Error processing item in LLM reranking: {e}, using combined_score")
# item_copy = item.copy()
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
# item_copy["final_score"] = combined_score
# item_copy["llm_relevance_score"] = combined_score
# item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy)
return scored_batch
# return scored_batch
# Process all batches concurrently
try:
batch_tasks = [process_batch(batch) for batch in batches]
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
# # Process all batches concurrently
# try:
# batch_tasks = [process_batch(batch) for batch in batches]
# batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
# Merge batch results
scored_items = []
for result in batch_results:
if isinstance(result, Exception):
logger.warning(f"Batch processing failed: {result}")
continue
scored_items.extend(result)
# # Merge batch results
# scored_items = []
# for result in batch_results:
# if isinstance(result, Exception):
# logger.warning(f"Batch processing failed: {result}")
# continue
# scored_items.extend(result)
# Add remaining items (not in top K) with their combined_score as final_score
for item in remaining_items:
item_copy = item.copy()
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item_copy["final_score"] = combined_score
item_copy["reranker_model"] = model_name
scored_items.append(item_copy)
# # Add remaining items (not in top K) with their combined_score as final_score
# for item in remaining_items:
# item_copy = item.copy()
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
# item_copy["final_score"] = combined_score
# item_copy["reranker_model"] = model_name
# scored_items.append(item_copy)
# Sort all items by final_score in descending order
scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
reranked_results[category] = scored_items
# # Sort all items by final_score in descending order
# scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
# reranked_results[category] = scored_items
except Exception as e:
logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
# Return original items with combined_score as final_score
for item in items:
combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
item["final_score"] = combined_score
item["reranker_model"] = model_name
reranked_results[category] = items
# except Exception as e:
# logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
# # Return original items with combined_score as final_score
# for item in items:
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
# item["final_score"] = combined_score
# item["reranker_model"] = model_name
# reranked_results[category] = items
return reranked_results
# return reranked_results
async def run_hybrid_search(
@@ -547,6 +558,7 @@ async def run_hybrid_search(
limit: int,
include: List[str],
output_path: str | None,
memory_config: "MemoryConfig",
rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
@@ -554,10 +566,14 @@ async def run_hybrid_search(
"""
Run search with specified type: 'keyword', 'embedding', or 'hybrid'
Args:
memory_config: MemoryConfig object containing embedding_model_id and config_id
"""
# Start overall timing
search_start_time = time.time()
latency_metrics = {}
logger.info(f"using embedding_id:{memory_config.embedding_model_id}...")
# Clean and normalize the incoming query before use/logging
query_text = extract_plain_query(query_text)
@@ -610,7 +626,9 @@ async def run_hybrid_search(
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
config_load_start = time.time()
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
rb_config = RedBearModelConfig(
model_name=embedder_config_dict["model_name"],
provider=embedder_config_dict["provider"],
@@ -672,7 +690,7 @@ async def run_hybrid_search(
if use_forgetting_rerank:
# Load forgetting parameters from pipeline config
try:
pc = get_pipeline_config()
pc = get_pipeline_config(memory_config)
forgetting_cfg = pc.forgetting_engine
except Exception as e:
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
@@ -700,16 +718,16 @@ async def run_hybrid_search(
# Apply LLM reranking if enabled
llm_rerank_applied = False
if use_llm_rerank:
try:
reranked_results = await apply_llm_reranker(
results=reranked_results,
query_text=query_text,
)
llm_rerank_applied = True
logger.info("LLM reranking applied successfully")
except Exception as e:
logger.warning(f"LLM reranking failed: {e}, using previous scores")
# if use_llm_rerank:
# try:
# reranked_results = await apply_llm_reranker(
# results=reranked_results,
# query_text=query_text,
# )
# llm_rerank_applied = True
# logger.info("LLM reranking applied successfully")
# except Exception as e:
# logger.warning(f"LLM reranking failed: {e}, using previous scores")
results["reranked_results"] = reranked_results
results["combined_summary"] = {
@@ -759,18 +777,11 @@ async def run_hybrid_search(
else:
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
completion_log = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"query": query_text,
"search_type": search_type,
"status": "completed",
"result_counts": result_counts,
"output_file": output_path,
"latency_metrics": latency_metrics
}
with open("search_log.txt", "a", encoding="utf-8") as f:
f.write(json.dumps(completion_log, ensure_ascii=False) + "\n")
# Log completion using the standard logger
logger.info(
f"Search completed: query='{query_text}', type={search_type}, "
f"result_counts={result_counts}, latency={latency_metrics}"
)
return results
@@ -892,89 +903,95 @@ async def search_chunk_by_chunk_id(
return {"chunks": chunks}
def main():
"""Main entry point for the hybrid graph search CLI.
# def main():
# """Main entry point for the hybrid graph search CLI.
Parses command line arguments and executes search with specified parameters.
Supports keyword, embedding, and hybrid search modes.
"""
parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
parser.add_argument(
"--query", "-q", required=True, help="Free-text query to search"
)
parser.add_argument(
"--search-type",
"-t",
choices=["keyword", "embedding", "hybrid"],
default="hybrid",
help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
)
parser.add_argument(
"--embedding-name",
"-m",
default="openai/nomic-embed-text:v1.5",
help="Embedding config name from config.json (default: openai/nomic-embed-text:v1.5)",
)
parser.add_argument(
"--group-id",
"-g",
default=None,
help="Optional group_id to filter results (default: None)",
)
parser.add_argument(
"--limit",
"-k",
type=int,
default=5,
help="Max number of results per type (default: 5)",
)
parser.add_argument(
"--include",
"-i",
nargs="+",
default=["statements", "chunks", "entities", "summaries"],
choices=["statements", "chunks", "entities", "summaries"],
help="Which targets to search for embedding search (default: statements chunks entities summaries)"
)
parser.add_argument(
"--output",
"-o",
default="search_results.json",
help="Path to save the search results JSON (default: search_results.json)",
)
parser.add_argument(
"--rerank-alpha",
"-a",
type=float,
default=0.6,
help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
)
parser.add_argument(
"--forgetting-rerank",
action="store_true",
help="Apply forgetting curve during reranking for hybrid search.",
)
parser.add_argument(
"--llm-rerank",
action="store_true",
help="Apply LLM-based reranking for hybrid search.",
)
args = parser.parse_args()
# Parses command line arguments and executes search with specified parameters.
# Supports keyword, embedding, and hybrid search modes.
# """
# parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
# parser.add_argument(
# "--query", "-q", required=True, help="Free-text query to search"
# )
# parser.add_argument(
# "--search-type",
# "-t",
# choices=["keyword", "embedding", "hybrid"],
# default="hybrid",
# help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
# )
# parser.add_argument(
# "--config-id",
# "-c",
# type=int,
# required=True,
# help="Database configuration ID (required)",
# )
# parser.add_argument(
# "--group-id",
# "-g",
# default=None,
# help="Optional group_id to filter results (default: None)",
# )
# parser.add_argument(
# "--limit",
# "-k",
# type=int,
# default=5,
# help="Max number of results per type (default: 5)",
# )
# parser.add_argument(
# "--include",
# "-i",
# nargs="+",
# default=["statements", "chunks", "entities", "summaries"],
# choices=["statements", "chunks", "entities", "summaries"],
# help="Which targets to search for embedding search (default: statements chunks entities summaries)"
# )
# parser.add_argument(
# "--output",
# "-o",
# default="search_results.json",
# help="Path to save the search results JSON (default: search_results.json)",
# )
# parser.add_argument(
# "--rerank-alpha",
# "-a",
# type=float,
# default=0.6,
# help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
# )
# parser.add_argument(
# "--forgetting-rerank",
# action="store_true",
# help="Apply forgetting curve during reranking for hybrid search.",
# )
# parser.add_argument(
# "--llm-rerank",
# action="store_true",
# help="Apply LLM-based reranking for hybrid search.",
# )
# args = parser.parse_args()
asyncio.run(
run_hybrid_search(
query_text=args.query,
search_type=args.search_type,
group_id=args.group_id,
limit=args.limit,
include=args.include,
output_path=args.output,
rerank_alpha=args.rerank_alpha,
use_forgetting_rerank=args.forgetting_rerank,
use_llm_rerank=args.llm_rerank,
)
)
# # Load memory config from database
# from app.services.memory_config_service import MemoryConfigService
# memory_config = MemoryConfigService.load_memory_config(args.config_id)
# asyncio.run(
# run_hybrid_search(
# query_text=args.query,
# search_type=args.search_type,
# group_id=args.group_id,
# limit=args.limit,
# include=args.include,
# output_path=args.output,
# memory_config=memory_config,
# rerank_alpha=args.rerank_alpha,
# use_forgetting_rerank=args.forgetting_rerank,
# use_llm_rerank=args.llm_rerank,
# )
# )
if __name__ == "__main__":
main()
# if __name__ == "__main__":
# main()

View File

@@ -1,19 +1,22 @@
"""
去重功能函数
"""
from app.core.memory.models.variate_config import DedupConfig
from typing import List, Dict, Tuple, Any
from app.core.memory.models.graph_models import(
StatementEntityEdge,
EntityEntityEdge,
ExtractedEntityNode
)
import os
from datetime import datetime
import difflib # 提供字符串相似度计算工具
import asyncio
import difflib # 提供字符串相似度计算工具
import importlib
import os
import re
from datetime import datetime
from typing import Any, Dict, List, Tuple
from app.core.memory.models.graph_models import (
EntityEntityEdge,
ExtractedEntityNode,
StatementEntityEdge,
)
from app.core.memory.models.variate_config import DedupConfig
# 模块级类型统一工具函数
def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None:
"""统一实体类型基于LLM建议或启发式规则选择最合适的类型。
@@ -705,7 +708,8 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
id_redirect: Dict[str, str],
config: DedupConfig | None = None,
config: DedupConfig,
llm_client = None,
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], List[str]]:
"""
基于迭代分块并发的 LLM 判定,生成实体重定向并在本地应用融合。
@@ -717,26 +721,13 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
"""
llm_records: List[str] = []
try:
# 优先使用运行时配置;若未提供配置,使用模型默认值,不再回退到环境变量
enable_switch = (
bool(config.enable_llm_dedup_blockwise) if config is not None else DedupConfig().enable_llm_dedup_blockwise
)
if not enable_switch:
return deduped_entities, id_redirect, llm_records
# 从配置读取 LLM 迭代参数;若无配置则使用 DedupConfig 的默认值
_defaults = DedupConfig()
block_size = (config.llm_block_size if config is not None else _defaults.llm_block_size)
block_concurrency = (config.llm_block_concurrency if config is not None else _defaults.llm_block_concurrency)
pair_concurrency = (config.llm_pair_concurrency if config is not None else _defaults.llm_pair_concurrency)
max_rounds = (config.llm_max_rounds if config is not None else _defaults.llm_max_rounds)
# 动态导入 llm 客户端(修正导入路径)
try:
llm_utils_mod = importlib.import_module("app.core.memory.utils.llm.llm_utils")
get_llm_client_fn = llm_utils_mod.get_llm_client
except Exception as e:
llm_records.append(f"[LLM错误] 无法导入 llm_utils 模块: {e}")
if not bool(config.enable_llm_dedup_blockwise):
return deduped_entities, id_redirect, llm_records
# 从配置读取 LLM 迭代参数
block_size = config.llm_block_size
block_concurrency = config.llm_block_concurrency
pair_concurrency = config.llm_pair_concurrency
max_rounds = config.llm_max_rounds
try:
llm_mod = importlib.import_module("app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm")
@@ -745,14 +736,9 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
llm_records.append(f"[LLM错误] 无法导入 entity_dedup_llm 模块: {e}")
return deduped_entities, id_redirect, llm_records
# 获取 LLM 客户端
try:
llm_client = get_llm_client_fn()
if llm_client is None:
llm_records.append("[LLM错误] LLM 客户端初始化失败:返回 None")
return deduped_entities, id_redirect, llm_records
except Exception as e:
llm_records.append(f"[LLM错误] 获取 LLM 客户端失败: {e}")
# 验证 LLM 客户端
if llm_client is None:
llm_records.append("[LLM错误] LLM 客户端未提供")
return deduped_entities, id_redirect, llm_records
llm_redirect, llm_records = await llm_fn(
@@ -813,7 +799,8 @@ async def LLM_disamb_decision(
statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
id_redirect: Dict[str, str],
config: DedupConfig | None = None,
config: DedupConfig,
llm_client = None,
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], set[tuple[str, str]], List[str]]:
"""
预消歧阶段对“同名但类型不同”的实体对调用LLM进行消歧
@@ -824,22 +811,16 @@ async def LLM_disamb_decision(
disamb_records: List[str] = []
blocked_pairs: set[tuple[str, str]] = set()
try:
enable_switch = (
config.enable_llm_disambiguation
if config is not None
else DedupConfig().enable_llm_disambiguation
)
if not bool(enable_switch):
if not bool(config.enable_llm_disambiguation):
return deduped_entities, id_redirect, blocked_pairs, disamb_records
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import llm_disambiguate_pairs_iterative
from app.core.memory.utils.config import definitions as config_defs
from app.core.memory.storage_services.extraction_engine.deduplication.entity_dedup_llm import (
llm_disambiguate_pairs_iterative,
)
# 获取 LLM 客户端并验证
llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
# 验证 LLM 客户端
if llm_client is None:
disamb_records.append("[DISAMB错误] LLM 客户端初始化失败:返回 None")
disamb_records.append("[DISAMB错误] LLM 客户端未提供")
return deduped_entities, id_redirect, blocked_pairs, disamb_records
merge_redirect, block_list, disamb_records = await llm_disambiguate_pairs_iterative(
@@ -895,6 +876,7 @@ async def deduplicate_entities_and_edges(
report_append: bool = False,
report_stage_notes: List[str] | None = None,
dedup_config: DedupConfig | None = None,
llm_client = None,
) -> Tuple[
List[ExtractedEntityNode],
List[StatementEntityEdge],
@@ -911,7 +893,7 @@ async def deduplicate_entities_and_edges(
# 1.5) LLM 决策消歧:阻断同名不同类型的高相似对,并应用必要的合并
deduped_entities, id_redirect, blocked_pairs, disamb_records = await LLM_disamb_decision(
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config, llm_client=llm_client
)
# 2) 模糊匹配(本地规则)
@@ -936,7 +918,7 @@ async def deduplicate_entities_and_edges(
if should_trigger_llm:
deduped_entities, id_redirect, llm_decision_records = await LLM_decision(
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config
deduped_entities, statement_entity_edges, entity_entity_edges, id_redirect, config=dedup_config, llm_client=llm_client
)
else:
llm_decision_records = []

View File

@@ -10,15 +10,27 @@
from __future__ import annotations
from typing import List, Dict, Any, Tuple
from datetime import datetime
from typing import Any, Dict, List, Tuple
from app.core.memory.models.graph_models import (
EntityEntityEdge,
ExtractedEntityNode,
StatementEntityEdge,
)
from app.core.memory.models.variate_config import DedupConfig
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import ( # 导入报告写入以在跳过时追加说明
_write_dedup_fusion_report,
deduplicate_entities_and_edges,
)
from app.repositories.neo4j.graph_search import (
get_dedup_candidates_for_entities, # 导入ge函数用于从 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。
)
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector # 导入 Neo4j 数据库连接器类,用于与 Neo4j 数据库进行交互
from app.repositories.neo4j.graph_search import get_dedup_candidates_for_entities # 导入ge函数,用于 Neo4j 中检索与输入实体可能重复的候选实体(去重的核心检索逻辑)。
from app.core.memory.models.graph_models import ExtractedEntityNode, StatementEntityEdge, EntityEntityEdge
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges, _write_dedup_fusion_report # 导入报告写入以在跳过时追加说明
from app.core.memory.models.variate_config import DedupConfig
from app.repositories.neo4j.neo4j_connector import (
Neo4jConnector, # 导入 Neo4j 数据库连接器类,用于 Neo4j 数据库进行交互
)
def _parse_dt(val: Any) -> datetime: # 定义内部辅助函数_parse_dt用于将任意类型的输入值解析为datetime对象处理实体节点中的时间字段
@@ -72,6 +84,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
dedup_config: DedupConfig | None = None,
llm_client = None,
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
"""
第二层去重消歧:
@@ -137,13 +150,14 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
union_entities: List[ExtractedEntityNode] = db_candidate_models + list(entity_nodes)
# 融合(内部执行精确/模糊/LLM 决策;随后再做边重定向与去重)
fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges = await deduplicate_entities_and_edges(
fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges, _ = await deduplicate_entities_and_edges(
union_entities,
statement_entity_edges,
entity_entity_edges,
report_stage="第二层去重消歧",
report_append=True,
dedup_config=dedup_config,
llm_client=llm_client,
)
return fused_entities, fused_stmt_entity_edges, fused_entity_entity_edges

View File

@@ -1,23 +1,27 @@
from __future__ import annotations
from typing import List, Tuple, Optional
from typing import List, Optional, Tuple
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.utils.config.config_utils import get_pipeline_config
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import second_layer_dedup_and_merge_with_neo4j
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.models.graph_models import (
DialogueNode,
ChunkNode,
StatementNode,
DialogueNode,
EntityEntityEdge,
ExtractedEntityNode,
StatementChunkEdge,
StatementEntityEdge,
EntityEntityEdge,
StatementNode,
)
from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import ExtractionPipelineConfig
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
deduplicate_entities_and_edges,
)
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
second_layer_dedup_and_merge_with_neo4j,
)
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def dedup_layers_and_merge_and_return(
@@ -29,8 +33,9 @@ async def dedup_layers_and_merge_and_return(
statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge],
dialog_data_list: List[DialogData],
pipeline_config: Optional[ExtractionPipelineConfig] = None,
pipeline_config: ExtractionPipelineConfig,
connector: Optional[Neo4jConnector] = None,
llm_client = None,
) -> Tuple[
List[DialogueNode],
List[ChunkNode],
@@ -48,12 +53,9 @@ async def dedup_layers_and_merge_and_return(
返回融合后的实体与边,同时保留原始的对话、片段与语句节点与边。
"""
# 默认从 runtime.json 加载管线配置,避免回退到环境变量
# pipeline_config is required - caller must provide it
if pipeline_config is None:
try:
pipeline_config = get_pipeline_config()
except Exception:
pipeline_config = None
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
# 先探测 group_id决定报告写入策略
group_id: Optional[str] = None
@@ -70,6 +72,7 @@ async def dedup_layers_and_merge_and_return(
report_stage="第一层去重消歧",
report_append=False,
dedup_config=(pipeline_config.deduplication if pipeline_config else None),
llm_client=llm_client,
)
# 初始化第二层融合结果为第一层结果
@@ -88,6 +91,7 @@ async def dedup_layers_and_merge_and_return(
statement_entity_edges=dedup_statement_entity_edges,
entity_entity_edges=dedup_entity_entity_edges,
dedup_config=(pipeline_config.deduplication if pipeline_config else None),
llm_client=llm_client,
)
else:
print("Skip second-layer dedup: missing connector")

View File

@@ -19,48 +19,49 @@
import asyncio
import logging
import os
from typing import List, Dict, Any, Tuple, Optional, Callable, Awaitable
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from app.core.memory.models.message_models import DialogData
from app.core.memory.llm_tools.openai_client import LLMClient
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.graph_models import (
DialogueNode,
ChunkNode,
StatementNode,
DialogueNode,
EntityEntityEdge,
ExtractedEntityNode,
StatementChunkEdge,
StatementEntityEdge,
EntityEntityEdge,
StatementNode,
)
from app.core.memory.utils.data.ontology import TemporalInfo
from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import (
ExtractionPipelineConfig,
)
from app.core.memory.llm_tools.openai_client import LLMClient
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
dedup_layers_and_merge_and_return,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
embedding_generation,
embedding_generation_all,
generate_entity_embeddings_from_triplets,
)
# 导入各个提取模块
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
StatementExtractor,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.triplet_extraction import (
TripletExtractor,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.temporal_extraction import (
TemporalExtractor,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
embedding_generation,
generate_entity_embeddings_from_triplets,
)
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
dedup_layers_and_merge_and_return,
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.triplet_extraction import (
TripletExtractor,
)
from app.core.memory.storage_services.extraction_engine.pipeline_help import (
_write_extracted_result_summary,
export_test_input_doc,
)
from app.core.memory.utils.data.ontology import TemporalInfo
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# 配置日志
logger = logging.getLogger(__name__)
@@ -94,6 +95,7 @@ class ExtractionOrchestrator:
connector: Neo4jConnector,
config: Optional[ExtractionPipelineConfig] = None,
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
embedding_id: Optional[str] = None,
):
"""
初始化流水线编排器
@@ -106,6 +108,7 @@ class ExtractionOrchestrator:
progress_callback: 进度回调函数
- 接受 (stage: str, message: str, data: Optional[Dict[str, Any]]) 并返回 Awaitable[None]
- 在管线关键点调用以报告进度和结果数据
embedding_id: 嵌入模型ID如果为 None 则从全局配置获取(向后兼容)
"""
self.llm_client = llm_client
self.embedder_client = embedder_client
@@ -113,6 +116,7 @@ class ExtractionOrchestrator:
self.config = config or ExtractionPipelineConfig()
self.is_pilot_run = False # 默认非试运行模式
self.progress_callback = progress_callback # 保存进度回调函数
self.embedding_id = embedding_id # 保存嵌入模型ID
# 保存去重消歧的详细记录(内存中的数据结构)
self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录
@@ -398,7 +402,9 @@ class ExtractionOrchestrator:
except Exception as e:
logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}")
completed_statements += 1
from app.core.memory.models.triplet_models import TripletExtractionResponse
from app.core.memory.models.triplet_models import (
TripletExtractionResponse,
)
return TripletExtractionResponse(triplets=[], entities=[])
tasks = [extract_for_statement(stmt_data, i) for i, stmt_data in enumerate(all_statements)]
@@ -412,7 +418,9 @@ class ExtractionOrchestrator:
d_idx, stmt_id = statement_metadata[i]
if isinstance(result, Exception):
logger.error(f"陈述句处理异常: {result}")
from app.core.memory.models.triplet_models import TripletExtractionResponse
from app.core.memory.models.triplet_models import (
TripletExtractionResponse,
)
triplet_maps[d_idx][stmt_id] = TripletExtractionResponse(triplets=[], entities=[])
else:
triplet_maps[d_idx][stmt_id] = result
@@ -525,8 +533,8 @@ class ExtractionOrchestrator:
temporal_maps[d_idx][stmt_id] = result
# 为 ATEMPORAL 陈述句添加空的时间范围
from app.core.memory.utils.data.ontology import TemporalInfo
from app.core.memory.models.message_models import TemporalValidityRange
from app.core.memory.utils.data.ontology import TemporalInfo
for d_idx, dialog in enumerate(dialog_data_list):
for chunk in dialog.chunks:
for statement in chunk.statements:
@@ -738,17 +746,14 @@ class ExtractionOrchestrator:
logger.info("开始生成基础嵌入向量(陈述句、分块、对话)")
try:
# 从 runtime.json 获取嵌入模型配置ID
from app.core.memory.utils.config import definitions as config_defs
embedding_id = config_defs.SELECTED_EMBEDDING_ID
if not embedding_id:
logger.error("未在 runtime.json 中配置 embedding 模型 ID")
raise ValueError("未配置嵌入模型ID")
# embedding_id is required - no fallback to global variable
if not self.embedding_id:
logger.error("embedding_id is required but was not provided to ExtractionOrchestrator")
raise ValueError("embedding_id is required but was not provided")
# 只生成陈述句、分块和对话的嵌入(不包括实体)
statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = await embedding_generation(
dialog_data_list, embedding_id
dialog_data_list, self.embedding_id
)
# 统计生成结果
@@ -792,17 +797,14 @@ class ExtractionOrchestrator:
logger.info("开始生成实体嵌入向量")
try:
# 从 runtime.json 获取嵌入模型配置ID
from app.core.memory.utils.config import definitions as config_defs
embedding_id = config_defs.SELECTED_EMBEDDING_ID
if not embedding_id:
logger.error("未在 runtime.json 中配置 embedding 模型 ID")
# embedding_id is required - no fallback to global variable
if not self.embedding_id:
logger.error("embedding_id is required but was not provided to ExtractionOrchestrator")
return triplet_maps
# 生成实体嵌入
updated_triplet_maps = await generate_entity_embeddings_from_triplets(
triplet_maps, embedding_id
triplet_maps, self.embedding_id
)
logger.info("实体嵌入生成完成")
@@ -1240,7 +1242,9 @@ class ExtractionOrchestrator:
if self.is_pilot_run:
logger.info("试运行模式:仅执行第一层去重,跳过第二层数据库去重")
# 只执行第一层去重
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
deduplicate_entities_and_edges,
)
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
entity_nodes,
@@ -1249,6 +1253,7 @@ class ExtractionOrchestrator:
report_stage="第一层去重消歧(试运行)",
report_append=False,
dedup_config=self.config.deduplication,
llm_client=self.llm_client,
)
# 保存去重消歧的详细记录到实例变量
@@ -1280,6 +1285,7 @@ class ExtractionOrchestrator:
dialog_data_list,
self.config,
self.connector,
llm_client=self.llm_client,
)
# 解包返回值
@@ -1824,7 +1830,9 @@ async def get_chunked_dialogs(
)
# 创建分块器并处理对话
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
)
chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data)
dialog_data.chunks = extracted_chunks
@@ -1871,7 +1879,9 @@ def preprocess_data(
经过清洗转换后的 DialogData 列表
"""
print("\n=== 数据预处理 ===")
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import DataPreprocessor
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
DataPreprocessor,
)
preprocessor = DataPreprocessor()
try:
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
@@ -1902,7 +1912,9 @@ async def get_chunked_dialogs_from_preprocessed(
raise ValueError("预处理数据为空,无法进行分块")
all_chunked_dialogs: List[DialogData] = []
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
)
for dialog_data in data:
chunker = DialogueChunker(chunker_strategy, llm_client=llm_client)
@@ -1963,7 +1975,9 @@ async def get_chunked_dialogs_with_preprocessing(
# 步骤2: 语义剪枝
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
SemanticPruner,
)
pruner = SemanticPruner(llm_client=llm_client)
# 记录单对话场景下剪枝前的消息数量
@@ -1986,7 +2000,9 @@ async def get_chunked_dialogs_with_preprocessing(
# 保存剪枝后的数据
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import DataPreprocessor
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
DataPreprocessor,
)
pruned_output_path = settings.get_memory_output_path("pruned_data.json")
dp = DataPreprocessor(output_file_path=pruned_output_path)
dp.save_data(preprocessed_data, output_path=pruned_output_path)

View File

@@ -5,11 +5,13 @@
"""
import asyncio
from typing import List, Dict, Any, Tuple
from app.core.memory.models.message_models import DialogData
from typing import Any, Dict, List, Tuple
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.models.message_models import DialogData
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
class EmbeddingGenerator:
@@ -21,7 +23,9 @@ class EmbeddingGenerator:
Args:
embedding_id: 嵌入模型 ID
"""
embedder_config = get_embedder_config(embedding_id)
with get_db_context() as db:
config_service = MemoryConfigService(db)
embedder_config = config_service.get_embedder_config(embedding_id)
self.embedder_client = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(embedder_config),
)

View File

@@ -1,21 +1,17 @@
import os
import asyncio
from datetime import datetime
from typing import List, Optional
from pydantic import Field, field_validator
from uuid import uuid4
from app.core.logging_config import get_memory_logger
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.base_response import RobustLLMResponse
from app.core.memory.models.graph_models import MemorySummaryNode
from app.core.memory.models.message_models import DialogData
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
from pydantic import Field
logger = get_memory_logger(__name__)
from app.core.memory.models.graph_models import MemorySummaryNode
from app.core.memory.models.base_response import RobustLLMResponse
from app.core.models.base import RedBearModelConfig
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
from uuid import uuid4
class MemorySummaryResponse(RobustLLMResponse):
@@ -91,22 +87,17 @@ async def _process_chunk_summary(
return None
async def Memory_summary_generation(
async def memory_summary_generation(
chunked_dialogs: List[DialogData],
llm_client,
embedding_id,
embedder_client: OpenAIEmbedderClient,
) -> List[MemorySummaryNode]:
"""Generate memory summaries per chunk, embed them, and return nodes."""
embedder_cfg_dict = get_embedder_config(embedding_id)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(embedder_cfg_dict),
)
# Collect all tasks for parallel processing
tasks = []
for dialog in chunked_dialogs:
for chunk in dialog.chunks:
tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder))
tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder_client))
# Process all chunks in parallel
results = await asyncio.gather(*tasks, return_exceptions=False)

View File

@@ -1,17 +1,21 @@
import os
import asyncio
import logging
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.models.message_models import DialogData, Statement
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
from app.core.memory.models.variate_config import StatementExtractionConfig
from app.core.memory.utils.data.ontology import (
LABEL_DEFINITIONS,
RelevenceInfo,
StatementType,
TemporalInfo,
)
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS, StatementType, TemporalInfo, RelevenceInfo
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)

View File

@@ -8,35 +8,43 @@
4. 反思结果应用 - 更新记忆库
"""
import asyncio
import json
import logging
import asyncio
import os
import time
from typing import List, Dict, Any, Optional
from enum import Enum
import uuid
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.config import definitions as config_defs
from app.core.memory.utils.config import get_model_config
from app.core.memory.utils.config.get_data import get_data,get_data_statement,extract_and_process_changes
from app.core.memory.utils.config.get_data import (
extract_and_process_changes,
get_data,
get_data_statement,
)
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
from app.core.memory.utils.prompt.template_render import (
render_evaluate_prompt,
render_reflexion_prompt,
)
from app.core.models.base import RedBearModelConfig
from app.core.response_utils import success
from app.repositories.neo4j.cypher_queries import (
UPDATE_STATEMENT_INVALID_AT,
neo4j_query_all,
neo4j_query_part,
neo4j_statement_all,
neo4j_statement_part,
)
from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.neo4j_update import neo4j_data
from app.schemas.memory_storage_schema import ConflictResultSchema
from app.schemas.memory_storage_schema import ReflexionResultSchema
from app.schemas.memory_storage_schema import (
ConflictResultSchema,
ReflexionResultSchema,
)
from pydantic import BaseModel
# 配置日志
_root_logger = logging.getLogger()
@@ -152,12 +160,26 @@ class ReflectionEngine:
self.neo4j_connector = Neo4jConnector()
if self.llm_client is None:
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
from app.core.memory.utils.config import definitions as config_defs
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
with get_db_context() as db:
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
elif isinstance(self.llm_client, str):
# 如果 llm_client 是字符串model_id则用它初始化客户端
# from app.core.memory.utils.llm.llm_utils import get_llm_client
# model_id = self.llm_client
# self.llm_client = get_llm_client(model_id)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
model_id = self.llm_client
with get_db_context() as db:
factory = MemoryClientFactory(db)
# self.llm_client = factory.get_llm_client(model_id)
# Use MemoryConfigService to get model config
config_service = MemoryConfigService(db)
model_config = config_service.get_model_config(model_id)
extra_params={
"temperature": 0.2, # 降低温度提高响应速度和一致性
"max_tokens": 600, # 限制最大token数
@@ -165,7 +187,6 @@ class ReflectionEngine:
"stream": False, # 确保非流式输出以获得最快响应
}
model_config = get_model_config(self.llm_client)
self.llm_client = OpenAIClient(RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
@@ -184,9 +205,15 @@ class ReflectionEngine:
self.get_data_statement = get_data_statement
if self.render_evaluate_prompt_func is None:
from app.core.memory.utils.prompt.template_render import (
render_evaluate_prompt,
)
self.render_evaluate_prompt_func = render_evaluate_prompt
if self.render_reflexion_prompt_func is None:
from app.core.memory.utils.prompt.template_render import (
render_reflexion_prompt,
)
self.render_reflexion_prompt_func = render_reflexion_prompt
if self.conflict_schema is None:
@@ -196,6 +223,9 @@ class ReflectionEngine:
self.reflexion_schema = ReflexionResultSchema
if self.update_query is None:
from app.repositories.neo4j.cypher_queries import (
UPDATE_STATEMENT_INVALID_AT,
)
self.update_query = UPDATE_STATEMENT_INVALID_AT
self._lazy_init_done = True

View File

@@ -4,10 +4,20 @@
本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。
"""
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
from app.core.memory.storage_services.search.search_strategy import (
SearchResult,
SearchStrategy,
)
from app.core.memory.storage_services.search.semantic_search import (
SemanticSearchStrategy,
)
__all__ = [
"SearchStrategy",
@@ -34,7 +44,7 @@ async def run_hybrid_search(
include: list[str] | None = None,
alpha: float = 0.6,
use_forgetting_curve: bool = False,
embedding_id: str | None = None,
memory_config: "MemoryConfig" = None,
**kwargs
) -> dict:
"""运行混合搜索向后兼容的函数式API
@@ -51,24 +61,26 @@ async def run_hybrid_search(
include: 要包含的搜索类别列表
alpha: BM25分数权重0.0-1.0
use_forgetting_curve: 是否使用遗忘曲线
embedding_id: 嵌入模型ID
memory_config: MemoryConfig object containing embedding_model_id
**kwargs: 其他参数
Returns:
dict: 搜索结果字典格式与旧API兼容
"""
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.utils.config import definitions as config_defs
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 使用提供的embedding_id或默认值
emb_id = embedding_id or config_defs.SELECTED_EMBEDDING_ID
if not memory_config:
raise ValueError("memory_config is required for search")
# 初始化客户端
connector = Neo4jConnector()
embedder_config_dict = get_embedder_config(emb_id)
with get_db_context() as db:
config_service = MemoryConfigService(db)
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)

View File

@@ -1,447 +0,0 @@
# TODO hybrid_chatbot.py 是一个独立的GUI演示应用不是核心功能的一部分可以考虑删除
from app.core.memory.utils.llm.llm_utils import get_llm_client
import asyncio
import os
import time
import json
from datetime import datetime, timezone
import tkinter as tk
from tkinter import scrolledtext, messagebox
import threading
from typing import Any, Dict, Tuple, List
# Import our hybrid search functionality
from app.core.memory.storage_services.search import run_hybrid_search
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.models.config_models import LLMConfig
from dotenv import load_dotenv
load_dotenv()
class HybridSearchChatbot:
def __init__(self):
from app.core.memory.utils.config import definitions as config_defs
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
# Chat history
self.chat_history = []
# Search configuration
self.search_config = {
"group_id": "group_wyl_25",
"limit": 10,
"include": ["statements", "chunks", "entities","summaries"],
# "include": ["statements", "dialogues", "entities"],
"rerank_alpha": 0.6
}
# Setup GUI
self.setup_gui()
def setup_gui(self):
"""Setup the GUI interface"""
self.root = tk.Tk()
self.root.title("Hybrid Search Chatbot")
self.root.geometry("800x600")
# Chat display area
self.chat_display = scrolledtext.ScrolledText(
self.root,
wrap=tk.WORD,
width=80,
height=25,
state=tk.DISABLED
)
self.chat_display.pack(padx=10, pady=10, fill=tk.BOTH, expand=True)
# Input frame
input_frame = tk.Frame(self.root)
input_frame.pack(padx=10, pady=5, fill=tk.X)
# User input
self.user_input = tk.Entry(input_frame, font=("Arial", 12))
self.user_input.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=(0, 5))
self.user_input.bind("<Return>", self.on_send_message)
# Send button
self.send_button = tk.Button(
input_frame,
text="发送",
command=self.on_send_message,
font=("Arial", 12)
)
self.send_button.pack(side=tk.RIGHT)
# Status frame
status_frame = tk.Frame(self.root)
status_frame.pack(padx=10, pady=5, fill=tk.X)
# Status label
self.status_label = tk.Label(
status_frame,
text="就绪",
font=("Arial", 10),
anchor="w"
)
self.status_label.pack(side=tk.LEFT, fill=tk.X, expand=True)
# Search config button
config_button = tk.Button(
status_frame,
text="搜索配置",
command=self.show_config_dialog,
font=("Arial", 10)
)
config_button.pack(side=tk.RIGHT)
# Add welcome message
self.add_message("系统", "欢迎使用混合搜索聊天机器人!我可以基于知识图谱中的信息回答您的问题。")
def add_message(self, sender: str, message: str, metadata: Dict = None):
"""Add a message to the chat display"""
self.chat_display.config(state=tk.NORMAL)
timestamp = datetime.now().strftime("%H:%M:%S")
# Add sender and timestamp
self.chat_display.insert(tk.END, f"[{timestamp}] {sender}:\n", "sender")
# Add message content
self.chat_display.insert(tk.END, f"{message}\n", "message")
# Add metadata if available
if metadata:
self.chat_display.insert(tk.END, f" {metadata}\n", "metadata")
self.chat_display.insert(tk.END, "\n")
self.chat_display.config(state=tk.DISABLED)
self.chat_display.see(tk.END)
# Configure text tags for styling
self.chat_display.tag_config("sender", foreground="blue", font=("Arial", 10, "bold"))
self.chat_display.tag_config("message", foreground="black", font=("Arial", 10))
self.chat_display.tag_config("metadata", foreground="gray", font=("Arial", 8))
def show_config_dialog(self):
"""Show search configuration dialog"""
config_window = tk.Toplevel(self.root)
config_window.title("搜索配置")
config_window.geometry("400x600")
config_window.transient(self.root)
config_window.grab_set()
# Current configuration display
current_config_frame = tk.Frame(config_window)
current_config_frame.pack(pady=10, padx=10, fill=tk.X)
tk.Label(current_config_frame, text="当前配置:", font=("Arial", 10, "bold")).pack(anchor="w")
current_text = f"Alpha: {self.search_config['rerank_alpha']}, 限制: {self.search_config['limit']}, 目标: {', '.join(self.search_config['include'])}"
tk.Label(current_config_frame, text=current_text, font=("Arial", 9), fg="blue").pack(anchor="w")
# Alpha parameter
tk.Label(config_window, text="重排权重 (Alpha):").pack(pady=(10, 5))
alpha_var = tk.DoubleVar(value=self.search_config["rerank_alpha"])
alpha_scale = tk.Scale(
config_window,
from_=0.0,
to=1.0,
resolution=0.1,
orient=tk.HORIZONTAL,
variable=alpha_var
)
alpha_scale.pack(pady=5, padx=20, fill=tk.X)
tk.Label(config_window, text="0.0=纯语义搜索, 1.0=纯关键词搜索", font=("Arial", 8)).pack()
# Limit parameter
tk.Label(config_window, text="搜索结果数量:").pack(pady=(20, 5))
limit_var = tk.IntVar(value=self.search_config["limit"])
limit_spinbox = tk.Spinbox(
config_window,
from_=1,
to=50,
textvariable=limit_var,
width=10
)
limit_spinbox.pack(pady=5)
# Include options
tk.Label(config_window, text="搜索目标:").pack(pady=(20, 5))
include_frame = tk.Frame(config_window)
include_frame.pack(pady=5)
include_vars = {}
for option in ["statements", "chunks", "entities","summaries"]:
var = tk.BooleanVar(value=option in self.search_config["include"])
include_vars[option] = var
tk.Checkbutton(
include_frame,
text=option,
variable=var
).pack(side=tk.LEFT, padx=10)
# Buttons
button_frame = tk.Frame(config_window)
button_frame.pack(pady=20)
def save_config():
try:
# Validate inputs
alpha_value = alpha_var.get()
limit_value = limit_var.get()
include_list = [
option for option, var in include_vars.items() if var.get()
]
# Check if at least one search target is selected
if not include_list:
messagebox.showerror("配置错误", "请至少选择一个搜索目标!")
return
# Update configuration
self.search_config["rerank_alpha"] = alpha_value
self.search_config["limit"] = limit_value
self.search_config["include"] = include_list
config_window.destroy()
self.add_message("系统",
f"配置已更新: Alpha={alpha_value:.1f}, 限制={limit_value}, 目标={', '.join(include_list)}")
except Exception as e:
messagebox.showerror("配置错误", f"保存配置时出错: {str(e)}")
print(f"Config save error: {e}") # Debug output
tk.Button(button_frame, text="保存", command=save_config).pack(side=tk.LEFT, padx=5)
tk.Button(button_frame, text="取消", command=config_window.destroy).pack(side=tk.LEFT, padx=5)
def on_send_message(self, event=None):
"""Handle sending a message"""
user_message = self.user_input.get().strip()
if not user_message:
return
# Clear input
self.user_input.delete(0, tk.END)
# Add user message to display
self.add_message("用户", user_message)
# Disable send button and show processing status
self.send_button.config(state=tk.DISABLED)
self.status_label.config(text="正在搜索和生成回复...")
# Process message in background thread
threading.Thread(
target=self.process_message_async,
args=(user_message,),
daemon=True
).start()
def process_message_async(self, user_message: str):
"""Process message asynchronously"""
try:
# Run the async processing
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
response, metadata = loop.run_until_complete(
self.process_message(user_message)
)
loop.close()
# Update GUI in main thread
self.root.after(0, self.on_response_ready, response, metadata)
except Exception as e:
error_msg = f"处理消息时出错: {str(e)}"
self.root.after(0, self.on_error, error_msg)
async def process_message(self, user_message: str) -> Tuple[str, Dict[str, Any]]:
"""Process user message with hybrid search"""
start_time = time.time()
# Perform hybrid search
search_start = time.time()
search_results = await run_hybrid_search(
query_text=user_message,
search_type="hybrid",
group_id=self.search_config["group_id"],
limit=self.search_config["limit"],
include=self.search_config["include"],
output_path=None,
rerank_alpha=self.search_config["rerank_alpha"]
)
search_time = time.time() - search_start
# Extract relevant information from search results
context_info = self.extract_context_from_search(search_results)
# Generate response using LLM
llm_start = time.time()
response = await self.generate_response(user_message, context_info)
llm_time = time.time() - llm_start
total_time = time.time() - start_time
# Prepare metadata
metadata = {
"搜索时间": f"{search_time:.2f}s",
"生成时间": f"{llm_time:.2f}s",
"总时间": f"{total_time:.2f}s",
"搜索结果": self.get_search_summary(search_results),
"重排权重": self.search_config["rerank_alpha"]
}
return response, metadata
def extract_context_from_search(self, search_results: Dict) -> str:
"""Extract context information from search results"""
if not search_results:
return "未找到相关信息。"
context_parts = []
# Get reranked results if available, otherwise use individual results
if "reranked_results" in search_results:
results = search_results["reranked_results"]
else:
results = {}
for key in ["keyword_search", "embedding_search"]:
if key in search_results:
for category, items in search_results[key].items():
if category not in results:
results[category] = []
results[category].extend(items)
# Extract statements
if "statements" in results and results["statements"]:
statements = results["statements"][:5] # Top 5
context_parts.append("相关陈述:")
for i, stmt in enumerate(statements, 1):
content = stmt.get("statement", "")
score = stmt.get("combined_score", stmt.get("score", 0))
context_parts.append(f"{i}. {content} (相关度: {score:.3f})")
# Extract chunks
if "chunks" in results and results["chunks"]:
chunks = results["chunks"][:3] # Top 3
context_parts.append("\n相关对话:")
for i, chunk in enumerate(chunks, 1):
content = chunk.get("content", "")
score = chunk.get("combined_score", chunk.get("score", 0))
context_parts.append(f"{i}. {content} (相关度: {score:.3f})")
# Extract entities
if "entities" in results and results["entities"]:
entities = results["entities"][:5] # Top 5
context_parts.append("\n相关实体:")
entity_names = [ent.get("name", "") for ent in entities]
context_parts.append(", ".join(entity_names))
return "\n".join(context_parts) if context_parts else "未找到相关信息。"
def get_search_summary(self, search_results: Dict) -> str:
"""Get a summary of search results"""
if not search_results:
return "无结果"
summary_parts = []
if "combined_summary" in search_results:
summary = search_results["combined_summary"]
if "total_reranked_results" in summary:
summary_parts.append(f"重排结果: {summary['total_reranked_results']}")
if "total_keyword_results" in summary:
summary_parts.append(f"关键词: {summary['total_keyword_results']}")
if "total_embedding_results" in summary:
summary_parts.append(f"语义: {summary['total_embedding_results']}")
return ", ".join(summary_parts) if summary_parts else "有结果"
async def generate_response(self, user_message: str, context: str) -> str:
"""Generate response using LLM"""
system_prompt = f"""你是一个智能助手,基于知识图谱中的信息回答用户问题。
以下是从知识图谱中检索到的相关信息:
{context}
请基于这些信息回答用户的问题。如果信息不足,请诚实地说明。回答要自然、友好,并且准确。"""
try:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
]
response = self.llm_client.chat(
messages=messages,
)
print(response)
# Extract content from various possible response types
# 1) LangChain AIMessage or similar object with `.content`
if hasattr(response, 'content'):
return getattr(response, 'content')
# 2) OpenAI-style response with `.choices`
if hasattr(response, 'choices') and response.choices:
first_choice = response.choices[0]
# Newer clients may have `.message.content`, some have `.content` directly
if hasattr(first_choice, 'message') and hasattr(first_choice.message, 'content'):
return first_choice.message.content
if hasattr(first_choice, 'content'):
return first_choice.content
# 3) Dict-like responses
if isinstance(response, dict):
if 'content' in response:
return response['content']
if 'choices' in response and response['choices']:
ch = response['choices'][0]
if isinstance(ch, dict):
if 'message' in ch and 'content' in ch['message']:
return ch['message']['content']
if 'content' in ch:
return ch['content']
# 4) Fallback: if it's a plain string
if isinstance(response, str):
return response
# Default fallback
return "抱歉,我无法生成回复。"
except Exception as e:
return f"生成回复时出错: {str(e)}"
def on_response_ready(self, response: str, metadata: Dict[str, Any]):
"""Handle when response is ready"""
self.add_message("助手", response, metadata)
self.send_button.config(state=tk.NORMAL)
self.status_label.config(text="就绪")
self.user_input.focus()
def on_error(self, error_message: str):
"""Handle errors"""
self.add_message("系统", f" {error_message}")
self.send_button.config(state=tk.NORMAL)
self.status_label.config(text="就绪")
self.user_input.focus()
def run(self):
"""Start the chatbot"""
self.root.mainloop()
def main():
"""Main function to run the chatbot"""
try:
chatbot = HybridSearchChatbot()
chatbot.run()
except Exception as e:
print(f"启动聊天机器人时出错: {e}")
if __name__ == "__main__":
main()

View File

@@ -1,408 +1,408 @@
# -*- coding: utf-8 -*-
"""混合搜索策略
# # -*- coding: utf-8 -*-
# """混合搜索策略
结合关键词搜索和语义搜索的混合检索方法。
支持结果重排序和遗忘曲线加权。
"""
# 结合关键词搜索和语义搜索的混合检索方法。
# 支持结果重排序和遗忘曲线加权。
# """
from typing import List, Dict, Any, Optional
import math
from datetime import datetime
from app.core.logging_config import get_memory_logger
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.variate_config import ForgettingEngineConfig
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
# from typing import List, Dict, Any, Optional
# import math
# from datetime import datetime
# from app.core.logging_config import get_memory_logger
# from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
# from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
# from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
# from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
# from app.core.memory.models.variate_config import ForgettingEngineConfig
# from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
logger = get_memory_logger(__name__)
# logger = get_memory_logger(__name__)
class HybridSearchStrategy(SearchStrategy):
"""混合搜索策略
# class HybridSearchStrategy(SearchStrategy):
# """混合搜索策略
结合关键词搜索和语义搜索的优势:
- 关键词搜索:精确匹配,适合已知术语
- 语义搜索:语义理解,适合概念查询
- 混合重排序:综合两种搜索的结果
- 遗忘曲线:根据时间衰减调整相关性
"""
# 结合关键词搜索和语义搜索的优势:
# - 关键词搜索:精确匹配,适合已知术语
# - 语义搜索:语义理解,适合概念查询
# - 混合重排序:综合两种搜索的结果
# - 遗忘曲线:根据时间衰减调整相关性
# """
def __init__(
self,
connector: Optional[Neo4jConnector] = None,
embedder_client: Optional[OpenAIEmbedderClient] = None,
alpha: float = 0.6,
use_forgetting_curve: bool = False,
forgetting_config: Optional[ForgettingEngineConfig] = None
):
"""初始化混合搜索策略
# def __init__(
# self,
# connector: Optional[Neo4jConnector] = None,
# embedder_client: Optional[OpenAIEmbedderClient] = None,
# alpha: float = 0.6,
# use_forgetting_curve: bool = False,
# forgetting_config: Optional[ForgettingEngineConfig] = None
# ):
# """初始化混合搜索策略
Args:
connector: Neo4j连接器
embedder_client: 嵌入模型客户端
alpha: BM25分数权重0.0-1.01-alpha为嵌入分数权重
use_forgetting_curve: 是否使用遗忘曲线
forgetting_config: 遗忘引擎配置
"""
self.connector = connector
self.embedder_client = embedder_client
self.alpha = alpha
self.use_forgetting_curve = use_forgetting_curve
self.forgetting_config = forgetting_config or ForgettingEngineConfig()
self._owns_connector = connector is None
# Args:
# connector: Neo4j连接器
# embedder_client: 嵌入模型客户端
# alpha: BM25分数权重0.0-1.01-alpha为嵌入分数权重
# use_forgetting_curve: 是否使用遗忘曲线
# forgetting_config: 遗忘引擎配置
# """
# self.connector = connector
# self.embedder_client = embedder_client
# self.alpha = alpha
# self.use_forgetting_curve = use_forgetting_curve
# self.forgetting_config = forgetting_config or ForgettingEngineConfig()
# self._owns_connector = connector is None
# 创建子策略
self.keyword_strategy = KeywordSearchStrategy(connector=connector)
self.semantic_strategy = SemanticSearchStrategy(
connector=connector,
embedder_client=embedder_client
)
# # 创建子策略
# self.keyword_strategy = KeywordSearchStrategy(connector=connector)
# self.semantic_strategy = SemanticSearchStrategy(
# connector=connector,
# embedder_client=embedder_client
# )
async def __aenter__(self):
"""异步上下文管理器入口"""
if self._owns_connector:
self.connector = Neo4jConnector()
self.keyword_strategy.connector = self.connector
self.semantic_strategy.connector = self.connector
return self
# async def __aenter__(self):
# """异步上下文管理器入口"""
# if self._owns_connector:
# self.connector = Neo4jConnector()
# self.keyword_strategy.connector = self.connector
# self.semantic_strategy.connector = self.connector
# return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
if self._owns_connector and self.connector:
await self.connector.close()
# async def __aexit__(self, exc_type, exc_val, exc_tb):
# """异步上下文管理器出口"""
# if self._owns_connector and self.connector:
# await self.connector.close()
async def search(
self,
query_text: str,
group_id: Optional[str] = None,
limit: int = 50,
include: Optional[List[str]] = None,
**kwargs
) -> SearchResult:
"""执行混合搜索
# async def search(
# self,
# query_text: str,
# group_id: Optional[str] = None,
# limit: int = 50,
# include: Optional[List[str]] = None,
# **kwargs
# ) -> SearchResult:
# """执行混合搜索
Args:
query_text: 查询文本
group_id: 可选的组ID过滤
limit: 每个类别的最大结果数
include: 要包含的搜索类别列表
**kwargs: 其他搜索参数如alpha, use_forgetting_curve
# Args:
# query_text: 查询文本
# group_id: 可选的组ID过滤
# limit: 每个类别的最大结果数
# include: 要包含的搜索类别列表
# **kwargs: 其他搜索参数如alpha, use_forgetting_curve
Returns:
SearchResult: 搜索结果对象
"""
logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
# Returns:
# SearchResult: 搜索结果对象
# """
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
# 从kwargs中获取参数
alpha = kwargs.get("alpha", self.alpha)
use_forgetting = kwargs.get("use_forgetting_curve", self.use_forgetting_curve)
# # 从kwargs中获取参数
# alpha = kwargs.get("alpha", self.alpha)
# use_forgetting = kwargs.get("use_forgetting_curve", self.use_forgetting_curve)
# 获取有效的搜索类别
include_list = self._get_include_list(include)
# # 获取有效的搜索类别
# include_list = self._get_include_list(include)
try:
# 并行执行关键词搜索和语义搜索
keyword_result = await self.keyword_strategy.search(
query_text=query_text,
group_id=group_id,
limit=limit,
include=include_list
)
# try:
# # 并行执行关键词搜索和语义搜索
# keyword_result = await self.keyword_strategy.search(
# query_text=query_text,
# group_id=group_id,
# limit=limit,
# include=include_list
# )
semantic_result = await self.semantic_strategy.search(
query_text=query_text,
group_id=group_id,
limit=limit,
include=include_list
)
# semantic_result = await self.semantic_strategy.search(
# query_text=query_text,
# group_id=group_id,
# limit=limit,
# include=include_list
# )
# 重排序结果
if use_forgetting:
reranked_results = self._rerank_with_forgetting_curve(
keyword_result=keyword_result,
semantic_result=semantic_result,
alpha=alpha,
limit=limit
)
else:
reranked_results = self._rerank_hybrid_results(
keyword_result=keyword_result,
semantic_result=semantic_result,
alpha=alpha,
limit=limit
)
# # 重排序结果
# if use_forgetting:
# reranked_results = self._rerank_with_forgetting_curve(
# keyword_result=keyword_result,
# semantic_result=semantic_result,
# alpha=alpha,
# limit=limit
# )
# else:
# reranked_results = self._rerank_hybrid_results(
# keyword_result=keyword_result,
# semantic_result=semantic_result,
# alpha=alpha,
# limit=limit
# )
# 创建元数据
metadata = self._create_metadata(
query_text=query_text,
search_type="hybrid",
group_id=group_id,
limit=limit,
include=include_list,
alpha=alpha,
use_forgetting_curve=use_forgetting
)
# # 创建元数据
# metadata = self._create_metadata(
# query_text=query_text,
# search_type="hybrid",
# group_id=group_id,
# limit=limit,
# include=include_list,
# alpha=alpha,
# use_forgetting_curve=use_forgetting
# )
# 添加结果统计
metadata["keyword_results"] = keyword_result.metadata.get("result_counts", {})
metadata["semantic_results"] = semantic_result.metadata.get("result_counts", {})
metadata["total_keyword_results"] = keyword_result.total_results()
metadata["total_semantic_results"] = semantic_result.total_results()
metadata["total_reranked_results"] = reranked_results.total_results()
# # 添加结果统计
# metadata["keyword_results"] = keyword_result.metadata.get("result_counts", {})
# metadata["semantic_results"] = semantic_result.metadata.get("result_counts", {})
# metadata["total_keyword_results"] = keyword_result.total_results()
# metadata["total_semantic_results"] = semantic_result.total_results()
# metadata["total_reranked_results"] = reranked_results.total_results()
reranked_results.metadata = metadata
# reranked_results.metadata = metadata
logger.info(f"混合搜索完成: 共找到 {reranked_results.total_results()} 条结果")
return reranked_results
# logger.info(f"混合搜索完成: 共找到 {reranked_results.total_results()} 条结果")
# return reranked_results
except Exception as e:
logger.error(f"混合搜索失败: {e}", exc_info=True)
# 返回空结果但包含错误信息
return SearchResult(
metadata=self._create_metadata(
query_text=query_text,
search_type="hybrid",
group_id=group_id,
limit=limit,
error=str(e)
)
)
# except Exception as e:
# logger.error(f"混合搜索失败: {e}", exc_info=True)
# # 返回空结果但包含错误信息
# return SearchResult(
# metadata=self._create_metadata(
# query_text=query_text,
# search_type="hybrid",
# group_id=group_id,
# limit=limit,
# error=str(e)
# )
# )
def _normalize_scores(
self,
results: List[Dict[str, Any]],
score_field: str = "score"
) -> List[Dict[str, Any]]:
"""使用z-score标准化和sigmoid转换归一化分数
# def _normalize_scores(
# self,
# results: List[Dict[str, Any]],
# score_field: str = "score"
# ) -> List[Dict[str, Any]]:
# """使用z-score标准化和sigmoid转换归一化分数
Args:
results: 结果列表
score_field: 分数字段名
# Args:
# results: 结果列表
# score_field: 分数字段名
Returns:
List[Dict[str, Any]]: 归一化后的结果列表
"""
if not results:
return results
# Returns:
# List[Dict[str, Any]]: 归一化后的结果列表
# """
# if not results:
# return results
# 提取分数
scores = []
for item in results:
if score_field in item:
score = item.get(score_field)
if score is not None and isinstance(score, (int, float)):
scores.append(float(score))
else:
scores.append(0.0)
# # 提取分数
# scores = []
# for item in results:
# if score_field in item:
# score = item.get(score_field)
# if score is not None and isinstance(score, (int, float)):
# scores.append(float(score))
# else:
# scores.append(0.0)
if not scores or len(scores) == 1:
# 单个分数或无分数设置为1.0
for item in results:
if score_field in item:
item[f"normalized_{score_field}"] = 1.0
return results
# if not scores or len(scores) == 1:
# # 单个分数或无分数设置为1.0
# for item in results:
# if score_field in item:
# item[f"normalized_{score_field}"] = 1.0
# return results
# 计算均值和标准差
mean_score = sum(scores) / len(scores)
variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
std_dev = math.sqrt(variance)
# # 计算均值和标准差
# mean_score = sum(scores) / len(scores)
# variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
# std_dev = math.sqrt(variance)
if std_dev == 0:
# 所有分数相同设置为1.0
for item in results:
if score_field in item:
item[f"normalized_{score_field}"] = 1.0
else:
# z-score标准化 + sigmoid转换
for item in results:
if score_field in item:
score = item[score_field]
if score is None or not isinstance(score, (int, float)):
score = 0.0
z_score = (score - mean_score) / std_dev
normalized = 1 / (1 + math.exp(-z_score))
item[f"normalized_{score_field}"] = normalized
# if std_dev == 0:
# # 所有分数相同设置为1.0
# for item in results:
# if score_field in item:
# item[f"normalized_{score_field}"] = 1.0
# else:
# # z-score标准化 + sigmoid转换
# for item in results:
# if score_field in item:
# score = item[score_field]
# if score is None or not isinstance(score, (int, float)):
# score = 0.0
# z_score = (score - mean_score) / std_dev
# normalized = 1 / (1 + math.exp(-z_score))
# item[f"normalized_{score_field}"] = normalized
return results
# return results
def _rerank_hybrid_results(
self,
keyword_result: SearchResult,
semantic_result: SearchResult,
alpha: float,
limit: int
) -> SearchResult:
"""重排序混合搜索结果
# def _rerank_hybrid_results(
# self,
# keyword_result: SearchResult,
# semantic_result: SearchResult,
# alpha: float,
# limit: int
# ) -> SearchResult:
# """重排序混合搜索结果
Args:
keyword_result: 关键词搜索结果
semantic_result: 语义搜索结果
alpha: BM25分数权重
limit: 结果限制
# Args:
# keyword_result: 关键词搜索结果
# semantic_result: 语义搜索结果
# alpha: BM25分数权重
# limit: 结果限制
Returns:
SearchResult: 重排序后的结果
"""
reranked_data = {}
# Returns:
# SearchResult: 重排序后的结果
# """
# reranked_data = {}
for category in ["statements", "chunks", "entities", "summaries"]:
keyword_items = getattr(keyword_result, category, [])
semantic_items = getattr(semantic_result, category, [])
# for category in ["statements", "chunks", "entities", "summaries"]:
# keyword_items = getattr(keyword_result, category, [])
# semantic_items = getattr(semantic_result, category, [])
# 归一化分数
keyword_items = self._normalize_scores(keyword_items, "score")
semantic_items = self._normalize_scores(semantic_items, "score")
# # 归一化分数
# keyword_items = self._normalize_scores(keyword_items, "score")
# semantic_items = self._normalize_scores(semantic_items, "score")
# 合并结果
combined_items = {}
# # 合并结果
# combined_items = {}
# 添加关键词结果
for item in keyword_items:
item_id = item.get("id") or item.get("uuid")
if item_id:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
combined_items[item_id]["embedding_score"] = 0
# # 添加关键词结果
# for item in keyword_items:
# item_id = item.get("id") or item.get("uuid")
# if item_id:
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
# combined_items[item_id]["embedding_score"] = 0
# 添加或更新语义结果
for item in semantic_items:
item_id = item.get("id") or item.get("uuid")
if item_id:
if item_id in combined_items:
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# # 添加或更新语义结果
# for item in semantic_items:
# item_id = item.get("id") or item.get("uuid")
# if item_id:
# if item_id in combined_items:
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# else:
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = 0
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# 计算组合分数
for item_id, item in combined_items.items():
bm25_score = item.get("bm25_score", 0)
embedding_score = item.get("embedding_score", 0)
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
item["combined_score"] = combined_score
# # 计算组合分数
# for item_id, item in combined_items.items():
# bm25_score = item.get("bm25_score", 0)
# embedding_score = item.get("embedding_score", 0)
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
# item["combined_score"] = combined_score
# 排序并限制结果
sorted_items = sorted(
combined_items.values(),
key=lambda x: x.get("combined_score", 0),
reverse=True
)[:limit]
# # 排序并限制结果
# sorted_items = sorted(
# combined_items.values(),
# key=lambda x: x.get("combined_score", 0),
# reverse=True
# )[:limit]
reranked_data[category] = sorted_items
# reranked_data[category] = sorted_items
return SearchResult(
statements=reranked_data.get("statements", []),
chunks=reranked_data.get("chunks", []),
entities=reranked_data.get("entities", []),
summaries=reranked_data.get("summaries", [])
)
# return SearchResult(
# statements=reranked_data.get("statements", []),
# chunks=reranked_data.get("chunks", []),
# entities=reranked_data.get("entities", []),
# summaries=reranked_data.get("summaries", [])
# )
def _parse_datetime(self, value: Any) -> Optional[datetime]:
"""解析日期时间字符串"""
if value is None:
return None
if isinstance(value, datetime):
return value
if isinstance(value, str):
s = value.strip()
if not s:
return None
try:
return datetime.fromisoformat(s)
except Exception:
return None
return None
# def _parse_datetime(self, value: Any) -> Optional[datetime]:
# """解析日期时间字符串"""
# if value is None:
# return None
# if isinstance(value, datetime):
# return value
# if isinstance(value, str):
# s = value.strip()
# if not s:
# return None
# try:
# return datetime.fromisoformat(s)
# except Exception:
# return None
# return None
def _rerank_with_forgetting_curve(
self,
keyword_result: SearchResult,
semantic_result: SearchResult,
alpha: float,
limit: int
) -> SearchResult:
"""使用遗忘曲线重排序混合搜索结果
# def _rerank_with_forgetting_curve(
# self,
# keyword_result: SearchResult,
# semantic_result: SearchResult,
# alpha: float,
# limit: int
# ) -> SearchResult:
# """使用遗忘曲线重排序混合搜索结果
Args:
keyword_result: 关键词搜索结果
semantic_result: 语义搜索结果
alpha: BM25分数权重
limit: 结果限制
# Args:
# keyword_result: 关键词搜索结果
# semantic_result: 语义搜索结果
# alpha: BM25分数权重
# limit: 结果限制
Returns:
SearchResult: 重排序后的结果
"""
engine = ForgettingEngine(self.forgetting_config)
now_dt = datetime.now()
# Returns:
# SearchResult: 重排序后的结果
# """
# engine = ForgettingEngine(self.forgetting_config)
# now_dt = datetime.now()
reranked_data = {}
# reranked_data = {}
for category in ["statements", "chunks", "entities", "summaries"]:
keyword_items = getattr(keyword_result, category, [])
semantic_items = getattr(semantic_result, category, [])
# for category in ["statements", "chunks", "entities", "summaries"]:
# keyword_items = getattr(keyword_result, category, [])
# semantic_items = getattr(semantic_result, category, [])
# 归一化分数
keyword_items = self._normalize_scores(keyword_items, "score")
semantic_items = self._normalize_scores(semantic_items, "score")
# # 归一化分数
# keyword_items = self._normalize_scores(keyword_items, "score")
# semantic_items = self._normalize_scores(semantic_items, "score")
# 合并结果
combined_items = {}
# # 合并结果
# combined_items = {}
for src_items, is_embedding in [(keyword_items, False), (semantic_items, True)]:
for item in src_items:
item_id = item.get("id") or item.get("uuid")
if not item_id:
continue
# for src_items, is_embedding in [(keyword_items, False), (semantic_items, True)]:
# for item in src_items:
# item_id = item.get("id") or item.get("uuid")
# if not item_id:
# continue
if item_id not in combined_items:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0
combined_items[item_id]["embedding_score"] = 0
# if item_id not in combined_items:
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = 0
# combined_items[item_id]["embedding_score"] = 0
if is_embedding:
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
# if is_embedding:
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# else:
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
# 计算分数并应用遗忘权重
for item_id, item in combined_items.items():
bm25_score = float(item.get("bm25_score", 0) or 0)
embedding_score = float(item.get("embedding_score", 0) or 0)
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
# # 计算分数并应用遗忘权重
# for item_id, item in combined_items.items():
# bm25_score = float(item.get("bm25_score", 0) or 0)
# embedding_score = float(item.get("embedding_score", 0) or 0)
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
# 计算时间衰减
dt = self._parse_datetime(item.get("created_at"))
if dt is None:
time_elapsed_days = 0.0
else:
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
# # 计算时间衰减
# dt = self._parse_datetime(item.get("created_at"))
# if dt is None:
# time_elapsed_days = 0.0
# else:
# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
memory_strength = 1.0 # 默认强度
forgetting_weight = engine.calculate_weight(
time_elapsed=time_elapsed_days,
memory_strength=memory_strength
)
# memory_strength = 1.0 # 默认强度
# forgetting_weight = engine.calculate_weight(
# time_elapsed=time_elapsed_days,
# memory_strength=memory_strength
# )
final_score = combined_score * forgetting_weight
item["combined_score"] = final_score
item["forgetting_weight"] = forgetting_weight
item["time_elapsed_days"] = time_elapsed_days
# final_score = combined_score * forgetting_weight
# item["combined_score"] = final_score
# item["forgetting_weight"] = forgetting_weight
# item["time_elapsed_days"] = time_elapsed_days
# 排序并限制结果
sorted_items = sorted(
combined_items.values(),
key=lambda x: x.get("combined_score", 0),
reverse=True
)[:limit]
# # 排序并限制结果
# sorted_items = sorted(
# combined_items.values(),
# key=lambda x: x.get("combined_score", 0),
# reverse=True
# )[:limit]
reranked_data[category] = sorted_items
# reranked_data[category] = sorted_items
return SearchResult(
statements=reranked_data.get("statements", []),
chunks=reranked_data.get("chunks", []),
entities=reranked_data.get("entities", []),
summaries=reranked_data.get("summaries", [])
)
# return SearchResult(
# statements=reranked_data.get("statements", []),
# chunks=reranked_data.get("chunks", []),
# entities=reranked_data.get("entities", []),
# summaries=reranked_data.get("summaries", [])
# )

View File

@@ -5,15 +5,20 @@
使用余弦相似度进行语义匹配。
"""
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
from app.core.logging_config import get_memory_logger
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.storage_services.search.search_strategy import (
SearchResult,
SearchStrategy,
)
from app.core.memory.utils.config import definitions as config_defs
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
logger = get_memory_logger(__name__)
@@ -62,7 +67,9 @@ class SemanticSearchStrategy(SearchStrategy):
"""
try:
# 从数据库读取嵌入器配置
embedder_config_dict = get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
with get_db_context() as db:
config_service = MemoryConfigService(db)
embedder_config_dict = config_service.get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
rb_config = RedBearModelConfig(
model_name=embedder_config_dict["model_name"],
provider=embedder_config_dict["provider"],

View File

@@ -1,445 +0,0 @@
# Memory 模块工具函数文档
本目录包含 Memory 模块使用的所有工具函数,统一管理以提高代码可维护性和可复用性。
## 目录结构
```
app/core/memory/utils/
├── __init__.py # 包初始化文件,导出所有公共接口
├── README.md # 本文档
├── config/ # 配置管理模块
│ ├── __init__.py # 配置模块初始化
│ ├── config_utils.py # 配置管理工具
│ ├── definitions.py # 全局定义和常量
│ ├── overrides.py # 运行时配置覆写
│ ├── get_data.py # 数据获取工具
│ ├── litellm_config.py # LiteLLM 配置和监控
│ └── config_optimization.py # 配置优化工具
├── log/ # 日志管理模块
│ ├── __init__.py # 日志模块初始化
│ ├── logging_utils.py # 日志工具
│ └── audit_logger.py # 审计日志
├── prompt/ # 提示词管理模块
│ ├── __init__.py # 提示词模块初始化
│ ├── prompt_utils.py # 提示词渲染工具
│ ├── template_render.py # 模板渲染工具
│ └── prompts/ # Jinja2 提示词模板目录
│ ├── entity_dedup.jinja2 # 实体去重提示词
│ ├── extract_statement.jinja2 # 陈述句提取提示词
│ ├── extract_temporal.jinja2 # 时间信息提取提示词
│ ├── extract_triplet.jinja2 # 三元组提取提示词
│ ├── memory_summary.jinja2 # 记忆摘要提示词
│ ├── evaluate.jinja2 # 评估提示词
│ ├── reflexion.jinja2 # 反思提示词
│ ├── system.jinja2 # 系统提示词
│ └── user.jinja2 # 用户提示词
├── llm/ # LLM 工具模块
│ ├── __init__.py # LLM 模块初始化
│ └── llm_utils.py # LLM 客户端工具
├── data/ # 数据处理模块
│ ├── __init__.py # 数据模块初始化
│ ├── text_utils.py # 文本处理工具
│ ├── time_utils.py # 时间处理工具
│ └── ontology.py # 本体定义(谓语、标签等)
├── paths/ # 路径管理模块
│ ├── __init__.py # 路径模块初始化
│ └── output_paths.py # 输出路径管理
├── visualization/ # 可视化模块
│ ├── __init__.py # 可视化模块初始化
│ └── forgetting_visualizer.py # 遗忘曲线可视化
└── self_reflexion_utils/ # 自我反思工具模块
├── __init__.py # 反思模块初始化
├── evaluate.py # 冲突评估
├── reflexion.py # 反思处理
└── self_reflexion.py # 自我反思主逻辑
```
## 模块分类
### 1. 配置管理config/
配置管理模块包含所有与配置相关的工具函数和定义。
#### config_utils.py
提供配置加载和管理功能:
- `get_model_config(model_id)` - 获取 LLM 模型配置
- `get_embedder_config(embedding_id)` - 获取嵌入模型配置
- `get_neo4j_config()` - 获取 Neo4j 数据库配置
- `get_chunker_config(chunker_strategy)` - 获取分块策略配置
- `get_pipeline_config()` - 获取流水线配置
- `get_pruning_config()` - 获取语义剪枝配置
- `get_picture_config()` - 获取图片模型配置
- `get_voice_config()` - 获取语音模型配置
#### definitions.py
全局定义和常量:
- `CONFIG` - 基础配置(从 config.json 加载)
- `RUNTIME_CONFIG` - 运行时配置(从 runtime.json 或数据库加载)
- `PROJECT_ROOT` - 项目根目录路径
- 各种选择配置常量LLM、嵌入模型、分块策略等
- `reload_configuration_from_database(config_id)` - 动态重新加载配置
#### overrides.py
运行时配置覆写:
- `load_unified_config(project_root)` - 加载统一配置
#### get_data.py
数据获取工具:
- `get_data(host_id)` - 从 SQL 数据库获取数据
#### litellm_config.py
LiteLLM 配置和监控:
- `LiteLLMConfig` - LiteLLM 配置类
- `setup_litellm_enhanced(max_retries)` - 设置增强的 LiteLLM 配置
- `get_usage_summary()` - 获取使用统计摘要
- `print_usage_summary()` - 打印使用统计
- `get_instant_qps(module)` - 获取即时 QPS 数据
- `print_instant_qps(module)` - 打印即时 QPS 信息
#### config_optimization.py
配置优化工具:
- 配置参数优化相关功能
### 3. LLM 工具llm/
LLM 工具模块包含所有与 LLM 客户端相关的工具函数。
#### llm_utils.py
LLM 客户端工具:
- `get_llm_client(llm_id)` - 获取 LLM 客户端实例
- `get_reranker_client(rerank_id)` - 获取重排序客户端实例
- `handle_response(response)` - 处理 LLM 响应
#### litellm_config.py
LiteLLM 配置和监控:
- `LiteLLMConfig` - LiteLLM 配置类
- `setup_litellm_enhanced(max_retries)` - 设置增强的 LiteLLM 配置
- `get_usage_summary()` - 获取使用统计摘要
- `print_usage_summary()` - 打印使用统计
- `get_instant_qps(module)` - 获取即时 QPS 数据
- `print_instant_qps(module)` - 打印即时 QPS 信息
### 4. 提示词管理prompt/
提示词管理模块包含所有提示词渲染和模板管理相关的工具函数。
#### prompt_utils.py
提示词渲染工具(使用 Jinja2 模板):
- `get_prompts(message)` - 获取系统和用户提示词
- `render_statement_extraction_prompt(...)` - 渲染陈述句提取提示词
- `render_temporal_extraction_prompt(...)` - 渲染时间信息提取提示词
- `render_entity_dedup_prompt(...)` - 渲染实体去重提示词
- `render_triplet_extraction_prompt(...)` - 渲染三元组提取提示词
- `render_memory_summary_prompt(...)` - 渲染记忆摘要提示词
- `prompt_env` - Jinja2 环境对象
#### template_render.py
模板渲染工具(用于评估和反思):
- `render_evaluate_prompt(evaluate_data, schema)` - 渲染评估提示词
- `render_reflexion_prompt(data, schema)` - 渲染反思提示词
#### prompts/
Jinja2 模板文件目录,包含所有提示词模板
### 5. 数据处理data/
数据处理模块包含所有数据处理相关的工具函数。
#### text_utils.py
文本处理工具:
- `escape_lucene_query(query)` - 转义 Lucene 查询特殊字符
- `extract_plain_query(query_input)` - 从各种输入格式提取纯文本查询
#### time_utils.py
时间处理工具:
- `validate_date_format(date_str)` - 验证日期格式YYYY-MM-DD
- `normalize_date(date_str)` - 标准化日期格式
- `normalize_date_safe(date_str, default)` - 安全的日期标准化(带默认值)
- `preprocess_date_string(date_str)` - 预处理日期字符串
#### ontology.py
本体定义:
- `PREDICATE_DEFINITIONS` - 谓语定义字典
- `LABEL_DEFINITIONS` - 标签定义字典
- `Predicate` - 谓语枚举
- `StatementType` - 陈述句类型枚举
- `TemporalInfo` - 时间信息枚举
- `RelevenceInfo` - 相关性信息枚举
### 2. 日志管理log/
日志管理模块包含所有与日志记录相关的工具函数。
#### logging_utils.py
日志工具:
- `log_prompt_rendering(role, content)` - 记录提示词渲染
- `log_template_rendering(template_name, context)` - 记录模板渲染
- `log_time(operation, duration)` - 记录操作耗时
- `prompt_logger` - 提示词日志记录器
#### audit_logger.py
审计日志:
- `audit_logger` - 审计日志记录器
- 记录系统关键操作和安全事件
### 6. 自我反思工具self_reflexion_utils/
自我反思工具模块包含记忆冲突检测和反思处理功能。
#### evaluate.py
冲突评估:
- `conflict(evaluate_data, schema)` - 评估记忆冲突
#### reflexion.py
反思处理:
- `reflexion(data, schema)` - 执行反思处理
#### self_reflexion.py
自我反思主逻辑:
- `self_reflexion(...)` - 自我反思主函数
### 7. 数据模型
#### json_schema.py
JSON Schema 数据模型:
- `BaseDataSchema` - 基础数据模型
- `ConflictResultSchema` - 冲突结果模型
- `ConflictSchema` - 冲突模型
- `ReflexionSchema` - 反思模型
- `ResolvedSchema` - 解决方案模型
- `ReflexionResultSchema` - 反思结果模型
#### messages.py
API 消息模型:
- `ConfigKey` - 配置键模型
- `ChunkerStrategy` - 分块策略枚举
- `ConfigParams` - 配置参数模型
- `ConfigParamsCreate` - 创建配置参数模型
- `ConfigUpdate` - 更新配置模型
- `ConfigUpdateExtracted` - 更新萃取引擎配置模型
- `ConfigUpdateForget` - 更新遗忘引擎配置模型
- `ConfigPilotRun` - 试运行配置模型
- `ConfigFilter` - 配置过滤模型
- `ApiResponse` - API 响应模型
- `ok(msg, data)` - 成功响应构造函数
- `fail(msg, error_code, data)` - 失败响应构造函数
### 8. 可视化visualization/
可视化模块包含所有可视化相关的工具函数。
#### forgetting_visualizer.py
遗忘曲线可视化:
- `export_memory_curve_numpy(...)` - 导出记忆曲线为 NumPy 数组
- `export_memory_curves_multiple_strengths(...)` - 导出多个强度的记忆曲线
- `export_parameter_sweep_numpy(...)` - 导出参数扫描结果
- `visualize_forgetting_curve(...)` - 可视化遗忘曲线
- `plot_3d_forgetting_surface(...)` - 绘制 3D 遗忘曲线表面
- `create_comparison_visualization(...)` - 创建对比可视化
- `save_memory_curves_to_file(...)` - 保存记忆曲线到文件
### 9. 路径管理paths/
路径管理模块包含所有路径管理相关的工具函数。
#### output_paths.py
输出路径管理:
- `get_output_dir()` - 获取输出目录
- `get_output_path(filename)` - 获取输出文件路径
## 使用示例
### 配置管理
```python
from app.core.memory.utils.config import get_model_config, get_pipeline_config
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
# 获取模型配置
model_config = get_model_config("model_id_123")
# 获取流水线配置
pipeline_config = get_pipeline_config()
# 使用全局常量
llm_id = SELECTED_LLM_ID
```
### 日志管理
```python
from app.core.memory.utils.log import log_prompt_rendering, log_time, audit_logger
# 记录提示词渲染
log_prompt_rendering('user', 'Hello, world!')
# 记录操作耗时
log_time('extraction', 1.23)
# 使用审计日志
audit_logger.info('User action performed')
```
### LLM 工具
```python
from app.core.memory.utils.llm import get_llm_client
# 获取 LLM 客户端
llm_client = get_llm_client("llm_id_456")
# 调用 LLM
response = await llm_client.chat([
{"role": "user", "content": "Hello"}
])
```
### 提示词渲染
```python
from app.core.memory.utils.prompt import render_statement_extraction_prompt
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS
# 渲染陈述句提取提示词
prompt = await render_statement_extraction_prompt(
chunk_content="对话内容...",
definitions=LABEL_DEFINITIONS,
json_schema=schema,
granularity=2
)
```
### 数据处理
```python
from app.core.memory.utils.data.time_utils import normalize_date
from app.core.memory.utils.data.text_utils import escape_lucene_query
# 标准化日期
normalized = normalize_date("2025/10/28") # 返回 "2025-10-28"
# 转义 Lucene 查询
escaped = escape_lucene_query("user:admin AND status:active")
```
### 运行时配置覆写
```python
from app.core.memory.utils import apply_runtime_overrides_with_config_id
# 使用指定 config_id 覆写配置
runtime_cfg = {"selections": {}}
updated_cfg = apply_runtime_overrides_with_config_id(
project_root="/path/to/project",
runtime_cfg=runtime_cfg,
config_id="config_123"
)
```
## 迁移说明
### 从旧路径迁移
如果你的代码使用了旧的导入路径,请按以下方式更新:
**旧路径2024年11月之前**
```python
from app.core.memory.src.utils.config_utils import get_model_config
from app.core.memory.src.utils.prompt_utils import render_statement_extraction_prompt
from app.core.memory.src.data_config_api.utils.messages import ok, fail
```
**中间路径2024年11月**
```python
from app.core.memory.utils.config_utils import get_model_config
from app.core.memory.utils.logging_utils import log_prompt_rendering
from app.schemas.memory_storage_schema import ok, fail
```
**新路径2024年11月27日之后**
```python
# 配置相关
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.memory.utils.config import get_model_config # 简化导入
# 日志相关
from app.core.memory.utils.log.logging_utils import log_prompt_rendering
from app.core.memory.utils.log import log_prompt_rendering # 简化导入
# 其他工具
from app.core.memory.utils import prompt_utils
from app.schemas.memory_storage_schema import ok, fail
```
### 目录结构重组2024年11月27日
utils 目录已按功能进行了完整的重组:
**重组前的结构:**
- 所有文件都在 `app/core/memory/utils/` 根目录下
**重组后的结构:**
- `config/` - 配置管理相关文件
- `log/` - 日志管理相关文件
- `prompt/` - 提示词管理相关文件
- `llm/` - LLM 工具相关文件
- `data/` - 数据处理相关文件
- `paths/` - 路径管理相关文件
- `visualization/` - 可视化相关文件
- `self_reflexion_utils/` - 自我反思工具(已存在)
**导入路径变化:**
```python
# 旧导入方式
from app.core.memory.utils.config_utils import get_model_config
from app.core.memory.utils.logging_utils import log_prompt_rendering
from app.core.memory.utils.prompt_utils import render_statement_extraction_prompt
# 新导入方式
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.memory.utils.log.logging_utils import log_prompt_rendering
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
# 或使用简化导入
from app.core.memory.utils.config import get_model_config
from app.core.memory.utils.log import log_prompt_rendering
from app.core.memory.utils.prompt import render_statement_extraction_prompt
```
## 维护指南
### 添加新工具函数
1. 在相应的模块文件中添加函数
2.`__init__.py` 中导出函数
3. 在本 README 中添加文档
4. 编写单元测试
### 删除旧工具函数
1. 确认没有代码使用该函数
2. 从模块文件中删除函数
3.`__init__.py` 中删除导出
4. 更新本 README
### 重构工具函数
1. 保持向后兼容性(使用别名或包装器)
2. 更新所有使用该函数的代码
3. 更新文档和测试
4. 在适当时机删除旧版本
## 注意事项
1. **向后兼容性**:所有工具函数应保持向后兼容,避免破坏现有代码
2. **文档完整性**:每个函数都应有清晰的文档字符串
3. **类型注解**:使用类型注解提高代码可读性
4. **错误处理**:工具函数应有适当的错误处理
5. **测试覆盖**:所有工具函数都应有单元测试
## 相关文档
- [Memory 模块架构设计](../.kiro/specs/memory-refactoring/design.md)
- [Memory 模块需求文档](../.kiro/specs/memory-refactoring/requirements.md)
- [Memory 模块任务列表](../.kiro/specs/memory-refactoring/tasks.md)

View File

@@ -6,33 +6,26 @@
# 从子模块导出常用函数和常量,保持向后兼容
from .config_utils import (
get_model_config,
get_embedder_config,
get_neo4j_config,
get_chunker_config,
get_embedder_config,
get_model_config,
get_picture_config,
get_pipeline_config,
get_pruning_config,
get_picture_config,
get_voice_config,
)
from .definitions import (
CONFIG,
RUNTIME_CONFIG,
PROJECT_ROOT,
SELECTED_LLM_ID,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_RERANK_ID,
SELECTED_LLM_PICTURE_NAME,
SELECTED_LLM_VOICE_NAME,
REFLEXION_ENABLED,
REFLEXION_ITERATION_PERIOD,
REFLEXION_RANGE,
REFLEXION_BASELINE,
reload_configuration_from_database,
)
from .overrides import load_unified_config
# DEPRECATED: Global configuration variables removed
# Use MemoryConfig objects with dependency injection instead
# from .definitions import (
# CONFIG, # DEPRECATED - empty dict for backward compatibility
# RUNTIME_CONFIG, # DEPRECATED - minimal for backward compatibility
# PROJECT_ROOT, # Still needed for file paths
# reload_configuration_from_database, # DEPRECATED - returns False
# )
# DEPRECATED: overrides module removed - use MemoryConfig with dependency injection
from .get_data import get_data
# litellm_config 需要时动态导入,避免循环依赖
# from .litellm_config import (
# LiteLLMConfig,
@@ -47,29 +40,16 @@ __all__ = [
# config_utils
"get_model_config",
"get_embedder_config",
"get_neo4j_config",
"get_chunker_config",
"get_pipeline_config",
"get_pruning_config",
"get_picture_config",
"get_voice_config",
# definitions
"CONFIG",
"RUNTIME_CONFIG",
"PROJECT_ROOT",
"SELECTED_LLM_ID",
"SELECTED_EMBEDDING_ID",
"SELECTED_GROUP_ID",
"SELECTED_RERANK_ID",
"SELECTED_LLM_PICTURE_NAME",
"SELECTED_LLM_VOICE_NAME",
"REFLEXION_ENABLED",
"REFLEXION_ITERATION_PERIOD",
"REFLEXION_RANGE",
"REFLEXION_BASELINE",
"reload_configuration_from_database",
# overrides
"load_unified_config",
# definitions (DEPRECATED - use MemoryConfig objects instead)
# "CONFIG", # DEPRECATED
# "RUNTIME_CONFIG", # DEPRECATED
# "PROJECT_ROOT",
# "reload_configuration_from_database", # DEPRECATED
# get_data
"get_data",
# litellm_config - 需要时从 .litellm_config 直接导入

View File

@@ -1,94 +1,74 @@
import uuid
import json
from typing import Optional
"""
Configuration utilities - Backward compatibility layer
from sqlalchemy.orm import Session
from fastapi.exceptions import HTTPException
from fastapi import status
DEPRECATED: These functions now require a db session parameter.
New code should use MemoryConfigService(db) instance directly.
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
from app.core.memory.models.variate_config import (
ExtractionPipelineConfig,
DedupConfig,
StatementExtractionConfig,
ForgettingEngineConfig,
)
from app.core.memory.models.config_models import PruningConfig
from app.db import get_db
from app.models.models_model import ModelConfig, ModelApiKey
from app.services.model_service import ModelConfigService
def get_model_config(model_id: str, db: Session | None = None) -> dict:
For functions that don't require db (get_pipeline_config, get_pruning_config),
they are still re-exported here.
"""
import warnings
from app.services.memory_config_service import MemoryConfigService
# These functions don't require db - safe to re-export as static methods
get_pipeline_config = MemoryConfigService.get_pipeline_config
get_pruning_config = MemoryConfigService.get_pruning_config
def get_model_config(model_id: str, db=None):
"""DEPRECATED: Use MemoryConfigService(db).get_model_config(model_id) directly."""
if db is None:
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen) # 取到真正的 Session
raise ValueError(
"get_model_config now requires a db session. "
"Use MemoryConfigService(db).get_model_config(model_id) directly."
)
return MemoryConfigService(db).get_model_config(model_id)
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
print(f"模型ID {model_id} 不存在")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
apiConfig: ModelApiKey = config.api_keys[0]
# 从环境变量读取超时和重试配置
from app.core.config import settings
model_config = {
"model_name": apiConfig.model_name,
"provider": apiConfig.provider,
"api_key": apiConfig.api_key,
"base_url": apiConfig.api_base,
"model_config_id":apiConfig.model_config_id,
"type": config.type,
# 添加超时和重试配置,避免 LLM 请求超时
"timeout": settings.LLM_TIMEOUT, # 从环境变量读取默认120秒
"max_retries": settings.LLM_MAX_RETRIES, # 从环境变量读取默认2次
}
# 写入model_config.log文件中
with open("logs/model_config.log", "a", encoding="utf-8") as f:
f.write(f"模型ID: {model_id}\n")
f.write(f"模型配置信息:\n{model_config}\n")
f.write("=============================\n\n")
return model_config
def get_embedder_config(embedding_id: str, db: Session | None = None) -> dict:
def get_embedder_config(embedding_id: str, db=None):
"""DEPRECATED: Use MemoryConfigService(db).get_embedder_config(embedding_id) directly."""
if db is None:
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen) # 取到真正的 Session
raise ValueError(
"get_embedder_config now requires a db session. "
"Use MemoryConfigService(db).get_embedder_config(embedding_id) directly."
)
return MemoryConfigService(db).get_embedder_config(embedding_id)
config = ModelConfigService.get_model_by_id(db=db, model_id=embedding_id)
if not config:
print(f"嵌入模型ID {embedding_id} 不存在")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
apiConfig: ModelApiKey = config.api_keys[0]
model_config = {
"model_name": apiConfig.model_name,
"provider": apiConfig.provider,
"api_key": apiConfig.api_key,
"base_url": apiConfig.api_base,
"model_config_id":apiConfig.model_config_id,
# Ensure required field for RedBearModelConfig validation
"type": config.type,
# 添加超时和重试配置,避免嵌入服务请求超时
"timeout": 120.0, # 嵌入服务超时时间(秒)
"max_retries": 5, # 最大重试次数
}
# 写入embedder_config.log文件中
with open("logs/embedder_config.log", "a", encoding="utf-8") as f:
f.write(f"嵌入模型ID: {embedding_id}\n")
f.write(f"嵌入模型配置信息:\n{model_config}\n")
f.write("=============================\n\n")
return model_config
def get_neo4j_config() -> dict:
"""Retrieves the Neo4j configuration from the config file."""
return CONFIG.get("neo4j", {})
def get_picture_config(llm_name: str) -> dict:
"""Retrieves the configuration for a specific model from the config file."""
"""Retrieves the configuration for a specific model from the config file.
.. deprecated::
This function is deprecated and will be removed in a future version.
Use database-backed model configuration instead.
"""
warnings.warn(
"get_picture_config is deprecated and will be removed in a future version. "
"Use database-backed model configuration instead.",
DeprecationWarning,
stacklevel=2
)
for model_config in CONFIG.get("picture_recognition", []):
if model_config["llm_name"] == llm_name:
return model_config
raise ValueError(f"Model '{llm_name}' not found in config.json")
def get_voice_config(llm_name: str) -> dict:
"""Retrieves the configuration for a specific model from the config file."""
"""Retrieves the configuration for a specific model from the config file.
.. deprecated::
This function is deprecated and will be removed in a future version.
Use database-backed model configuration instead.
"""
warnings.warn(
"get_voice_config is deprecated and will be removed in a future version. "
"Use database-backed model configuration instead.",
DeprecationWarning,
stacklevel=2
)
for model_config in CONFIG.get("voice_recognition", []):
if model_config["llm_name"] == llm_name:
return model_config
@@ -96,20 +76,15 @@ def get_voice_config(llm_name: str) -> dict:
def get_chunker_config(chunker_strategy: str) -> dict:
"""Retrieves the configuration for a specific chunker strategy.
"""Retrieves the configuration for a specific chunker strategy."""
Enhancements:
- Supports default configs for `LLMChunker` and `HybridChunker` if not present.
- Falls back to the first available chunker config when the requested one is missing.
"""
# 1) Try to find exact match in config
chunker_list = CONFIG.get("chunker_list", [])
for chunker_config in chunker_list:
if chunker_config.get("chunker_strategy") == chunker_strategy:
return chunker_config
# 2) Provide sane defaults for newer strategies
default_configs = {
"RecursiveChunker": {
"chunker_strategy": "RecursiveChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 512,
"min_characters_per_chunk": 50
},
"LLMChunker": {
"chunker_strategy": "LLMChunker",
"embedding_model": "BAAI/bge-m3",
@@ -134,134 +109,6 @@ def get_chunker_config(chunker_strategy: str) -> dict:
if chunker_strategy in default_configs:
return default_configs[chunker_strategy]
# 3) Fallback: use first available config but tag with requested strategy
if chunker_list:
fallback = chunker_list[0].copy()
fallback["chunker_strategy"] = chunker_strategy
# Non-fatal notice for visibility in logs if any
print(f"Warning: Using first available chunker config as fallback for '{chunker_strategy}'")
return fallback
# 4) If no configs available at all
raise ValueError(
f"Chunker '{chunker_strategy}' not found in config.json and no default or fallback available"
f"Chunker '{chunker_strategy}' not found "
)
def get_pipeline_config() -> ExtractionPipelineConfig:
"""Build ExtractionPipelineConfig using only runtime.json values.
Behavior:
- Read `deduplication` section from runtime.json if present.
- Read `statement_extraction` section from runtime.json if present.
- Read `forgetting_engine` section from runtime.json if present.
- If absent, check legacy top-level `enable_llm_dedup` key.
- Do NOT fall back to environment variables.
- Unspecified fields use model defaults defined in DedupConfig.
"""
dedup_rc = RUNTIME_CONFIG.get("deduplication", {}) or {}
stmt_rc = RUNTIME_CONFIG.get("statement_extraction", {}) or {}
forget_rc = RUNTIME_CONFIG.get("forgetting_engine", {}) or {}
# Assemble kwargs from runtime.json only
kwargs = {}
# LLM switch: prefer new key, then legacy top-level, default False
if "enable_llm_dedup_blockwise" in dedup_rc:
kwargs["enable_llm_dedup_blockwise"] = bool(dedup_rc.get("enable_llm_dedup_blockwise"))
else:
# Legacy top-level fallback inside runtime.json only
legacy = RUNTIME_CONFIG.get("enable_llm_dedup")
if legacy is not None:
kwargs["enable_llm_dedup_blockwise"] = bool(legacy)
else:
kwargs["enable_llm_dedup_blockwise"] = False # default reserve
# Disambiguation switch: only from runtime.json deduplication section
if "enable_llm_disambiguation" in dedup_rc:
kwargs["enable_llm_disambiguation"] = bool(dedup_rc.get("enable_llm_disambiguation"))
# Optional LLM fallback gating
if "enable_llm_fallback_only_on_borderline" in dedup_rc:
kwargs["enable_llm_fallback_only_on_borderline"] = bool(dedup_rc.get("enable_llm_fallback_only_on_borderline"))
# Optional fuzzy thresholds: use values if provided; otherwise rely on DedupConfig defaults
for key in (
"fuzzy_name_threshold_strict",
"fuzzy_type_threshold_strict",
"fuzzy_overall_threshold",
"fuzzy_unknown_type_name_threshold",
"fuzzy_unknown_type_type_threshold",
):
if key in dedup_rc:
kwargs[key] = dedup_rc[key]
# Optional weights and bonuses for overall scoring
for key in (
"name_weight",
"desc_weight",
"type_weight",
"context_bonus",
"llm_fallback_floor",
"llm_fallback_ceiling",
):
if key in dedup_rc:
kwargs[key] = dedup_rc[key]
# Optional LLM iterative dedup parameters
for key in (
"llm_block_size",
"llm_block_concurrency",
"llm_pair_concurrency",
"llm_max_rounds",
):
if key in dedup_rc:
kwargs[key] = dedup_rc[key]
dedup_config = DedupConfig(**kwargs)
# Build StatementExtractionConfig from runtime.json
stmt_kwargs = {}
for key in (
"statement_granularity",
"temperature",
"include_dialogue_context",
"max_dialogue_context_chars",
):
if key in stmt_rc:
stmt_kwargs[key] = stmt_rc[key]
stmt_config = StatementExtractionConfig(**stmt_kwargs)
# Build ForgettingEngineConfig from runtime.json
forget_kwargs = {}
for key in ("offset", "lambda_time", "lambda_mem"):
if key in forget_rc:
forget_kwargs[key] = forget_rc[key]
forget_config = ForgettingEngineConfig(**forget_kwargs)
return ExtractionPipelineConfig(
statement_extraction=stmt_config,
deduplication=dedup_config,
forgetting_engine=forget_config,
)
def get_pruning_config() -> dict:
"""Retrieve semantic pruning config from runtime.json.
Returns a dict suitable for PruningConfig.model_validate.
Structure in runtime.json:
{
"pruning": {
"enabled": true,
"scene": "education" | "online_service" | "outbound",
"threshold": 0.5
}
}
"""
pruning_rc = RUNTIME_CONFIG.get("pruning", {}) or {}
return {
"pruning_switch": bool(pruning_rc.get("enabled", False)),
"pruning_scene": pruning_rc.get("scene", "education"),
"pruning_threshold": float(pruning_rc.get("threshold", 0.5)),
}

View File

@@ -1,360 +1,268 @@
"""
配置加载模块 - 三阶段架构(已迁移到统一配置管理)
# """
# 配置加载模块 - DEPRECATED
本模块现在使用全局配置管理系统 (app/core/config.py)
来加载和管理配置,同时保持向后兼容性。
# ⚠️ DEPRECATION NOTICE ⚠️
# This module is deprecated and will be removed in a future version.
# Global configuration variables have been eliminated in favor of dependency injection.
阶段 1: 从 runtime.json 加载配置(路径 A
阶段 2: 从数据库加载配置(路径 B基于 dbrun.json 中的 config_id
阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)
"""
import os
import json
import threading
from typing import Any, Dict, Optional
from datetime import datetime, timedelta
# Use the new MemoryConfig system instead:
# - app.schemas.memory_config_schema.MemoryConfig for configuration objects
# - config_service = MemoryConfigService(db); config_service.load_memory_config(config_id)
try:
from dotenv import load_dotenv
load_dotenv()
except Exception:
pass
# 阶段 1: 从 runtime.json 加载配置(路径 A- DEPRECATED
# 阶段 2: 从数据库加载配置(路径 B基于 dbrun.json 中的 config_id- DEPRECATED
# 阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED
# """
# import json
# import os
# import threading
# from datetime import datetime, timedelta
# from typing import Any, Dict, Optional
# Import unified configuration system
try:
from app.core.config import settings
USE_UNIFIED_CONFIG = True
except ImportError:
USE_UNIFIED_CONFIG = False
settings = None
# #TODO: Fix this
# PROJECT_ROOT 应该指向 app/core/memory/ 目录
# __file__ = app/core/memory/utils/config/definitions.py
# os.path.dirname(__file__) = app/core/memory/utils/config
# os.path.dirname(...) = app/core/memory/utils
# os.path.dirname(...) = app/core/memory
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# try:
# from dotenv import load_dotenv
# load_dotenv()
# except Exception:
# pass
# 全局配置锁 - 用于线程安全
_config_lock = threading.RLock()
# # Import unified configuration system
# try:
# from app.core.config import settings
# USE_UNIFIED_CONFIG = True
# except ImportError:
# USE_UNIFIED_CONFIG = False
# settings = None
# 加载基础配置config.json- 使用全局配置系统
if USE_UNIFIED_CONFIG:
CONFIG = settings.load_memory_config()
else:
# Fallback to legacy loading
config_path = os.path.join(PROJECT_ROOT, "config.json")
try:
with open(config_path, "r") as f:
CONFIG = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
print("Warning: config.json not found or is malformed. Using default settings.")
CONFIG = {}
# # PROJECT_ROOT 应该指向 app/core/memory/ 目录
# # __file__ = app/core/memory/utils/config/definitions.py
# # os.path.dirname(__file__) = app/core/memory/utils/config
# # os.path.dirname(...) = app/core/memory/utils
# # os.path.dirname(...) = app/core/memory
# PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
DEFAULT_VALUES = {
"llm_name": "openai/qwen-plus",
"embedding_name": "openai/nomic-embed-text:v1.5",
"chunker_strategy": "RecursiveChunker",
"group_id": "group_123",
"user_id": "default_user",
"apply_id": "default_apply",
"llm_agent_name": "openai/qwen-plus",
"llm_verify_name": "openai/qwen-plus",
"llm_image_recognition": "openai/qwen-plus",
"llm_voice_recognition": "openai/qwen-plus",
"prompt_level": "DEBUG",
"reflexion_iteration_period": "3",
"reflexion_range": "retrieval",
"reflexion_baseline": "TIME",
}
# # DEPRECATED: Global configuration lock removed
# # Use MemoryConfig objects with dependency injection instead
# # DEPRECATED: Legacy config.json loading removed
# # Use MemoryConfig objects with dependency injection instead
# CONFIG = {}
# DEFAULT_VALUES = {
# "llm_name": "openai/qwen-plus",
# "embedding_name": "openai/nomic-embed-text:v1.5",
# "chunker_strategy": "RecursiveChunker",
# "group_id": "group_123",
# "user_id": "default_user",
# "apply_id": "default_apply",
# "llm_agent_name": "openai/qwen-plus",
# "llm_verify_name": "openai/qwen-plus",
# "llm_image_recognition": "openai/qwen-plus",
# "llm_voice_recognition": "openai/qwen-plus",
# "prompt_level": "DEBUG",
# "reflexion_iteration_period": "3",
# "reflexion_range": "retrieval",
# "reflexion_baseline": "TIME",
# }
# # DEPRECATED: Legacy global variables for backward compatibility only
# # These will be removed in a future version
# # Use MemoryConfig objects with dependency injection instead
# # LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
# # SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"])
# 阶段 1: 从 runtime.json 加载配置(路径 A
def _load_from_runtime_json() -> Dict[str, Any]:
"""
从 runtime.json 文件加载配置(通过统一配置加载器)
# # 阶段 1: 从 runtime.json 加载配置(路径 A
# def _load_from_runtime_json() -> Dict[str, Any]:
# """
# DEPRECATED: Legacy runtime.json loading
使用 overrides.py 的统一配置加载器,按优先级加载:
1. 数据库配置(如果 dbrun.json 中有 config_id/group_id
2. 环境变量配置
3. runtime.json 默认配置
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
Returns:
Dict[str, Any]: 运行时配置字典
"""
try:
# 使用 overrides.py 的统一配置加载器
from app.core.memory.utils.config.overrides import load_unified_config
# Returns:
# Dict[str, Any]: Empty configuration (legacy support only)
# """
# import warnings
# warnings.warn(
# "Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# return {"selections": {}}
# # 阶段 2: 从数据库加载配置(路径 B- 已整合到统一加载器
# # 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
# # 保留此函数仅为向后兼容
# def _load_from_database() -> Optional[Dict[str, Any]]:
# """
# DEPRECATED: Legacy database configuration loading
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
# Returns:
# Optional[Dict[str, Any]]: None (deprecated functionality)
# """
# import warnings
# warnings.warn(
# "Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# return None
# # 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED
# def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
# """
# DEPRECATED: 将运行时配置暴露为全局常量供项目使用
# ⚠️ This function is deprecated and will be removed in a future version.
# Global configuration variables have been eliminated in favor of dependency injection.
# Use the new MemoryConfig system instead:
# - app.core.memory_config.config.MemoryConfig for configuration objects
# - Pass configuration objects as parameters instead of using global variables
# Args:
# runtime_cfg: 运行时配置字典
# """
# import warnings
# warnings.warn(
# "Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# # Keep minimal global state for backward compatibility only
# # These will be removed in a future version
# global RUNTIME_CONFIG, SELECTIONS
# RUNTIME_CONFIG = runtime_cfg
# SELECTIONS = RUNTIME_CONFIG.get("selections", {})
# # All other global variables have been removed
# # Use MemoryConfig objects instead
# # 初始化:使用统一配置加载器
# def _initialize_configuration() -> None:
# """
# DEPRECATED: Legacy configuration initialization
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
# """
# import warnings
# warnings.warn(
# "Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# # Initialize with empty configuration for backward compatibility
# _expose_runtime_constants({"selections": {}})
# # 模块加载时自动初始化配置
# _initialize_configuration()
# # DEPRECATED: Global variables removed
# # These variables have been eliminated in favor of dependency injection
# # Use MemoryConfig objects instead of accessing global variables
# # 公共 API动态重新加载配置
# def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool:
# """
# DEPRECATED: Legacy configuration reloading
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
# For new code, use:
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
# Args:
# config_id: Configuration ID (deprecated)
# force_reload: Force reload flag (deprecated)
# Returns:
# bool: Always returns False (deprecated functionality)
# """
# import logging
# import warnings
# logger = logging.getLogger(__name__)
# warnings.warn(
# "reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. "
# "Use MemoryConfig objects with dependency injection instead.")
# return False
# def get_current_config_id() -> Optional[str]:
# """
# DEPRECATED: Legacy config ID retrieval
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
# Returns:
# Optional[str]: None (deprecated functionality)
# """
# import warnings
# warnings.warn(
# "get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# return None
# def ensure_fresh_config(config_id = None) -> bool:
# """
# DEPRECATED: Legacy configuration freshness check
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
# For new code, use:
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
# Args:
# config_id: Configuration ID (deprecated)
runtime_cfg = load_unified_config(PROJECT_ROOT)
return runtime_cfg
except Exception as e:
# Fallback: 直接读取 runtime.json
runtime_config_path = os.path.join(PROJECT_ROOT, "runtime.json")
try:
with open(runtime_config_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e2:
pass # print(f"[definitions] ❌ 无法加载 runtime.json: {e2},使用空配置")
return {"selections": {}}
# 阶段 2: 从数据库加载配置(路径 B- 已整合到统一加载器
# 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
# 保留此函数仅为向后兼容
def _load_from_database() -> Optional[Dict[str, Any]]:
"""
从数据库加载配置(基于 dbrun.json 中的 config_id
# Returns:
# bool: Always returns False (deprecated functionality)
# """
# import logging
# import warnings
注意:此函数已被统一配置加载器替代,现在直接调用 _load_from_runtime_json
即可获得包含数据库配置的完整配置。
Returns:
Optional[Dict[str, Any]]: 配置字典
"""
try:
# 直接使用统一配置加载器
return _load_from_runtime_json()
except Exception:
return None
# 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)
def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
"""
将运行时配置暴露为全局常量供项目使用
这是路径 Aruntime.json和路径 B数据库的汇合点
无论配置来自哪里,都通过这个函数统一暴露为常量。
Args:
runtime_cfg: 运行时配置字典
"""
global RUNTIME_CONFIG, SELECTIONS, LOGGING_CONFIG
global LANGFUSE_ENABLED, AGENTA_ENABLED, PROMPT_LOG_LEVEL_NAME
global SELECTED_LLM_NAME, SELECTED_EMBEDDING_NAME, SELECTED_CHUNKER_STRATEGY
global SELECTED_GROUP_ID, SELECTED_USER_ID, SELECTED_APPLY_ID, SELECTED_TEST_DATA_INDICES
global SELECTED_LLM_AGENT_NAME, SELECTED_LLM_VERIFY_NAME, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
global SELECTED_LLM_ID, SELECTED_EMBEDDING_ID, SELECTED_RERANK_ID
global REFLEXION_CONFIG, REFLEXION_ENABLED, REFLEXION_ITERATION_PERIOD, REFLEXION_RANGE, REFLEXION_BASELINE
RUNTIME_CONFIG = runtime_cfg
# 可观测性配置
LANGFUSE_ENABLED = RUNTIME_CONFIG.get("langfuse", {}).get("enabled", False)
AGENTA_ENABLED = RUNTIME_CONFIG.get("agenta", {}).get("enabled", False)
# 日志配置
LOGGING_CONFIG = RUNTIME_CONFIG.get("logging", {})
PROMPT_LOG_LEVEL_NAME = LOGGING_CONFIG.get("prompt_level", DEFAULT_VALUES["prompt_level"])
# 选择配置
SELECTIONS = RUNTIME_CONFIG.get("selections", {})
# 基础模型选择
SELECTED_LLM_NAME = SELECTIONS.get("llm_name", DEFAULT_VALUES["llm_name"])
SELECTED_EMBEDDING_NAME = SELECTIONS.get("embedding_name", DEFAULT_VALUES["embedding_name"])
SELECTED_CHUNKER_STRATEGY = SELECTIONS.get("chunker_strategy", DEFAULT_VALUES["chunker_strategy"])
# 分组和用户配置
SELECTED_GROUP_ID = SELECTIONS.get("group_id", DEFAULT_VALUES["group_id"])
SELECTED_USER_ID = SELECTIONS.get("user_id", DEFAULT_VALUES["user_id"])
SELECTED_APPLY_ID = SELECTIONS.get("apply_id", DEFAULT_VALUES["apply_id"])
SELECTED_TEST_DATA_INDICES = SELECTIONS.get("test_data_indices", None)
# 专用 LLM 配置
SELECTED_LLM_AGENT_NAME = SELECTIONS.get("llm_agent_name", DEFAULT_VALUES["llm_agent_name"])
SELECTED_LLM_VERIFY_NAME = SELECTIONS.get("llm_verify_name", DEFAULT_VALUES["llm_verify_name"])
SELECTED_LLM_PICTURE_NAME = SELECTIONS.get("llm_image_recognition", DEFAULT_VALUES["llm_image_recognition"])
SELECTED_LLM_VOICE_NAME = SELECTIONS.get("llm_voice_recognition", DEFAULT_VALUES["llm_voice_recognition"])
# 模型 ID 配置
SELECTED_LLM_ID = SELECTIONS.get("llm_id", None)
SELECTED_EMBEDDING_ID = SELECTIONS.get("embedding_id", None)
SELECTED_RERANK_ID = SELECTIONS.get("rerank_id", None)
# logger = logging.getLogger(__name__)
# 反思配置
REFLEXION_CONFIG = RUNTIME_CONFIG.get("reflexion", {})
REFLEXION_ENABLED = REFLEXION_CONFIG.get("enabled", False)
REFLEXION_ITERATION_PERIOD = REFLEXION_CONFIG.get("iteration_period", DEFAULT_VALUES["reflexion_iteration_period"])
REFLEXION_RANGE = REFLEXION_CONFIG.get("reflexion_range", DEFAULT_VALUES["reflexion_range"])
REFLEXION_BASELINE = REFLEXION_CONFIG.get("baseline", DEFAULT_VALUES["reflexion_baseline"])
# 初始化:使用统一配置加载器
def _initialize_configuration() -> None:
"""
初始化配置:使用统一配置加载器
# warnings.warn(
# "ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
配置加载优先级(由 overrides.py 统一处理):
1. 数据库配置(如果 dbrun.json 中有 config_id/group_id
2. 环境变量配置(.env
3. runtime.json 默认配置
"""
try:
# 使用统一配置加载器(已包含所有优先级处理)
runtime_config = _load_from_runtime_json()
# 暴露为全局常量
_expose_runtime_constants(runtime_config)
except Exception as e:
pass # print(f"[definitions] × 配置初始化失败: {e}")
# 使用空配置
_expose_runtime_constants({"selections": {}})
# 模块加载时自动初始化配置
_initialize_configuration()
# 公共 API动态重新加载配置
def reload_configuration_from_database(config_id: int | str, force_reload: bool = False) -> bool:
"""
动态重新加载配置(从数据库)- 使用统一配置加载器
用于运行时切换配置,例如前端传入新的 config_id 时调用。
注意:此函数仅在内存中覆写配置,不会修改 runtime.json 文件。
Args:
config_id: 配置 ID整数或字符串会自动转换
force_reload: 保留参数以保持向后兼容(已移除缓存逻辑)
Returns:
bool: 是否成功重新加载配置
"""
import logging
logger = logging.getLogger(__name__)
# logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. "
# "Use MemoryConfig objects with dependency injection instead.")
# 导入审计日志记录器
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
with _config_lock:
try:
from app.core.memory.utils.config.overrides import load_unified_config
except Exception as e:
logger.error(f"[definitions] 导入统一配置加载器失败: {e}")
# 记录配置加载失败
if audit_logger:
audit_logger.log_config_load(
config_id=config_id,
success=False,
details={"error": f"Import failed: {str(e)}"}
)
return False
try:
logger.info(f"[definitions] 开始重新加载配置config_id={config_id}")
# 使用统一配置加载器(指定 config_id
updated_cfg = load_unified_config(PROJECT_ROOT, config_id=config_id)
# 检查是否成功加载
if not updated_cfg or not updated_cfg.get('selections'):
logger.error(f"[definitions] 配置加载失败:数据库中未找到 config_id={config_id} 的配置")
# 记录配置加载失败
if audit_logger:
audit_logger.log_config_load(
config_id=config_id,
success=False,
details={"reason": "config not found in database"}
)
return False
# 重新暴露常量
_expose_runtime_constants(updated_cfg)
logger.info("[definitions] 配置重新加载成功,已暴露常量")
logger.debug(f"[definitions] 配置详情: LLM_ID={updated_cfg.get('selections', {}).get('llm_id')}, "
f"EMBEDDING_ID={updated_cfg.get('selections', {}).get('embedding_id')}")
# 记录成功的配置加载
if audit_logger:
selections = updated_cfg.get('selections', {})
audit_logger.log_config_load(
config_id=config_id,
user_id=selections.get('user_id', None),
group_id=selections.get('group_id', None),
success=True,
details={
"llm_id": selections.get('llm_id'),
"embedding_id": selections.get('embedding_id'),
"chunker_strategy": selections.get('chunker_strategy')
}
)
return True
except Exception as e:
logger.error(f"[definitions] 重新加载配置时发生异常: {e}", exc_info=True)
# 记录配置加载异常
if audit_logger:
audit_logger.log_config_load(
config_id=config_id,
success=False,
details={"error": str(e)}
)
return False
def get_current_config_id() -> Optional[str]:
"""
获取当前使用的 config_id
Returns:
Optional[str]: 当前的 config_id如果未设置则返回 None
"""
return SELECTIONS.get("config_id", None)
def ensure_fresh_config(config_id: Optional[int | str] = None) -> bool:
"""
确保使用最新的配置(每次写入操作前调用)
如果提供了 config_id则加载该配置
否则从 dbrun.json 读取并加载最新配置。
Args:
config_id: 可选的配置ID整数或字符串会自动转换
Returns:
bool: 是否成功加载配置
"""
import logging
logger = logging.getLogger(__name__)
with _config_lock:
try:
if config_id:
# 使用指定的 config_id
logger.debug(f"[definitions] 加载指定配置config_id={config_id}")
return reload_configuration_from_database(config_id)
else:
# 从数据库重新加载配置
logger.debug("[definitions] 从数据库重新加载最新配置")
memory_config = _load_from_database()
if not memory_config or not memory_config.get('selections'):
logger.warning("[definitions] 未能从数据库加载配置,使用当前配置")
return False
_expose_memory_constants(memory_config)
return True
except Exception as e:
logger.error(f"[definitions] 加载配置失败: {e}", exc_info=True)
return False
# return False

View File

@@ -1,609 +0,0 @@
"""
运行时配置覆写工具 - 统一配置加载器
本模块作为统一的配置加载器,负责从多个来源加载配置并按优先级覆写。
配置来源优先级(从高到低):
1. 数据库配置PostgreSQL data_config 表)
2. 环境变量配置(.env 文件)
3. 默认配置runtime.json 文件)
支持的配置加载方式:
- 基于 config_id 的配置加载(从 dbrun.json 读取或前端传入)
- 基于 group_id 的配置加载(从 dbrun.json 读取)
- 环境变量覆写(支持 INTERNAL/EXTERNAL 网络模式)
主要功能:
- 从 PostgreSQL 数据库读取配置
- 从环境变量读取配置
- 从 runtime.json 读取默认配置
- 按优先级覆写配置项(仅在内存中,不修改文件)
- 支持多种配置字段selections、statement_extraction、deduplication、forgetting_engine、pruning、reflexion
使用场景:
- 应用启动时自动加载配置
- 前端切换配置时动态重新加载
- 多租户场景下的配置隔离
- 内外网环境自动切换
"""
import os
import json
from typing import Optional, Dict, Any, Literal
NetworkMode = Literal['internal', 'external']
def _set_if_present(target: Dict[str, Any], target_key: str, src: Dict[str, Any], src_key: str, caster):
"""安全地设置目标字典的值(如果源字典中存在且不为 None
Args:
target: 目标字典
target_key: 目标字典的键
src: 源字典
src_key: 源字典的键
caster: 类型转换函数
"""
try:
if src_key in src and src.get(src_key) is not None:
try:
target[target_key] = caster(src.get(src_key))
except Exception:
pass
except Exception:
pass
def _to_bool(val: Any) -> bool:
"""将各种类型的值转换为布尔值
支持的输入:
- bool: 直接返回
- int/float: 非零为 True
- str: "true", "1", "on", "yes" 为 True"false", "0", "off", "no" 为 False
Args:
val: 要转换的值
Returns:
bool: 转换后的布尔值
"""
try:
if isinstance(val, bool):
return val
if isinstance(val, (int, float)):
return bool(val)
if isinstance(val, str):
m = val.strip().lower()
if m in {"true", "1", "on", "yes"}:
return True
if m in {"false", "0", "off", "no"}:
return False
return bool(val)
except Exception:
return False
def _make_pgsql_conn() -> Optional[object]:
"""创建 PostgreSQL 数据库连接
使用环境变量配置连接参数:
- DB_HOST: 数据库主机地址(默认 localhost
- DB_PORT: 数据库端口(默认 5432
- DB_USER: 数据库用户名
- DB_PASSWORD: 数据库密码
- DB_NAME: 数据库名称
Returns:
Optional[object]: 数据库连接对象,失败时返回 None
"""
host = os.getenv("DB_HOST", "localhost")
user = os.getenv("DB_USER")
password = os.getenv("DB_PASSWORD")
dbname = os.getenv("DB_NAME")
port_str = os.getenv("DB_PORT")
try:
import psycopg2 # type: ignore
port = int(port_str) if port_str else 5432
conn = psycopg2.connect(
host=host,
port=port,
user=user,
password=password,
dbname=dbname,
)
conn.autocommit = True
return conn
except Exception:
return None
def _fetch_db_config_by_group_id(group_id: str) -> Optional[Dict[str, Any]]:
"""根据 group_id 从数据库查询配置
Args:
group_id: 组标识符
Returns:
Optional[Dict[str, Any]]: 配置字典,未找到时返回 None
"""
conn = _make_pgsql_conn()
if conn is None:
return None
try:
from psycopg2.extras import RealDictCursor # type: ignore
cur = conn.cursor(cursor_factory=RealDictCursor)
try:
cur.execute("SET TIME ZONE %s", ("Asia/Shanghai",))
except Exception:
pass
sql = (
"SELECT group_id, user_id, apply_id, chunker_strategy, "
" enable_llm_dedup_blockwise, enable_llm_disambiguation "
"FROM data_config WHERE group_id = %s ORDER BY updated_at DESC LIMIT 1"
)
cur.execute(sql, (group_id,))
row = cur.fetchone()
return row if row else None
except Exception:
return None
finally:
try:
cur.close()
except Exception:
pass
try:
conn.close()
except Exception:
pass
def _fetch_db_config_by_config_id(config_id: int | str) -> Optional[Dict[str, Any]]:
"""根据 config_id 从数据库查询配置
Args:
config_id: 配置标识符(整数或字符串,会自动转换为整数)
Returns:
Optional[Dict[str, Any]]: 配置字典,未找到时返回 None
"""
conn = _make_pgsql_conn()
if conn is None:
try:
pass
except Exception:
pass
return None
try:
from psycopg2.extras import RealDictCursor # type: ignore
cur = conn.cursor(cursor_factory=RealDictCursor)
try:
cur.execute("SET TIME ZONE %s", ("Asia/Shanghai",))
except Exception:
pass
# config_id 在数据库中是 Integer 类型,需要转换
try:
config_id_int = int(config_id)
except (ValueError, TypeError):
try:
pass
except Exception:
pass
return None
sql = (
"SELECT config_id, group_id, user_id, apply_id, chunker_strategy, "
" enable_llm_dedup_blockwise, enable_llm_disambiguation, "
" deep_retrieval, t_type_strict, t_name_strict, t_overall, state, "
" statement_granularity, include_dialogue_context, max_context, "
" \"offset\" AS offset, lambda_time, lambda_mem, "
" pruning_enabled, pruning_scene, pruning_threshold, "
" llm_id, embedding_id, rerank_id "
"FROM data_config WHERE config_id = %s LIMIT 1"
)
cur.execute(sql, (config_id_int,))
row = cur.fetchone()
if row:
try:
pass
except Exception:
pass
else:
pass
return row if row else None
except Exception:
pass
return None
finally:
try:
cur.close()
except Exception:
pass
try:
conn.close()
except Exception:
pass
def _load_dbrun_group_id(project_root: str) -> Optional[str]:
"""从 dbrun.json 读取 group_id
Args:
project_root: 项目根目录路径
Returns:
Optional[str]: group_id未找到时返回 None
"""
try:
path = os.path.join(project_root, "dbrun.json")
if not os.path.isfile(path):
return None
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
if "group_id" in data:
return str(data.get("group_id"))
sel = data.get("selections", {})
if isinstance(sel, dict) and "group_id" in sel:
return str(sel.get("group_id"))
return None
except Exception:
return None
def _load_dbrun_config_id(project_root: str) -> Optional[str]:
"""从 dbrun.json 读取 config_id
Args:
project_root: 项目根目录路径
Returns:
Optional[str]: config_id未找到时返回 None
"""
try:
path = os.path.join(project_root, "dbrun.json")
if not os.path.isfile(path):
return None
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
if "config_id" in data:
return str(data.get("config_id"))
sel = data.get("selections", {})
if isinstance(sel, dict) and "config_id" in sel:
return str(sel.get("config_id"))
return None
except Exception:
return None
def _apply_overrides_from_db_row(
runtime_cfg: Dict[str, Any],
db_row: Optional[Dict[str, Any]],
identifier: str,
identifier_type: str = "config_id"
) -> Dict[str, Any]:
"""从数据库行数据覆写运行时配置(统一处理函数)
Args:
runtime_cfg: 运行时配置字典
db_row: 数据库查询结果行
identifier: 标识符值group_id 或 config_id
identifier_type: 标识符类型("group_id""config_id"
Returns:
Dict[str, Any]: 覆写后的运行时配置
"""
try:
selections = runtime_cfg.setdefault("selections", {})
selections[identifier_type] = identifier
if not db_row:
return runtime_cfg
# 覆写 selections 字段
for tk in ("group_id", "user_id", "apply_id", "chunker_strategy", "state",
"t_type_strict", "t_name_strict", "t_overall",
"statement_granularity", "include_dialogue_context"):
_set_if_present(selections, tk, db_row, tk, str)
# 特殊处理 UUID 字段,确保转换为字符串格式
for uuid_field in ("llm_id", "embedding_id", "rerank_id"):
if uuid_field in db_row and db_row.get(uuid_field) is not None:
try:
value = db_row.get(uuid_field)
# 如果是 UUID 对象,转换为字符串(带连字符的标准格式)
if hasattr(value, 'hex'):
selections[uuid_field] = str(value)
else:
selections[uuid_field] = str(value)
except Exception:
pass
# 覆写 statement_extraction 字段
stmt = runtime_cfg.setdefault("statement_extraction", {})
_set_if_present(stmt, "statement_granularity", db_row, "statement_granularity", int)
_set_if_present(stmt, "include_dialogue_context", db_row, "include_dialogue_context", _to_bool)
_set_if_present(stmt, "max_dialogue_context_chars", db_row, "max_context", int)
# 覆写 deduplication 字段
dedup = runtime_cfg.setdefault("deduplication", {})
for tk in ("enable_llm_dedup_blockwise", "enable_llm_disambiguation"):
_set_if_present(dedup, tk, db_row, tk, _to_bool)
_set_if_present(dedup, "deep_retrieval", db_row, "deep_retrieval", _to_bool)
# 覆写 forgetting_engine 字段
forgetting = runtime_cfg.setdefault("forgetting_engine", {})
_set_if_present(forgetting, "offset", db_row, "offset", float)
_set_if_present(forgetting, "lambda_time", db_row, "lambda_time", float)
_set_if_present(forgetting, "lambda_mem", db_row, "lambda_mem", float)
# 覆写 pruning 字段
pruning = runtime_cfg.setdefault("pruning", {})
_set_if_present(pruning, "enabled", db_row, "pruning_enabled", _to_bool)
_set_if_present(pruning, "scene", db_row, "pruning_scene", str)
# 阈值需要转为 float且限制在 [0.0, 0.9]
try:
if "pruning_threshold" in db_row and db_row.get("pruning_threshold") is not None:
thr = float(db_row.get("pruning_threshold"))
thr = max(0.0, min(0.9, thr)) # 限制在 [0.0, 0.9]
pruning["threshold"] = thr
except Exception:
pass
return runtime_cfg
except Exception:
pass
return runtime_cfg
def apply_runtime_overrides_by_group(project_root: str, runtime_cfg: Dict[str, Any]) -> Dict[str, Any]:
"""基于 group_id 从数据库覆写运行时配置
工作流程:
1. 从 dbrun.json 读取 group_id
2. 根据 group_id 查询数据库配置
3. 覆写运行时配置(仅在内存中)
Args:
project_root: 项目根目录路径
runtime_cfg: 运行时配置字典
Returns:
Dict[str, Any]: 覆写后的运行时配置
"""
try:
selected_gid = _load_dbrun_group_id(project_root)
if not selected_gid:
return runtime_cfg
db_row = _fetch_db_config_by_group_id(selected_gid)
if not db_row:
# 如果数据库中没有配置,仍然设置 group_id
runtime_cfg.setdefault("selections", {})["group_id"] = selected_gid
return runtime_cfg
return _apply_overrides_from_db_row(runtime_cfg, db_row, selected_gid, "group_id")
except Exception:
return runtime_cfg
def apply_runtime_overrides_by_config(project_root: str, runtime_cfg: Dict[str, Any]) -> Dict[str, Any]:
"""基于 config_id 从数据库覆写运行时配置(从 dbrun.json 读取)
工作流程:
1. 从 dbrun.json 读取 config_id
2. 根据 config_id 查询数据库配置
3. 覆写运行时配置(仅在内存中)
Args:
project_root: 项目根目录路径
runtime_cfg: 运行时配置字典
Returns:
Dict[str, Any]: 覆写后的运行时配置
"""
try:
selected_cid = _load_dbrun_config_id(project_root)
if not selected_cid:
return runtime_cfg
db_row = _fetch_db_config_by_config_id(selected_cid)
return _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id")
except Exception:
return runtime_cfg
def apply_runtime_overrides_with_config_id(
project_root: str,
runtime_cfg: Dict[str, Any],
config_id: str
) -> tuple[Dict[str, Any], bool]:
"""使用指定的 config_id 从数据库覆写运行时配置(不读 dbrun.json
用于前端动态切换配置的场景。
Args:
project_root: 项目根目录路径
runtime_cfg: 运行时配置字典
config_id: 配置标识符
Returns:
tuple[Dict[str, Any], bool]: (覆写后的运行时配置, 是否成功从数据库加载)
"""
try:
selected_cid = str(config_id).strip()
if not selected_cid:
return runtime_cfg, False
db_row = _fetch_db_config_by_config_id(selected_cid)
if db_row is None:
return runtime_cfg, False
updated_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id")
return updated_cfg, True
except Exception:
pass
return runtime_cfg, False
# ============================================================================
# 以下函数已注释:不再需要网络模式自动检测功能
# ============================================================================
# def get_server_ip() -> str:
# """
# 获取当前服务器的IP地址
#
# Returns:
# 服务器IP地址字符串
# """
# try:
# # 方式1从环境变量获取优先
# server_ip = os.getenv('SERVER_IP')
# if server_ip and server_ip not in ['127.0.0.1', 'localhost', '0.0.0.0']:
# return server_ip
#
# # 方式2通过socket获取
# hostname = socket.gethostname()
# ip_address = socket.gethostbyname(hostname)
#
# # 如果是本地回环地址尝试获取真实IP
# if ip_address.startswith('127.'):
# # 尝试连接外部地址来获取本机IP
# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# try:
# s.connect(('8.8.8.8', 80))
# ip_address = s.getsockname()[0]
# finally:
# s.close()
#
# return ip_address
# except Exception as e:
# print(f"[overrides] 获取服务器IP失败: {e},使用默认值 127.0.0.1")
# return '127.0.0.1'
# def auto_detect_network_mode() -> NetworkMode:
# """
# 自动检测网络模式基于服务器IP
#
# 规则:
# - 如果服务器IP在内网IP列表中 → internal内网
# - 其他IP → external外网
#
# 可以通过环境变量 INTERNAL_SERVER_IPS 自定义内网IP列表逗号分隔
#
# Returns:
# 'internal' 或 'external'
# """
# server_ip = get_server_ip()
#
# # 从环境变量获取内网IP列表支持多个IP逗号分隔
# internal_ips_str = os.getenv('INTERNAL_SERVER_IPS', '119.45.181.55')
# internal_ips = [ip.strip() for ip in internal_ips_str.split(',')]
#
# # 判断当前IP是否在内网IP列表中
# if server_ip in internal_ips:
# print(f"[overrides] 自动检测服务器IP {server_ip} 属于内网,使用 INTERNAL 配置")
# return 'internal'
# else:
# print(f"[overrides] 自动检测服务器IP {server_ip} 属于外网,使用 EXTERNAL 配置")
# return 'external'
# ============================================================================
# 环境变量覆写功能已废弃 - 不再使用
# ============================================================================
# def _apply_env_var_overrides(runtime_cfg: Dict[str, Any], network_mode: NetworkMode = None, force_override: bool = False) -> Dict[str, Any]:
# """
# 从环境变量覆写配置(已废弃)
# """
# return runtime_cfg
def load_unified_config(
project_root: str,
config_id: Optional[int | str] = None,
group_id: Optional[str] = None,
network_mode: NetworkMode = None,
env_override_models: bool = True
) -> Dict[str, Any]:
"""
统一配置加载器 - 按优先级加载配置
配置加载优先级:
1. PG数据库配置最高优先级通过 dbrun.json 中的 config_id 读取)
2. runtime.json 默认配置(最低优先级)
Args:
project_root: 项目根目录路径
config_id: 配置ID整数或字符串可选优先从 dbrun.json 读取)
group_id: 组ID可选
network_mode: 已废弃,保留参数仅为向后兼容
env_override_models: 已废弃,保留参数仅为向后兼容
Returns:
Dict[str, Any]: 最终的运行时配置
"""
try:
# 步骤 1: 加载 runtime.json 作为基础配置
runtime_config_path = os.path.join(project_root, "runtime.json")
try:
with open(runtime_config_path, "r", encoding="utf-8") as f:
runtime_cfg = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
runtime_cfg = {"selections": {}}
# 步骤 2: 尝试从 dbrun.json 读取 config_id 并应用数据库配置(最高优先级)
if config_id:
# 优先使用传入的 config_id
db_row = _fetch_db_config_by_config_id(config_id)
if db_row:
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, config_id, "config_id")
pass
elif group_id:
# 其次使用 group_id
db_row = _fetch_db_config_by_group_id(group_id)
if db_row:
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, group_id, "group_id")
pass
else:
# 尝试从 dbrun.json 读取
dbrun_config_id = _load_dbrun_config_id(project_root)
if dbrun_config_id:
db_row = _fetch_db_config_by_config_id(dbrun_config_id)
if db_row:
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, dbrun_config_id, "config_id")
pass
else:
dbrun_group_id = _load_dbrun_group_id(project_root)
if dbrun_group_id:
db_row = _fetch_db_config_by_group_id(dbrun_group_id)
if db_row:
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, dbrun_group_id, "group_id")
pass
return runtime_cfg
except Exception:
return {"selections": {}}
# 向后兼容的别名
apply_runtime_overrides = apply_runtime_overrides_by_config

View File

@@ -0,0 +1,11 @@
"""Embedder utilities module."""
from app.core.memory.utils.embedder.embedder_utils import (
get_embedder_client,
get_embedder_client_from_config,
)
__all__ = [
"get_embedder_client",
"get_embedder_client_from_config",
]

View File

@@ -0,0 +1,83 @@
"""Embedder Client Utilities
This module provides centralized functions for creating embedder clients.
"""
from typing import TYPE_CHECKING
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
def get_embedder_client_from_config(memory_config: "MemoryConfig") -> OpenAIEmbedderClient:
"""
Get embedder client from MemoryConfig object.
**PREFERRED METHOD**: Use this function in production code when you have a MemoryConfig object.
This ensures proper configuration management and multi-tenant support.
Args:
memory_config: MemoryConfig object containing embedding_model_id
Returns:
OpenAIEmbedderClient: Initialized embedder client
Raises:
ValueError: If embedding model ID is not configured or client initialization fails
Example:
>>> embedder_client = get_embedder_client_from_config(memory_config)
"""
if not memory_config.embedding_model_id:
raise ValueError(
f"Configuration {memory_config.config_id} has no embedding model configured"
)
return get_embedder_client(str(memory_config.embedding_model_id))
def get_embedder_client(embedding_id: str) -> OpenAIEmbedderClient:
"""
Get embedder client by model ID.
**LEGACY/TEST METHOD**: Use this function only for:
- Test/evaluation code where you have a model ID directly
- Legacy code that hasn't been migrated to MemoryConfig yet
For production code with MemoryConfig, use get_embedder_client_from_config() instead.
Args:
embedding_id: Embedding model ID (required)
Returns:
OpenAIEmbedderClient: Initialized embedder client
Raises:
ValueError: If embedding_id is not provided or client initialization fails
Example:
>>> # For tests/evaluations only
>>> embedder_client = get_embedder_client("model-uuid-string")
"""
if not embedding_id:
raise ValueError("Embedding ID is required but was not provided")
try:
with get_db_context() as db:
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_id)
except Exception as e:
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e
try:
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
return embedder_client
except Exception as e:
model_name = embedder_config_dict.get('model_name', 'unknown')
raise ValueError(
f"Failed to initialize embedder client for model '{model_name}': {str(e)}"
) from e

View File

@@ -1,77 +1,237 @@
import os
from pydantic import BaseModel
from typing import TYPE_CHECKING
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.memory.utils.config import definitions as config_defs
from app.core.models.base import RedBearModelConfig
from pydantic import BaseModel
from sqlalchemy.orm import Session
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
async def handle_response(response: type[BaseModel]) -> dict:
return response.model_dump()
def get_llm_client(llm_id: str | None = None):
llm_id = llm_id or config_defs.SELECTED_LLM_ID
# Validate LLM ID exists before attempting to get config
if not llm_id:
raise ValueError("LLM ID is required but was not provided")
try:
model_config = get_model_config(llm_id)
except Exception as e:
# Re-raise with clear error message about invalid LLM ID
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
try:
# 移除调试打印,避免污染终端输出
# print(model_config)
llm_client = OpenAIClient(RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url")
),type_=model_config.get("type"))
# print(llm.dict())
return llm_client
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
def get_reranker_client(rerank_id: str | None = None):
class MemoryClientFactory:
"""
Get an LLM client configured for reranking.
Factory for creating LLM, embedder, and reranker clients.
Initialize once with db session, then call methods without passing db each time.
Example:
>>> factory = MemoryClientFactory(db)
>>> llm_client = factory.get_llm_client(model_id)
>>> embedder_client = factory.get_embedder_client(embedding_id)
"""
def __init__(self, db: Session):
from app.services.memory_config_service import MemoryConfigService
self._config_service = MemoryConfigService(db)
def get_llm_client(self, llm_id: str) -> OpenAIClient:
"""Get LLM client by model ID."""
if not llm_id:
raise ValueError("LLM ID is required")
try:
model_config = self._config_service.get_model_config(llm_id)
except Exception as e:
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
try:
return OpenAIClient(
RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url")
),
type_=model_config.get("type")
)
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
def get_embedder_client(self, embedding_id: str):
"""Get embedder client by model ID."""
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
if not embedding_id:
raise ValueError("Embedding ID is required")
try:
embedder_config = self._config_service.get_embedder_config(embedding_id)
except Exception as e:
raise ValueError(f"Invalid embedding ID '{embedding_id}': {str(e)}") from e
try:
return OpenAIEmbedderClient(
RedBearModelConfig(
model_name=embedder_config.get("model_name"),
provider=embedder_config.get("provider"),
api_key=embedder_config.get("api_key"),
base_url=embedder_config.get("base_url")
)
)
except Exception as e:
model_name = embedder_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize embedder client for model '{model_name}': {str(e)}") from e
def get_reranker_client(self, rerank_id: str) -> OpenAIClient:
"""Get reranker client by model ID."""
if not rerank_id:
raise ValueError("Rerank ID is required")
try:
model_config = self._config_service.get_model_config(rerank_id)
except Exception as e:
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
try:
return OpenAIClient(
RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url")
),
type_=model_config.get("type")
)
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e
def get_llm_client_from_config(self, memory_config: "MemoryConfig") -> OpenAIClient:
"""Get LLM client from MemoryConfig object.
Args:
memory_config: Configuration containing llm_model_id
Returns:
OpenAIClient configured for the LLM model
Raises:
ValueError: If memory_config has no LLM model configured
"""
if not memory_config.llm_model_id:
raise ValueError(
f"Configuration {memory_config.config_id} has no LLM model configured"
)
return self.get_llm_client(str(memory_config.llm_model_id))
def get_embedder_client_from_config(self, memory_config: "MemoryConfig"):
"""Get embedder client from MemoryConfig object.
Args:
memory_config: Configuration containing embedding_model_id
Returns:
OpenAIEmbedderClient configured for the embedding model
Raises:
ValueError: If memory_config has no embedding model configured
"""
if not memory_config.embedding_model_id:
raise ValueError(
f"Configuration {memory_config.config_id} has no embedding model configured"
)
return self.get_embedder_client(str(memory_config.embedding_model_id))
def get_reranker_client_from_config(self, memory_config: "MemoryConfig") -> OpenAIClient:
"""Get reranker client from MemoryConfig object.
Args:
memory_config: Configuration containing rerank_model_id
Returns:
OpenAIClient configured for the reranker model
Raises:
ValueError: If memory_config has no rerank model configured
"""
if not memory_config.rerank_model_id:
raise ValueError(
f"Configuration {memory_config.config_id} has no rerank model configured"
)
return self.get_reranker_client(str(memory_config.rerank_model_id))
# Legacy functions for backward compatibility
def get_llm_client_from_config(memory_config: "MemoryConfig", db: Session) -> OpenAIClient:
"""Get LLM client from MemoryConfig object.
DEPRECATED: Use MemoryClientFactory(db).get_llm_client_from_config(memory_config) instead.
This function is maintained for backward compatibility during migration to the
factory pattern. New code should create a MemoryClientFactory instance and use
its get_llm_client_from_config method directly.
Args:
rerank_id: Optional reranker model ID. If None, uses SELECTED_RERANK_ID.
memory_config: Configuration containing llm_model_id
db: Database session
Returns:
OpenAIClient: Initialized client for the reranker model
OpenAIClient configured for the LLM model
Raises:
ValueError: If rerank_id is invalid or client initialization fails
ValueError: If memory_config has no LLM model configured
"""
rerank_id = rerank_id or config_defs.SELECTED_RERANK_ID
return MemoryClientFactory(db).get_llm_client_from_config(memory_config)
def get_llm_client(llm_id: str, db: Session) -> OpenAIClient:
"""Get LLM client by model ID.
# Validate rerank ID exists before attempting to get config
if not rerank_id:
raise ValueError("Rerank ID is required but was not provided")
DEPRECATED: Use MemoryClientFactory(db).get_llm_client(llm_id) instead.
try:
model_config = get_model_config(rerank_id)
except Exception as e:
# Re-raise with clear error message about invalid rerank ID
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
This function is maintained for backward compatibility during migration to the
factory pattern. New code should create a MemoryClientFactory instance and use
its get_llm_client method directly.
try:
reranker_client = OpenAIClient(RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url")
),type_=model_config.get("type"))
return reranker_client
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e
Args:
llm_id: LLM model ID
db: Database session
Returns:
OpenAIClient configured for the LLM model
"""
return MemoryClientFactory(db).get_llm_client(llm_id)
def get_embedder_client(embedding_id: str, db: Session):
"""Get embedder client by model ID.
DEPRECATED: Use MemoryClientFactory(db).get_embedder_client(embedding_id) instead.
This function is maintained for backward compatibility during migration to the
factory pattern. New code should create a MemoryClientFactory instance and use
its get_embedder_client method directly.
Args:
embedding_id: Embedding model ID
db: Database session
Returns:
OpenAIEmbedderClient configured for the embedding model
"""
return MemoryClientFactory(db).get_embedder_client(embedding_id)
def get_reranker_client(rerank_id: str, db: Session) -> OpenAIClient:
"""Get reranker client by model ID.
DEPRECATED: Use MemoryClientFactory(db).get_reranker_client(rerank_id) instead.
This function is maintained for backward compatibility during migration to the
factory pattern. New code should create a MemoryClientFactory instance and use
its get_reranker_client method directly.
Args:
rerank_id: Reranker model ID
db: Database session
Returns:
OpenAIClient configured for the reranker model
"""
return MemoryClientFactory(db).get_reranker_client(rerank_id)

View File

@@ -6,11 +6,12 @@
"""
import logging
from typing import List, Any
import time
from typing import Any, List
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.db import get_db_context
from app.schemas.memory_storage_schema import ConflictResultSchema
from pydantic import BaseModel
@@ -25,7 +26,9 @@ async def conflict(evaluate_data: List[Any]) -> List[Any]:
冲突记忆列表JSON 数组)。
"""
from app.core.memory.utils.config import definitions as config_defs
client = get_llm_client(config_defs.SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema)
messages = [{"role": "user", "content": rendered_prompt}]
print(f"提示词长度: {len(rendered_prompt)}")

View File

@@ -6,11 +6,12 @@
"""
import logging
from typing import List, Any
import time
from typing import Any, List
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.db import get_db_context
from app.schemas.memory_storage_schema import ReflexionResultSchema
from pydantic import BaseModel
@@ -25,7 +26,9 @@ async def reflexion(ref_data: List[Any]) -> List[Any]:
反思结果列表JSON 数组)。
"""
from app.core.memory.utils.config import definitions as config_defs
client = get_llm_client(config_defs.SELECTED_LLM_ID)
with get_db_context() as db:
factory = MemoryClientFactory(db)
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema)
messages = [{"role": "user", "content": rendered_prompt}]
print(f"提示词长度: {len(rendered_prompt)}")

View File

@@ -10,28 +10,29 @@
从 app.core.memory.src.data_config_api.self_reflexion 迁移而来。
"""
import os
import asyncio
import json
import logging
import asyncio
from typing import List, Dict, Any
import os
import uuid
from typing import Any, Dict, List
#TODO: Fix this
# Default values (previously from definitions.py)
REFLEXION_ENABLED = os.getenv("REFLEXION_ENABLED", "false").lower() == "true"
REFLEXION_ITERATION_PERIOD = os.getenv("REFLEXION_ITERATION_PERIOD", "3")
REFLEXION_RANGE = os.getenv("REFLEXION_RANGE", "retrieval")
REFLEXION_BASELINE = os.getenv("REFLEXION_BASELINE", "TIME")
from app.core.memory.utils.config.definitions import (
REFLEXION_ENABLED,
REFLEXION_ITERATION_PERIOD,
REFLEXION_RANGE,
REFLEXION_BASELINE,
)
from app.db import get_db
from sqlalchemy.orm import Session
from app.models.retrieval_info import RetrievalInfo
from app.core.memory.utils.config.get_data import get_data
from app.core.memory.utils.self_reflexion_utils.evaluate import conflict
from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion
from app.db import get_db
from app.models.retrieval_info import RetrievalInfo
from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from sqlalchemy.orm import Session
# 并发限制(可通过环境变量覆盖)
CONCURRENCY = int(os.getenv("REFLEXION_CONCURRENCY", "5"))

View File

@@ -5,16 +5,24 @@ This module provides functionality to analyze chunk content and generate insight
"""
import asyncio
from typing import List, Dict, Any
from collections import Counter
from typing import Any, Dict, List
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from pydantic import BaseModel, Field
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.logging_config import get_business_logger
business_logger = get_business_logger()
def _get_llm_client():
"""Get LLM client using db context."""
with get_db_context() as db:
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
class ChunkInsight(BaseModel):
"""Pydantic model for chunk insight."""
insight: str = Field(..., description="对chunk内容的深度洞察分析")
@@ -40,7 +48,7 @@ async def classify_chunk_domain(chunk: str) -> str:
Domain name
"""
try:
llm_client = get_llm_client()
llm_client = _get_llm_client()
prompt = f"""请将以下文本内容归类到最合适的领域中。
@@ -177,7 +185,7 @@ async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str
]
# 调用LLM生成洞察
llm_client = get_llm_client()
llm_client = _get_llm_client()
response = await llm_client.chat(messages=messages)
insight = response.content.strip()

View File

@@ -5,15 +5,23 @@ This module provides functionality to summarize chunk content using LLM.
"""
import asyncio
from typing import List, Dict, Any
from typing import Any, Dict, List
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from pydantic import BaseModel, Field
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.logging_config import get_business_logger
business_logger = get_business_logger()
def _get_llm_client():
"""Get LLM client using db context."""
with get_db_context() as db:
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
class ChunkSummary(BaseModel):
"""Pydantic model for chunk summary."""
summary: str = Field(..., description="简洁的chunk内容摘要")
@@ -59,7 +67,7 @@ async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str
]
# 调用LLM生成摘要
llm_client = get_llm_client()
llm_client = _get_llm_client()
response = await llm_client.chat(messages=messages)
summary = response.content.strip()

View File

@@ -7,14 +7,22 @@ This module provides functionality to extract meaningful tags from chunk content
import asyncio
from collections import Counter
from typing import List, Tuple
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from pydantic import BaseModel, Field
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.logging_config import get_business_logger
business_logger = get_business_logger()
def _get_llm_client():
"""Get LLM client using db context."""
with get_db_context() as db:
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
class ExtractedTags(BaseModel):
"""Pydantic model for extracted tags."""
tags: List[str] = Field(..., description="从文本中提取的关键标签列表")
@@ -56,7 +64,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
)
llm_client = get_llm_client()
llm_client = _get_llm_client()
# 为每个chunk单独提取标签然后统计频率
all_tags = []
@@ -151,7 +159,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
]
# 调用LLM提取人物形象
llm_client = get_llm_client()
llm_client = _get_llm_client()
structured_response = await llm_client.response_structured(
messages=messages,
response_model=ExtractedPersona

View File

@@ -1,6 +1,21 @@
"""
Validators for file upload system.
Validators package for various validation utilities.
"""
from app.core.validators.file_validator import FileValidator, ValidationResult
from app.core.validators.memory_config_validators import (
validate_and_resolve_model_id,
validate_embedding_model,
validate_llm_model,
validate_model_exists_and_active,
)
__all__ = ["FileValidator", "ValidationResult"]
__all__ = [
# File validators
"FileValidator",
"ValidationResult",
# Memory config validators
"validate_model_exists_and_active",
"validate_and_resolve_model_id",
"validate_embedding_model",
"validate_llm_model",
]

View File

@@ -0,0 +1,250 @@
# -*- coding: utf-8 -*-
"""Memory Configuration Validators
This module provides validation functions for memory configuration models.
Functions:
validate_model_exists_and_active: Validate model exists and is active
validate_and_resolve_model_id: Validate and resolve model ID with DB lookup
validate_embedding_model: Validate embedding model availability
validate_llm_model: Validate LLM model availability
"""
import time
from typing import Optional, Union
from uuid import UUID
from app.core.logging_config import get_config_logger
from app.schemas.memory_config_schema import (
InvalidConfigError,
ModelInactiveError,
ModelNotFoundError,
)
from sqlalchemy.orm import Session
logger = get_config_logger()
def _parse_model_id(model_id: Union[str, UUID, None], model_type: str,
config_id: Optional[int] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
"""Parse model ID from string or UUID."""
if model_id is None:
return None
if isinstance(model_id, UUID):
return model_id
if isinstance(model_id, str):
if not model_id.strip():
return None
try:
return UUID(model_id.strip())
except ValueError:
raise InvalidConfigError(
f"Invalid UUID format for {model_type} model ID: '{model_id}'",
field_name=f"{model_type}_model_id",
invalid_value=model_id,
config_id=config_id,
workspace_id=workspace_id
)
raise InvalidConfigError(
f"Invalid type for {model_type} model ID: expected str or UUID, got {type(model_id).__name__}",
field_name=f"{model_type}_model_id",
invalid_value=model_id,
config_id=config_id,
workspace_id=workspace_id
)
def validate_model_exists_and_active(
model_id: UUID,
model_type: str,
db: Session,
tenant_id: Optional[UUID] = None,
config_id: Optional[int] = None,
workspace_id: Optional[UUID] = None
) -> tuple[str, bool]:
"""Validate that a model exists and is active.
Args:
model_id: Model UUID to validate
model_type: Type of model ("llm", "embedding", "rerank")
db: Database session
tenant_id: Optional tenant ID for filtering
config_id: Optional configuration ID for error context
workspace_id: Optional workspace ID for error context
Returns:
Tuple of (model_name, is_active)
Raises:
ModelNotFoundError: If model does not exist
ModelInactiveError: If model exists but is inactive
"""
from app.repositories.model_repository import ModelConfigRepository
start_time = time.time()
try:
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
elapsed_ms = (time.time() - start_time) * 1000
if not model:
logger.warning(
"Model not found",
extra={"model_id": str(model_id), "model_type": model_type, "elapsed_ms": elapsed_ms}
)
raise ModelNotFoundError(
model_id=model_id,
model_type=model_type,
config_id=config_id,
workspace_id=workspace_id,
message=f"{model_type.title()} model {model_id} not found"
)
if not model.is_active:
logger.warning(
"Model inactive",
extra={"model_id": str(model_id), "model_name": model.name, "elapsed_ms": elapsed_ms}
)
raise ModelInactiveError(
model_id=model_id,
model_name=model.name,
model_type=model_type,
config_id=config_id,
workspace_id=workspace_id,
message=f"{model_type.title()} model {model_id} ({model.name}) is inactive"
)
logger.debug(
"Model validation successful",
extra={"model_id": str(model_id), "model_name": model.name, "elapsed_ms": elapsed_ms}
)
return model.name, model.is_active
except (ModelNotFoundError, ModelInactiveError):
raise
except Exception as e:
logger.error(f"Model validation failed: {e}", exc_info=True)
raise
def validate_and_resolve_model_id(
model_id_str: Union[str, UUID, None],
model_type: str,
db: Session,
tenant_id: Optional[UUID] = None,
required: bool = False,
config_id: Optional[int] = None,
workspace_id: Optional[UUID] = None
) -> tuple[Optional[UUID], Optional[str]]:
"""Validate and resolve a model ID, checking existence and active status.
Returns:
Tuple of (validated_uuid, model_name) or (None, None) if not required and empty
"""
if model_id_str is None or (isinstance(model_id_str, str) and not model_id_str.strip()):
if required:
raise InvalidConfigError(
f"{model_type.title()} model ID is required",
field_name=f"{model_type}_model_id",
invalid_value=model_id_str,
config_id=config_id,
workspace_id=workspace_id
)
return None, None
model_uuid = _parse_model_id(model_id_str, model_type, config_id, workspace_id)
if model_uuid is None:
if required:
raise InvalidConfigError(
f"{model_type.title()} model ID is required",
field_name=f"{model_type}_model_id",
invalid_value=model_id_str,
config_id=config_id,
workspace_id=workspace_id
)
return None, None
model_name, _ = validate_model_exists_and_active(
model_uuid, model_type, db, tenant_id, config_id, workspace_id
)
return model_uuid, model_name
def validate_embedding_model(
config_id: int,
embedding_id: Union[str, UUID, None],
db: Session,
tenant_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None
) -> UUID:
"""Validate that embedding model is available and return its UUID.
Raises:
InvalidConfigError: If embedding_id is not provided or invalid
ModelNotFoundError: If embedding model does not exist
ModelInactiveError: If embedding model is inactive
"""
if embedding_id is None or (isinstance(embedding_id, str) and not embedding_id.strip()):
raise InvalidConfigError(
f"Configuration {config_id} has no embedding model configured",
field_name="embedding_model_id",
invalid_value=embedding_id,
config_id=config_id,
workspace_id=workspace_id
)
embedding_uuid, _ = validate_and_resolve_model_id(
embedding_id, "embedding", db, tenant_id, required=True,
config_id=config_id, workspace_id=workspace_id
)
if embedding_uuid is None:
raise InvalidConfigError(
f"Configuration {config_id} has no embedding model configured",
field_name="embedding_model_id",
invalid_value=embedding_id,
config_id=config_id,
workspace_id=workspace_id
)
return embedding_uuid
def validate_llm_model(
config_id: int,
llm_id: Union[str, UUID, None],
db: Session,
tenant_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None
) -> UUID:
"""Validate that LLM model is available and return its UUID.
Raises:
InvalidConfigError: If llm_id is not provided or invalid
ModelNotFoundError: If LLM model does not exist
ModelInactiveError: If LLM model is inactive
"""
if llm_id is None or (isinstance(llm_id, str) and not llm_id.strip()):
raise InvalidConfigError(
f"Configuration {config_id} has no LLM model configured",
field_name="llm_model_id",
invalid_value=llm_id,
config_id=config_id,
workspace_id=workspace_id
)
llm_uuid, _ = validate_and_resolve_model_id(
llm_id, "llm", db, tenant_id, required=True,
config_id=config_id, workspace_id=workspace_id
)
if llm_uuid is None:
raise InvalidConfigError(
f"Configuration {config_id} has no LLM model configured",
field_name="llm_model_id",
invalid_value=llm_id,
config_id=config_id,
workspace_id=workspace_id
)
return llm_uuid

View File

@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
"""Memory Configuration Model - Backward Compatibility
This module provides backward compatibility for imports.
All classes have been moved to app.schemas.memory_config_schema.
DEPRECATED: Import from app.schemas.memory_config_schema instead.
"""
# Re-export for backward compatibility
from app.schemas.memory_config_schema import (
ConfigurationError,
InvalidConfigError,
MemoryConfig,
MemoryConfigValidation,
ModelInactiveError,
ModelNotFoundError,
ModelValidation,
WorkspaceNotFoundError,
WorkspaceValidation,
validate_memory_config_data,
validate_model_data,
validate_workspace_data,
)
__all__ = [
"ConfigurationError",
"InvalidConfigError",
"MemoryConfig",
"MemoryConfigValidation",
"ModelInactiveError",
"ModelNotFoundError",
"ModelValidation",
"WorkspaceNotFoundError",
"WorkspaceValidation",
"validate_memory_config_data",
"validate_model_data",
"validate_workspace_data",
]

View File

@@ -8,22 +8,25 @@ Classes:
DataConfigRepository: 数据配置仓储类提供CRUD操作
"""
from typing import Dict, List, Optional, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import desc
import uuid
from typing import Dict, List, Optional, Tuple
from app.core.logging_config import get_config_logger, get_db_logger
from app.models.data_config_model import DataConfig
from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
)
from app.core.logging_config import get_db_logger
from sqlalchemy import desc
from sqlalchemy.orm import Session
# 获取数据库专用日志器
db_logger = get_db_logger()
# 获取配置专用日志器
config_logger = get_config_logger()
TABLE_NAME = "data_config"
class DataConfigRepository:
@@ -525,7 +528,129 @@ class DataConfigRepository:
except Exception as e:
db_logger.error(f"根据ID查询数据配置失败: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_config_with_workspace(db: Session, config_id: int) -> Optional[tuple]:
"""Get data config and its associated workspace information
Args:
db: Database session
config_id: Configuration ID
Returns:
Optional[tuple]: (DataConfig, Workspace) tuple, None if not found
Raises:
ValueError: Raised when config exists but workspace doesn't
"""
import time
from app.models.workspace_model import Workspace
start_time = time.time()
# Log configuration loading start
config_logger.info(
"Loading configuration with workspace",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id
}
)
db_logger.debug(f"Querying data config and workspace: config_id={config_id}")
try:
# Use join query to get both config and workspace
result = db.query(DataConfig, Workspace).join(
Workspace, DataConfig.workspace_id == Workspace.id
).filter(DataConfig.config_id == config_id).first()
elapsed_ms = (time.time() - start_time) * 1000
if not result:
# Check if config exists but workspace is missing
config_only = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if config_only:
if config_only.workspace_id is None:
config_logger.error(
"Configuration has no associated workspace ID",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"workspace_id": None,
"load_result": "no_workspace_id",
"elapsed_ms": elapsed_ms
}
)
db_logger.error(f"Data config {config_id} has no associated workspace ID")
raise ValueError(f"Configuration {config_id} has no associated workspace")
else:
config_logger.error(
"Configuration references non-existent workspace",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"workspace_id": str(config_only.workspace_id),
"load_result": "workspace_not_found",
"elapsed_ms": elapsed_ms
}
)
db_logger.error(f"Data config {config_id} references non-existent workspace {config_only.workspace_id}")
raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
config_logger.debug(
"Configuration not found",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"load_result": "not_found",
"elapsed_ms": elapsed_ms
}
)
db_logger.debug(f"Data config not found: config_id={config_id}")
return None
config, workspace = result
# Log successful configuration loading
config_logger.info(
"Configuration with workspace loaded successfully",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"config_name": config.config_name,
"workspace_id": str(workspace.id),
"workspace_name": workspace.name,
"tenant_id": str(workspace.tenant_id),
"load_result": "success",
"elapsed_ms": elapsed_ms
}
)
db_logger.debug(f"Data config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
return (config, workspace)
except ValueError:
# Re-raise known business exceptions
raise
except Exception as e:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Failed to load configuration with workspace",
extra={
"operation": "get_config_with_workspace",
"config_id": config_id,
"load_result": "error",
"error_type": type(e).__name__,
"error_message": str(e),
"elapsed_ms": elapsed_ms
},
exc_info=True
)
db_logger.error(f"Failed to query data config and workspace: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]:
"""获取所有配置参数

View File

@@ -0,0 +1,474 @@
# -*- coding: utf-8 -*-
"""Memory Configuration Schemas
This module provides schema definitions for memory configuration.
Classes:
MemoryConfig: Immutable memory configuration loaded from database
MemoryConfigValidation: Pydantic model for configuration validation
WorkspaceValidation: Pydantic model for workspace validation
ModelValidation: Pydantic model for model configuration validation
ConfigurationError: Base exception for configuration-related errors
WorkspaceNotFoundError: Raised when workspace does not exist
ModelNotFoundError: Raised when a required model does not exist
ModelInactiveError: Raised when a required model exists but is inactive
InvalidConfigError: Raised when configuration validation fails
"""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, Literal, Optional, Union
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
# ==================== Configuration Exception Classes ====================
class ConfigurationError(Exception):
"""Base exception for configuration-related errors.
This exception includes context information to help with debugging
and provides detailed error messages for different failure scenarios.
"""
def __init__(
self,
message: str,
config_id: Optional[int] = None,
workspace_id: Optional[UUID] = None,
context: Optional[Dict[str, Any]] = None,
):
"""Initialize configuration error with context.
Args:
message: Error message describing the failure
config_id: Optional configuration ID for context
workspace_id: Optional workspace ID for context
context: Optional additional context information
"""
self.config_id = config_id
self.workspace_id = workspace_id
self.context = context or {}
# Build detailed error message with context
detailed_message = message
if config_id is not None:
detailed_message = f"Configuration {config_id}: {message}"
if workspace_id is not None:
detailed_message = f"{detailed_message} (workspace: {workspace_id})"
# Add context information if available
if self.context:
context_str = ", ".join(f"{k}={v}" for k, v in self.context.items())
detailed_message = f"{detailed_message} [Context: {context_str}]"
super().__init__(detailed_message)
class WorkspaceNotFoundError(ConfigurationError):
"""Raised when workspace does not exist."""
def __init__(
self,
workspace_id: UUID,
config_id: Optional[int] = None,
message: Optional[str] = None,
):
if message is None:
message = f"Workspace {workspace_id} not found in database"
context = {"workspace_id": str(workspace_id)}
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
class ModelNotFoundError(ConfigurationError):
"""Raised when a required model does not exist."""
def __init__(
self,
model_id: Union[str, UUID],
model_type: str,
config_id: Optional[int] = None,
workspace_id: Optional[UUID] = None,
message: Optional[str] = None,
):
if message is None:
message = f"{model_type.title()} model {model_id} not found in database"
context = {
"model_id": str(model_id),
"model_type": model_type,
"failure_type": "not_found",
}
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
class ModelInactiveError(ConfigurationError):
"""Raised when a required model exists but is inactive."""
def __init__(
self,
model_id: Union[str, UUID],
model_name: str,
model_type: str,
config_id: Optional[int] = None,
workspace_id: Optional[UUID] = None,
message: Optional[str] = None,
):
if message is None:
message = f"{model_type.title()} model {model_id} ({model_name}) is inactive"
context = {
"model_id": str(model_id),
"model_name": model_name,
"model_type": model_type,
"failure_type": "inactive",
}
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
class InvalidConfigError(ConfigurationError):
"""Raised when configuration validation fails."""
def __init__(
self,
message: str,
field_name: Optional[str] = None,
invalid_value: Optional[Any] = None,
config_id: Optional[int] = None,
workspace_id: Optional[UUID] = None,
):
context = {}
if field_name is not None:
context["field_name"] = field_name
if invalid_value is not None:
context["invalid_value"] = str(invalid_value)
context["invalid_value_type"] = type(invalid_value).__name__
super().__init__(message, config_id=config_id, workspace_id=workspace_id, context=context)
# ==================== Pydantic Validation Models ====================
class MemoryConfigValidation(BaseModel):
"""Pydantic model for validating memory configuration data from database."""
config_id: int = Field(..., gt=0, description="Configuration ID must be positive")
config_name: str = Field(..., min_length=1, max_length=255)
workspace_id: UUID = Field(..., description="Workspace UUID")
workspace_name: str = Field(..., min_length=1, max_length=255)
tenant_id: UUID = Field(..., description="Tenant UUID")
embedding_model_id: UUID = Field(..., description="Embedding model UUID (required)")
embedding_model_name: str = Field(..., min_length=1, max_length=255)
llm_model_id: UUID = Field(..., description="LLM model UUID (required)")
llm_model_name: str = Field(..., min_length=1, max_length=255)
rerank_model_id: Optional[UUID] = Field(None, description="Rerank model UUID (optional)")
rerank_model_name: Optional[str] = Field(None, max_length=255)
storage_type: str = Field(..., min_length=1, max_length=50)
chunker_strategy: str = Field(default="RecursiveChunker", min_length=1, max_length=100)
reflexion_enabled: bool = Field(default=False)
reflexion_iteration_period: int = Field(default=3, ge=1, le=100)
reflexion_range: Literal["retrieval", "all"] = Field(default="retrieval")
reflexion_baseline: Literal["time", "fact", "time_and_fact"] = Field(default="time")
llm_params: Dict[str, Any] = Field(default_factory=dict)
embedding_params: Dict[str, Any] = Field(default_factory=dict)
config_version: str = Field(default="2.0", min_length=1, max_length=10)
@field_validator("config_name", "workspace_name", "embedding_model_name", "llm_model_name")
@classmethod
def validate_non_empty_strings(cls, v):
if not v or not v.strip():
raise ValueError("Field cannot be empty or whitespace-only")
return v.strip()
@field_validator("storage_type")
@classmethod
def validate_storage_type(cls, v):
valid_types = ["neo4j", "elasticsearch", "qdrant", "milvus", "chroma"]
if v.lower() not in valid_types:
raise ValueError(f"Storage type must be one of: {valid_types}")
return v.lower()
@field_validator("llm_params", "embedding_params")
@classmethod
def validate_model_params(cls, v):
if not isinstance(v, dict):
raise ValueError("Model parameters must be a dictionary")
reserved_keys = ["model_id", "model_name", "api_key", "base_url"]
for key in v.keys():
if key in reserved_keys:
raise ValueError(f"Model parameters cannot contain reserved parameter '{key}'")
return v
model_config = ConfigDict(validate_assignment=True, extra="forbid")
class WorkspaceValidation(BaseModel):
"""Pydantic model for validating workspace data from database."""
id: UUID = Field(..., description="Workspace UUID")
name: str = Field(..., min_length=1, max_length=255)
tenant_id: UUID = Field(..., description="Tenant UUID")
storage_type: Optional[str] = Field(None, max_length=50)
llm: Optional[str] = Field(None)
embedding: Optional[str] = Field(None)
rerank: Optional[str] = Field(None)
is_active: bool = Field(default=True)
@field_validator("llm", "embedding", "rerank")
@classmethod
def validate_model_ids(cls, v):
if v is None or v == "":
return None
try:
UUID(v.strip())
except ValueError:
raise ValueError("Model ID must be a valid UUID string")
return v.strip()
@field_validator("is_active")
@classmethod
def validate_active_status(cls, v):
if not v:
raise ValueError("Workspace must be active for configuration loading")
return v
model_config = ConfigDict(validate_assignment=True, extra="forbid")
class ModelValidation(BaseModel):
"""Pydantic model for validating model configuration data."""
id: UUID = Field(..., description="Model UUID")
name: str = Field(..., min_length=1, max_length=255)
type: str = Field(..., description="Model type (llm, embedding, rerank)")
tenant_id: UUID = Field(..., description="Tenant UUID")
is_active: bool = Field(..., description="Whether model is active")
is_public: bool = Field(default=False)
@field_validator("type")
@classmethod
def validate_type(cls, v):
valid_types = ["llm", "embedding", "rerank"]
if v.lower() not in valid_types:
raise ValueError(f"Model type must be one of: {valid_types}")
return v.lower()
@field_validator("is_active")
@classmethod
def validate_active_status(cls, v):
if not v:
raise ValueError("Model must be active for configuration use")
return v
model_config = ConfigDict(validate_assignment=True, extra="forbid")
# ==================== Validation Helper Functions ====================
def validate_memory_config_data(
config_data: Dict[str, Any], config_id: Optional[int] = None
) -> MemoryConfigValidation:
"""Validate memory configuration data using Pydantic model."""
try:
return MemoryConfigValidation(**config_data)
except ValidationError as e:
error_messages = []
for error in e.errors():
field_path = " -> ".join(str(loc) for loc in error["loc"])
error_messages.append(f"Field '{field_path}': {error['msg']}")
detailed_message = "Configuration validation failed:\n" + "\n".join(
f" - {msg}" for msg in error_messages
)
first_error = e.errors()[0] if e.errors() else {}
first_field = " -> ".join(str(loc) for loc in first_error.get("loc", []))
raise InvalidConfigError(
detailed_message,
field_name=first_field or None,
invalid_value=first_error.get("input"),
config_id=config_id,
)
def validate_workspace_data(
workspace_data: Dict[str, Any], config_id: Optional[int] = None
) -> WorkspaceValidation:
"""Validate workspace data using Pydantic model."""
try:
return WorkspaceValidation(**workspace_data)
except ValidationError as e:
error_messages = []
for error in e.errors():
field_path = " -> ".join(str(loc) for loc in error["loc"])
error_messages.append(f"Field '{field_path}': {error['msg']}")
detailed_message = "Workspace validation failed:\n" + "\n".join(
f" - {msg}" for msg in error_messages
)
first_error = e.errors()[0] if e.errors() else {}
first_field = " -> ".join(str(loc) for loc in first_error.get("loc", []))
workspace_id = workspace_data.get("id") if isinstance(workspace_data, dict) else None
raise InvalidConfigError(
detailed_message,
field_name=first_field or None,
invalid_value=first_error.get("input"),
config_id=config_id,
workspace_id=workspace_id,
)
def validate_model_data(
model_data: Dict[str, Any], config_id: Optional[int] = None
) -> ModelValidation:
"""Validate model data using Pydantic model."""
try:
return ModelValidation(**model_data)
except ValidationError as e:
error_messages = []
for error in e.errors():
field_path = " -> ".join(str(loc) for loc in error["loc"])
error_messages.append(f"Field '{field_path}': {error['msg']}")
detailed_message = "Model validation failed:\n" + "\n".join(
f" - {msg}" for msg in error_messages
)
first_error = e.errors()[0] if e.errors() else {}
first_field = " -> ".join(str(loc) for loc in first_error.get("loc", []))
raise InvalidConfigError(
detailed_message,
field_name=first_field or None,
invalid_value=first_error.get("input"),
config_id=config_id,
)
# ==================== Immutable Configuration Data Structure ====================
@dataclass(frozen=True)
class MemoryConfig:
"""Immutable memory configuration loaded from database."""
config_id: int
config_name: str
workspace_id: UUID
workspace_name: str
tenant_id: UUID
embedding_model_id: UUID
embedding_model_name: str
llm_model_id: UUID
llm_model_name: str
storage_type: str
chunker_strategy: str
reflexion_enabled: bool
reflexion_iteration_period: int
reflexion_range: str
reflexion_baseline: str
loaded_at: datetime
rerank_model_id: Optional[UUID] = None
rerank_model_name: Optional[str] = None
llm_params: Dict[str, Any] = field(default_factory=dict)
embedding_params: Dict[str, Any] = field(default_factory=dict)
config_version: str = "2.0"
# Pipeline config: Deduplication
enable_llm_dedup_blockwise: bool = False
enable_llm_disambiguation: bool = False
deep_retrieval: bool = True
t_type_strict: float = 0.8
t_name_strict: float = 0.8
t_overall: float = 0.8
# Pipeline config: Statement extraction
statement_granularity: int = 2
include_dialogue_context: bool = False
max_dialogue_context_chars: int = 1000
# Pipeline config: Forgetting engine
lambda_time: float = 0.5
lambda_mem: float = 0.5
offset: float = 0.0
# Pipeline config: Pruning
pruning_enabled: bool = False
pruning_scene: Optional[str] = "education"
pruning_threshold: float = 0.5
def __post_init__(self):
"""Validate configuration after initialization."""
if not self.config_name or not self.config_name.strip():
raise InvalidConfigError("Configuration name cannot be empty")
if not self.embedding_model_id:
raise InvalidConfigError("Embedding model ID is required")
if not self.llm_model_id:
raise InvalidConfigError("LLM model ID is required")
@classmethod
def from_validated_data(
cls, validated_config: MemoryConfigValidation, loaded_at: datetime
) -> "MemoryConfig":
"""Create MemoryConfig from validated Pydantic data."""
return cls(
config_id=validated_config.config_id,
config_name=validated_config.config_name,
workspace_id=validated_config.workspace_id,
workspace_name=validated_config.workspace_name,
tenant_id=validated_config.tenant_id,
embedding_model_id=validated_config.embedding_model_id,
embedding_model_name=validated_config.embedding_model_name,
storage_type=validated_config.storage_type,
chunker_strategy=validated_config.chunker_strategy,
reflexion_enabled=validated_config.reflexion_enabled,
reflexion_iteration_period=validated_config.reflexion_iteration_period,
reflexion_range=validated_config.reflexion_range,
reflexion_baseline=validated_config.reflexion_baseline,
loaded_at=loaded_at,
llm_model_id=validated_config.llm_model_id,
llm_model_name=validated_config.llm_model_name,
rerank_model_id=validated_config.rerank_model_id,
rerank_model_name=validated_config.rerank_model_name,
llm_params=validated_config.llm_params,
embedding_params=validated_config.embedding_params,
config_version=validated_config.config_version,
)
def get_model_summary(self) -> Dict[str, Optional[str]]:
"""Get a summary of configured models."""
return {
"llm": self.llm_model_name,
"embedding": self.embedding_model_name,
"rerank": self.rerank_model_name,
}
def is_model_configured(self, model_type: str) -> bool:
"""Check if a specific model type is configured."""
if model_type == "llm":
return True
elif model_type == "embedding":
return True
elif model_type == "rerank":
return self.rerank_model_id is not None
else:
raise ValueError(f"Unknown model type: {model_type}")

View File

@@ -3,26 +3,26 @@
提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。
"""
import time
import uuid
import json
import asyncio
import datetime
from typing import Dict, Any, Optional, List, AsyncGenerator
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from sqlalchemy import select
import json
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.models import AgentConfig, ModelConfig, ModelApiKey
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.core.rag.nlp.search import knowledge_retrieval
from app.models import AgentConfig, ModelApiKey, ModelConfig
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from app.core.rag.nlp.search import knowledge_retrieval
from app.services.langchain_tool_server import Search
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -83,17 +83,23 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"""
logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}")
try:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
group_id=end_user_id,
message=question,
history=[],
search_switch="1",
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
from app.db import get_db
db = next(get_db())
try:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
group_id=end_user_id,
message=question,
history=[],
search_switch="1",
config_id=config_id,
db=db,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
)
)
finally:
db.close()
logger.info(f'用户IDAgent:{end_user_id}')
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
@@ -713,9 +719,9 @@ class DraftRunService:
Raises:
BusinessException: 当指定的会话不存在时
"""
from app.services.conversation_service import ConversationService
from app.schemas.conversation_schema import ConversationCreate
from app.models import Conversation as ConversationModel
from app.schemas.conversation_schema import ConversationCreate
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db)

View File

@@ -7,14 +7,15 @@ Classes:
EmotionAnalyticsService: 情绪分析服务,提供各种情绪分析功能
"""
from typing import Dict, Any, Optional, List
import statistics
import json
from pydantic import BaseModel, Field
import statistics
from typing import Any, Dict, List, Optional
from app.core.logging_config import get_business_logger
from app.repositories.neo4j.emotion_repository import EmotionRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.logging_config import get_business_logger
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = get_business_logger()
@@ -454,7 +455,7 @@ class EmotionAnalyticsService:
async def generate_emotion_suggestions(
self,
end_user_id: str,
config_id: Optional[int] = None
db: Session,
) -> Dict[str, Any]:
"""生成个性化情绪建议
@@ -462,7 +463,7 @@ class EmotionAnalyticsService:
Args:
end_user_id: 宿主ID用户组ID
config_id: 配置ID可选用于从数据库加载LLM配置
db: 数据库会话
Returns:
Dict: 包含个性化建议的响应:
@@ -470,14 +471,32 @@ class EmotionAnalyticsService:
- suggestions: 建议列表3-5条
"""
try:
logger.info(f"生成个性化情绪建议: user={end_user_id}, config_id={config_id}")
logger.info(f"生成个性化情绪建议: user={end_user_id}")
# 1. 如果提供了 config_id从数据库加载配置
if config_id is not None:
from app.core.memory.utils.config.definitions import reload_configuration_from_database
config_loaded = reload_configuration_from_database(config_id)
if not config_loaded:
logger.warning(f"无法加载配置 config_id={config_id},将使用默认配置")
# 1. 从 end_user_id 获取关联的 memory_config_id
llm_client = None
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is not None:
from app.services.memory_config_service import (
MemoryConfigService,
)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=int(config_id),
service_name="EmotionAnalyticsService.generate_emotion_suggestions"
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
except Exception as e:
logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}")
# 2. 获取情绪健康数据
health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d")
@@ -498,8 +517,9 @@ class EmotionAnalyticsService:
prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile)
# 7. 调用LLM生成建议使用配置中的LLM
from app.core.memory.utils.llm.llm_utils import get_llm_client
llm_client = get_llm_client()
if llm_client is None:
# 无法获取配置时,抛出错误而不是使用默认配置
raise ValueError("无法获取LLM配置请确保end_user关联了有效的memory_config")
# 将 prompt 转换为 messages 格式
messages = [
@@ -598,7 +618,9 @@ class EmotionAnalyticsService:
Returns:
str: LLM prompt
"""
from app.core.memory.utils.prompt.prompt_utils import render_emotion_suggestions_prompt
from app.core.memory.utils.prompt.prompt_utils import (
render_emotion_suggestions_prompt,
)
prompt = await render_emotion_suggestions_prompt(
health_data=health_data,

View File

@@ -9,10 +9,12 @@ Classes:
import logging
from typing import Optional
from app.core.memory.models.emotion_models import EmotionExtraction
from app.models.data_config_model import DataConfig
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.core.memory.models.emotion_models import EmotionExtraction
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.data_config_model import DataConfig
logger = logging.getLogger(__name__)
@@ -50,7 +52,9 @@ class EmotionExtractionService:
"""
if self.llm_client is None or model_id:
effective_model_id = model_id or self.llm_id
self.llm_client = get_llm_client(effective_model_id)
with get_db_context() as db:
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(effective_model_id)
return self.llm_client
async def extract_emotion(
@@ -142,7 +146,9 @@ class EmotionExtractionService:
Returns:
Formatted prompt string for LLM
"""
from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt
from app.core.memory.utils.prompt.prompt_utils import (
render_emotion_extraction_prompt,
)
prompt = await render_emotion_extraction_prompt(
statement=statement,

View File

@@ -4,50 +4,48 @@ Memory Agent Service
Handles business logic for memory agent operations including read/write services,
health checks, and message type classification.
"""
import json
import os
import re
import time
import json
import uuid
from threading import Lock
from typing import Dict, List, Optional, Any, AsyncGenerator
from app.services.memory_konwledges_server import write_rag
from typing import Any, AsyncGenerator, Dict, List, Optional
import redis
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_mcp_adapters.tools import load_mcp_tools
from sqlalchemy.orm import Session
from sqlalchemy import func
from pydantic import BaseModel, Field
from app.core.config import settings
from app.core.logging_config import get_logger
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
from app.core.memory.agent.utils.type_classifier import status_typle
from app.db import get_db
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.schemas.memory_storage_schema import ApiResponse, ok, fail
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.data_config_repository import DataConfigRepository
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.services.memory_konwledges_server import memory_konwledges_up, SimpleUser, find_document_id_by_kb_and_filename
from app.core.memory.utils.config.definitions import reload_configuration_from_database
from app.schemas.file_schema import CustomTextFileCreate
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
)
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_mcp_adapters.tools import load_mcp_tools
from pydantic import BaseModel, Field
from sqlalchemy import func
from sqlalchemy.orm import Session
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
logger = get_logger(__name__)
config_logger = get_config_logger()
# Initialize Neo4j connector for analytics functions
_neo4j_connector = Neo4jConnector()
db_gen = get_db()
db = next(db_gen)
class MemoryAgentService:
"""Service for memory agent operations"""
@@ -257,14 +255,17 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> str:
async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
"""
Process write operation with config_id
Args:
group_id: Group identifier
group_id: Group identifier (also used as end_user_id)
message: Message to write
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
Returns:
Write operation result status
@@ -272,24 +273,40 @@ class MemoryAgentService:
Raises:
ValueError: If config loading fails or write operation fails
"""
if config_id==None:
config_id = os.getenv("config_id")
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
import time
start_time = time.time()
# 如果 config_id 为 None使用默认值 "17"
config_loaded = reload_configuration_from_database(config_id)
if not config_loaded:
error_msg = f"Failed to load configuration for config_id: {config_id}"
# Load configuration from database only
try:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# 记录失败的操作
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation( operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg )
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
logger.info(f"Configuration loaded successfully for config_id: {config_id}")
mcp_config = get_mcp_server_config()
client = MultiServerMCPClient(mcp_config)
@@ -300,20 +317,43 @@ class MemoryAgentService:
async with client.session("data_flow") as session:
logger.debug("Connected to MCP Server: data_flow")
tools = await load_mcp_tools(session)
workflow_errors = [] # Track errors from workflow
# Pass config_id to the graph workflow
async with make_write_graph(group_id, tools, group_id, group_id, config_id=config_id) as graph:
# Pass memory_config to the graph workflow
async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph:
logger.debug("Write graph created successfully")
config = {"configurable": {"thread_id": group_id}}
async for event in graph.astream(
{"messages": message, "config_id": config_id},
{"messages": message, "memory_config": memory_config, "errors": []},
stream_mode="values",
config=config
):
messages = event.get('messages')
return self.writer_messages_deal(messages,start_time,group_id,config_id,message)
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.error(f"Write workflow failed with errors: {error_details}")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="WRITE",
config_id=config_id,
group_id=group_id,
success=False,
duration=duration,
error=error_details
)
raise ValueError(f"Write workflow failed: {error_details}")
return self.writer_messages_deal(messages, start_time, group_id, config_id, message)
async def read_memory(
self,
@@ -321,7 +361,8 @@ class MemoryAgentService:
message: str,
history: List[Dict],
search_switch: str,
config_id: str,
config_id: Optional[str],
db: Session,
storage_type: str,
user_rag_memory_id: str
) -> Dict:
@@ -334,11 +375,14 @@ class MemoryAgentService:
- "2": Direct answer based on context
Args:
group_id: Group identifier
group_id: Group identifier (also used as end_user_id)
message: User message
history: Conversation history
search_switch: Search mode switch
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
Returns:
Dict with 'answer' and 'intermediate_outputs' keys
@@ -350,8 +394,18 @@ class MemoryAgentService:
import time
start_time = time.time()
if config_id==None:
config_id = os.getenv("config_id")
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
@@ -365,15 +419,19 @@ class MemoryAgentService:
group_lock = self.get_group_lock(group_id)
with group_lock:
# Step 1: Load configuration from database
from app.core.memory.utils.config.definitions import reload_configuration_from_database
config_loaded = reload_configuration_from_database(config_id)
if not config_loaded:
error_msg = f"Failed to load configuration for config_id: {config_id}"
# Step 1: Load configuration from database only
try:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# 记录失败的操作
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
@@ -387,8 +445,6 @@ class MemoryAgentService:
raise ValueError(error_msg)
logger.info(f"Configuration loaded successfully for config_id: {config_id}")
# Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
@@ -404,45 +460,52 @@ class MemoryAgentService:
intermediate_outputs = []
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
# Pass config_id to the graph workflow
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, config_id=config_id,storage_type=storage_type,user_rag_memory_id=user_rag_memory_id) as graph:
# Pass memory_config to the graph workflow
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
start = time.time()
config = {"configurable": {"thread_id": group_id}}
workflow_errors = [] # Track errors from workflow
async for event in graph.astream(
{"messages": history, "config_id": config_id},
{"messages": history, "memory_config": memory_config, "errors": []},
stream_mode="values",
config=config
):
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
for msg in messages:
msg_content = msg.content
msg_role = msg.__class__.__name__.lower().replace("message", "")
outputs.append({
"role": msg.__class__.__name__.lower().replace("message", ""),
"role": msg_role,
"content": msg_content
})
# Extract intermediate outputs
if hasattr(msg, 'content'):
try:
# Debug: log message type and content preview
msg_type = msg.__class__.__name__
content_preview = str(msg_content)[:200] if msg_content else "empty"
logger.debug(f"Processing message type={msg_type}, content preview={content_preview}")
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
content_to_parse = msg_content
if isinstance(msg_content, list):
for block in msg_content:
if isinstance(block, dict) and block.get('type') == 'text':
content_to_parse = block.get('text', '')
break
else:
continue # No text block found
# Try to parse content as JSON
if isinstance(msg_content, str):
if isinstance(content_to_parse, str):
try:
parsed = json.loads(msg_content)
parsed = json.loads(content_to_parse)
if isinstance(parsed, dict):
# Debug: log what keys are in parsed
logger.debug(f"Parsed dict keys: {list(parsed.keys())}")
# Check for single intermediate output
if '_intermediate' in parsed:
intermediate_data = parsed['_intermediate']
output_key = self._create_intermediate_key(intermediate_data)
logger.debug(f"Found _intermediate: {intermediate_data.get('type', 'unknown')}")
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
@@ -450,34 +513,14 @@ class MemoryAgentService:
# Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in parsed:
logger.debug(f"Found _intermediates list with {len(parsed['_intermediates'])} items")
for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
logger.debug(f"Processing intermediate: {intermediate_data.get('type', 'unknown')}")
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
except (json.JSONDecodeError, ValueError):
pass
elif isinstance(msg_content, dict):
# Check for single intermediate output
if '_intermediate' in msg_content:
intermediate_data = msg_content['_intermediate']
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
# Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in msg_content:
for intermediate_data in msg_content['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
@@ -489,18 +532,57 @@ class MemoryAgentService:
for messages in outputs:
if messages['role'] == 'tool':
message = messages['content']
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(message, list):
# Extract text from MCP content blocks
for block in message:
if isinstance(block, dict) and block.get('type') == 'text':
message = block.get('text', '')
break
else:
continue # No text block found
try:
message = json.loads(message) if isinstance(message, str) else message
if isinstance(message, dict) and message.get('status') != '':
summary_result = message.get('summary_result')
if summary_result:
final_answer = summary_result
parsed = json.loads(message) if isinstance(message, str) else message
if isinstance(parsed, dict):
if parsed.get('status') == 'success':
summary_result = parsed.get('summary_result')
if summary_result:
final_answer = summary_result
except (json.JSONDecodeError, ValueError):
pass
# 记录成功的操作
total_duration = time.time() - start_time
if audit_logger:
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}")
if audit_logger:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=False,
duration=total_duration,
error=error_details,
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer),
"errors": workflow_errors
}
)
# Raise error if no answer was produced
if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
@@ -612,19 +694,29 @@ class MemoryAgentService:
else:
return output
async def classify_message_type(self, message: str) -> Dict:
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
"""
Determine the type of user message (read or write)
Updated to eliminate global variables in favor of explicit parameters.
Args:
message: User message to classify
config_id: Configuration ID to load LLM model from database
db: Database session
Returns:
Type classification result
"""
logger.info("Classifying message type")
status = await status_typle(message)
# Load configuration to get LLM model ID
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}")
return status
@@ -790,7 +882,9 @@ class MemoryAgentService:
async def get_user_profile(
self,
end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None
current_user_id: Optional[str] = None,
llm_id: Optional[str] = None,
db: Session = None
) -> Dict[str, Any]:
"""
获取用户详情,包含:
@@ -801,6 +895,8 @@ class MemoryAgentService:
参数:
- end_user_id: 用户ID可选
- current_user_id: 当前登录用户的ID保留参数
- llm_id: LLM模型ID用于生成标签可选如果不提供则跳过标签生成
- db: 数据库会话(可选)
返回格式:
{
@@ -817,7 +913,7 @@ class MemoryAgentService:
# 1. 根据 end_user_id 获取 end_user_name
try:
if end_user_id:
if end_user_id and db:
from app.repositories import end_user_repository
from app.schemas.end_user_schema import EndUser as EndUserSchema
@@ -862,15 +958,19 @@ class MemoryAgentService:
await connector.close()
if not statements:
if not statements or not llm_id:
result["tags"] = []
if not llm_id and statements:
logger.warning("llm_id not provided, skipping tag generation")
else:
# 构建摘要文本
summary_text = f"用户语句样本:{' | '.join(statements[:20])}\n核心实体:{', '.join(entities)}"
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
# 使用LLM提取标签
llm_client = get_llm_client()
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(llm_id)
# 定义标签提取的结构
class UserTags(BaseModel):
@@ -1032,4 +1132,69 @@ class MemoryAgentService:
# "msg": "解析失败",
# "error_code": "DOC_PARSE_ERROR",
# "data": {"error": str(e)}
# }
# }
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
"""
获取终端用户关联的记忆配置
通过以下流程获取配置:
1. 根据 end_user_id 获取用户的 app_id
2. 获取该应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id
Args:
end_user_id: 终端用户ID
db: 数据库会话
Returns:
包含 memory_config_id 和相关信息的字典
Raises:
ValueError: 当终端用户不存在或应用未发布时
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from sqlalchemy import select
logger.info(f"Getting connected config for end_user: {end_user_id}")
# 1. 获取 end_user 及其 app_id
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
logger.warning(f"End user not found: {end_user_id}")
raise ValueError(f"终端用户不存在: {end_user_id}")
app_id = end_user.app_id
logger.debug(f"Found end_user app_id: {app_id}")
# 2. 获取该应用的最新发布版本
stmt = (
select(AppRelease)
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
.order_by(AppRelease.version.desc())
)
latest_release = db.scalars(stmt).first()
if not latest_release:
logger.warning(f"No active release found for app: {app_id}")
raise ValueError(f"应用未发布: {app_id}")
logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}")
# 3. 从 config 中提取 memory_config_id
config = latest_release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
result = {
"end_user_id": str(end_user_id),
"app_id": str(app_id),
"release_id": str(latest_release.id),
"release_version": latest_release.version,
"memory_config_id": memory_config_id
}
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
return result

View File

@@ -0,0 +1,399 @@
"""
Memory Configuration Service
Centralized configuration loading and management for memory services.
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
"""
import time
from datetime import datetime
from app.core.logging_config import get_config_logger, get_logger
from app.core.validators.memory_config_validators import (
validate_and_resolve_model_id,
validate_embedding_model,
validate_model_exists_and_active,
)
from app.repositories.data_config_repository import DataConfigRepository
from app.schemas.memory_config_schema import (
ConfigurationError,
InvalidConfigError,
MemoryConfig,
ModelInactiveError,
ModelNotFoundError,
)
from sqlalchemy.orm import Session
logger = get_logger(__name__)
config_logger = get_config_logger()
def _validate_config_id(config_id):
"""Validate configuration ID format."""
if config_id is None:
raise InvalidConfigError(
"Configuration ID cannot be None",
field_name="config_id",
invalid_value=config_id,
)
if isinstance(config_id, int):
if config_id <= 0:
raise InvalidConfigError(
f"Configuration ID must be positive: {config_id}",
field_name="config_id",
invalid_value=config_id,
)
return config_id
if isinstance(config_id, str):
try:
parsed_id = int(config_id.strip())
if parsed_id <= 0:
raise InvalidConfigError(
f"Configuration ID must be positive: {parsed_id}",
field_name="config_id",
invalid_value=config_id,
)
return parsed_id
except ValueError:
raise InvalidConfigError(
f"Invalid configuration ID format: '{config_id}'",
field_name="config_id",
invalid_value=config_id,
)
raise InvalidConfigError(
f"Invalid type for configuration ID: expected int or str, got {type(config_id).__name__}",
field_name="config_id",
invalid_value=config_id,
)
class MemoryConfigService:
"""
Centralized service for memory configuration loading and validation.
This class provides a single implementation of configuration loading logic
that can be shared across multiple services, eliminating code duplication.
Usage:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
model_config = config_service.get_model_config(model_id)
"""
def __init__(self, db: Session):
"""Initialize the service with a database session.
Args:
db: SQLAlchemy database session
"""
self.db = db
def load_memory_config(
self,
config_id: int,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
Args:
config_id: Configuration ID from database
service_name: Name of the calling service (for logging purposes)
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
"""
start_time = time.time()
config_logger.info(
"Starting memory configuration loading",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": config_id,
},
)
logger.info(f"Loading memory configuration from database: config_id={config_id}")
try:
validated_config_id = _validate_config_id(config_id)
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
if not result:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Configuration not found in database",
extra={
"operation": "load_memory_config",
"config_id": validated_config_id,
"load_result": "not_found",
"elapsed_ms": elapsed_ms,
"service": service_name,
},
)
raise ConfigurationError(
f"Configuration {validated_config_id} not found in database"
)
memory_config, workspace = result
# Validate embedding model
embedding_uuid = validate_embedding_model(
validated_config_id,
memory_config.embedding_id,
self.db,
workspace.tenant_id,
workspace.id,
)
# Resolve LLM model
llm_uuid, llm_name = validate_and_resolve_model_id(
memory_config.llm_id,
"llm",
self.db,
workspace.tenant_id,
required=True,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Resolve optional rerank model
rerank_uuid = None
rerank_name = None
if memory_config.rerank_id:
rerank_uuid, rerank_name = validate_and_resolve_model_id(
memory_config.rerank_id,
"rerank",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Get embedding model name
embedding_name, _ = validate_model_exists_and_active(
embedding_uuid,
"embedding",
self.db,
workspace.tenant_id,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Create immutable MemoryConfig object
config = MemoryConfig(
config_id=memory_config.config_id,
config_name=memory_config.config_name,
workspace_id=workspace.id,
workspace_name=workspace.name,
tenant_id=workspace.tenant_id,
llm_model_id=llm_uuid,
llm_model_name=llm_name,
embedding_model_id=embedding_uuid,
embedding_model_name=embedding_name,
rerank_model_id=rerank_uuid,
rerank_model_name=rerank_name,
storage_type=workspace.storage_type or "neo4j",
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
reflexion_enabled=memory_config.enable_self_reflexion or False,
reflexion_iteration_period=int(memory_config.iteration_period or "3"),
reflexion_range=memory_config.reflexion_range or "retrieval",
reflexion_baseline=memory_config.baseline or "time",
loaded_at=datetime.now(),
# Pipeline config: Deduplication
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
# Pipeline config: Statement extraction
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
# Pipeline config: Forgetting engine
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
# Pipeline config: Pruning
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
)
elapsed_ms = (time.time() - start_time) * 1000
config_logger.info(
"Memory configuration loaded successfully",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": validated_config_id,
"config_name": config.config_name,
"workspace_id": str(config.workspace_id),
"load_result": "success",
"elapsed_ms": elapsed_ms,
},
)
logger.info(f"Memory configuration loaded successfully: {config.config_name}")
return config
except Exception as e:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Failed to load memory configuration",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": config_id,
"load_result": "error",
"error_type": type(e).__name__,
"error_message": str(e),
"elapsed_ms": elapsed_ms,
},
exc_info=True,
)
logger.error(f"Failed to load memory configuration {config_id}: {e}")
if isinstance(e, (ConfigurationError, ValueError)):
raise
else:
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
def get_model_config(self, model_id: str) -> dict:
"""Get LLM model configuration by ID.
Args:
model_id: Model ID to look up
Returns:
Dict with model configuration including api_key, base_url, etc.
"""
from app.core.config import settings
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
if not config:
logger.warning(f"Model ID {model_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": api_config.model_config_id,
"type": config.type,
"timeout": settings.LLM_TIMEOUT,
"max_retries": settings.LLM_MAX_RETRIES,
}
def get_embedder_config(self, embedding_id: str) -> dict:
"""Get embedding model configuration by ID.
Args:
embedding_id: Embedding model ID to look up
Returns:
Dict with embedder configuration including api_key, base_url, etc.
"""
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
if not config:
logger.warning(f"Embedding model ID {embedding_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": api_config.model_config_id,
"type": config.type,
"timeout": 120.0,
"max_retries": 5,
}
@staticmethod
def get_pipeline_config(memory_config: MemoryConfig):
"""Build ExtractionPipelineConfig from MemoryConfig.
Args:
memory_config: MemoryConfig object containing all pipeline settings.
Returns:
ExtractionPipelineConfig with deduplication, statement extraction,
and forgetting engine settings.
"""
from app.core.memory.models.variate_config import (
DedupConfig,
ExtractionPipelineConfig,
ForgettingEngineConfig,
StatementExtractionConfig,
)
dedup_config = DedupConfig(
enable_llm_dedup_blockwise=memory_config.enable_llm_dedup_blockwise,
enable_llm_disambiguation=memory_config.enable_llm_disambiguation,
fuzzy_name_threshold_strict=memory_config.t_name_strict,
fuzzy_type_threshold_strict=memory_config.t_type_strict,
fuzzy_overall_threshold=memory_config.t_overall,
)
stmt_config = StatementExtractionConfig(
statement_granularity=memory_config.statement_granularity,
include_dialogue_context=memory_config.include_dialogue_context,
max_dialogue_context_chars=memory_config.max_dialogue_context_chars,
)
forget_config = ForgettingEngineConfig(
offset=memory_config.offset,
lambda_time=memory_config.lambda_time,
lambda_mem=memory_config.lambda_mem,
)
return ExtractionPipelineConfig(
statement_extraction=stmt_config,
deduplication=dedup_config,
forgetting_engine=forget_config,
)
@staticmethod
def get_pruning_config(memory_config: MemoryConfig) -> dict:
"""Retrieve semantic pruning config from MemoryConfig.
Args:
memory_config: MemoryConfig object containing pruning settings.
Returns:
Dict suitable for PruningConfig.model_validate with keys:
- pruning_switch: bool
- pruning_scene: str
- pruning_threshold: float
"""
return {
"pruning_switch": memory_config.pruning_enabled,
"pruning_scene": memory_config.pruning_scene,
"pruning_threshold": memory_config.pruning_threshold,
}

View File

@@ -4,38 +4,37 @@ Memory Storage Service
Handles business logic for memory storage operations.
"""
from typing import Dict, List, Optional, Any, AsyncGenerator
import os
import json
import asyncio
import json
import os
import time
import uuid
from datetime import datetime
from sqlalchemy.orm import Session
from dotenv import load_dotenv
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
from app.models.user_model import User
from app.core.logging_config import get_logger
from app.utils.sse_utils import format_sse_message
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError
from app.schemas.memory_storage_schema import (
ConfigPilotRun,
ConfigKey,
ConfigParamsCreate,
ConfigParamsDelete,
ConfigPilotRun,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
ConfigKey,
)
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.memory_insight import MemoryInsight
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
from app.core.memory.analytics.user_summary import generate_user_summary
from app.repositories.end_user_repository import EndUserRepository
import uuid
from app.services.memory_config_service import MemoryConfigService
from app.utils.sse_utils import format_sse_message
from dotenv import load_dotenv
from sqlalchemy.orm import Session
logger = get_logger(__name__)
config_logger = get_config_logger()
# Load environment variables for Neo4j connector
load_dotenv()
@@ -247,7 +246,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
RuntimeError: 当管线执行失败时
"""
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
try:
# 发出初始进度事件
@@ -256,24 +254,12 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"time": int(time.time() * 1000)
})
# 步骤 1: 配置加载和验证(复用现有逻辑
# 步骤 1: 配置加载和验证(数据库优先
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
cid: Optional[str] = payload_cid if payload_cid else None
if not cid and os.path.isfile(dbrun_path):
try:
with open(dbrun_path, "r", encoding="utf-8") as f:
dbrun = json.load(f)
if isinstance(dbrun, dict):
sel = dbrun.get("selections", {})
if isinstance(sel, dict):
fallback_cid = str(sel.get("config_id") or "").strip()
cid = fallback_cid or None
except Exception:
cid = None
if not cid:
raise ValueError("未提供 payload.config_id且 dbrun.json 未设置 selections.config_id禁止启动试运行")
raise ValueError("未提供 payload.config_id禁止启动试运行")
# 验证 dialogue_text 必须提供
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
@@ -281,12 +267,16 @@ class DataConfigService: # 数据配置服务类PostgreSQL
if not dialogue_text:
raise ValueError("试运行模式必须提供 dialogue_text 参数")
# 应用内存覆写并刷新常量
from app.core.memory.utils.config.definitions import reload_configuration_from_database
ok_override = reload_configuration_from_database(cid)
if not ok_override:
raise RuntimeError("运行时覆写失败config_id 无效或刷新常量失败")
# Load configuration from database only using centralized manager
try:
config_service = MemoryConfigService(self.db)
memory_config = config_service.load_memory_config(
config_id=int(cid),
service_name="MemoryStorageService.pilot_run_stream"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
raise RuntimeError(f"Configuration loading failed: {e}")
# 步骤 2: 创建进度回调函数捕获管线进度
# 使用队列在回调和生成器之间传递进度事件
@@ -307,13 +297,14 @@ class DataConfigService: # 数据配置服务类PostgreSQL
async def run_pipeline():
"""在后台执行管线并捕获异常"""
try:
from app.core.memory.main import main as pipeline_main
from app.services.pilot_run_service import run_pilot_extraction
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
await pipeline_main(
dialogue_text=dialogue_text,
is_pilot_run=True,
progress_callback=progress_callback
logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
await run_pilot_extraction(
memory_config=memory_config,
dialogue_text=dialogue_text,
db=self.db,
progress_callback=progress_callback,
)
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")

View File

@@ -0,0 +1,219 @@
"""
Pilot Run Service - 试运行服务
用于执行记忆系统的试运行流程,不保存到 Neo4j。
"""
import os
import re
import time
from datetime import datetime
from typing import Awaitable, Callable, Optional
from app.core.logging_config import get_memory_logger, log_time
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
DialogData,
)
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
get_chunked_dialogs_from_preprocessed,
)
from app.core.memory.utils.config.config_utils import (
get_pipeline_config,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
from sqlalchemy.orm import Session
logger = get_memory_logger(__name__)
async def run_pilot_extraction(
memory_config: MemoryConfig,
dialogue_text: str,
db: Session,
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
) -> None:
"""
执行试运行模式的知识提取流水线。
Args:
memory_config: 从数据库加载的内存配置对象
dialogue_text: 输入的对话文本
progress_callback: 可选的进度回调函数
- 参数1 (stage): 当前处理阶段标识符
- 参数2 (message): 人类可读的进度消息
- 参数3 (data): 可选的附加数据字典
"""
log_file = "logs/time.log"
os.makedirs(os.path.dirname(log_file), exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
pipeline_start = time.time()
neo4j_connector = None
try:
# 步骤 1: 初始化客户端
logger.info("Initializing clients...")
step_start = time.time()
client_factory = MemoryClientFactory(db)
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
neo4j_connector = Neo4jConnector()
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 解析对话文本
logger.info("Parsing dialogue text...")
step_start = time.time()
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
pattern = r"(用户|AI)[:]\s*([^\n]+(?:\n(?!(?:用户|AI)[:])[^\n]*)*?)"
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
messages = [
ConversationMessage(role=r, msg=c.strip())
for r, c in matches
if c.strip()
]
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
if not messages:
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
context = ConversationContext(msgs=messages)
dialog = DialogData(
context=context,
ref_id="pilot_dialog_1",
group_id=str(memory_config.workspace_id),
user_id=str(memory_config.tenant_id),
apply_id=str(memory_config.config_id),
metadata={"source": "pilot_run", "input_type": "frontend_text"},
)
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
chunker_strategy=memory_config.chunker_strategy,
llm_client=llm_client,
)
logger.info(f"Processed dialogue text: {len(messages)} messages")
# 进度回调:输出每个分块的结果
if progress_callback:
for dlg in chunked_dialogs:
for i, chunk in enumerate(dlg.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dlg.id,
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
preprocessing_summary = {
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化流水线编排器
logger.info("Initializing extraction orchestrator...")
step_start = time.time()
config = get_pipeline_config(memory_config)
logger.info(
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
)
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
# 步骤 4: 执行知识提取流水线
logger.info("Running extraction pipeline...")
step_start = time.time()
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=True,
)
# 解包 extraction_result tuple (与 main.py 保持一致)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file)
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 生成记忆摘要(与 main.py 保持一致)
try:
logger.info("Generating memory summaries...")
step_start = time.time()
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
summaries = await memory_summary_generation(
chunked_dialogs,
llm_client=llm_client,
embedder_client=embedder_client,
)
log_time("Memory Summary Generation", time.time() - step_start, log_file)
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
logger.info("Pilot run completed: Skipping Neo4j save")
except Exception as e:
logger.error(f"Pilot run failed: {e}", exc_info=True)
raise
finally:
if neo4j_connector:
try:
await neo4j_connector.close()
except Exception:
pass
total_time = time.time() - pipeline_start
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n")
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")

View File

@@ -4,15 +4,15 @@ User Memory Service
处理用户记忆相关的业务逻辑,包括记忆洞察、用户摘要、节点统计和图数据等。
"""
from typing import Dict, List, Optional, Any
import uuid
from sqlalchemy.orm import Session
from typing import Any, Dict, List, Optional
from app.core.logging_config import get_logger
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.analytics.memory_insight import MemoryInsight
from app.core.memory.analytics.user_summary import generate_user_summary
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from sqlalchemy.orm import Session
logger = get_logger(__name__)
@@ -284,8 +284,7 @@ class UserMemoryService:
# 使用 end_user_id 调用分析函数
try:
logger.info(f"使用 end_user_id={end_user_id} 生成用户摘要")
result = await analytics_user_summary(end_user_id)
summary = result.get("summary", "")
summary = await generate_user_summary(end_user_id)
if not summary:
logger.warning(f"end_user_id {end_user_id} 的用户摘要生成结果为空")

View File

@@ -1,26 +1,30 @@
import asyncio
from typing import Any, Dict, List
import requests
from datetime import datetime, timezone
import json
import os
import time
import uuid
from datetime import datetime, timezone
from math import ceil
import redis
from typing import Any, Dict, List, Optional
from app.db import get_db_context
from app.models.document_model import Document
from app.models.knowledge_model import Knowledge
from app.core.rag.llm.cv_model import QWenCV
from app.core.rag.llm.chat_model import Base
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.rag.models.chunk import DocumentChunk
from app.services.memory_agent_service import MemoryAgentService
from app.core.config import settings
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
from app.core.rag.prompts.generator import question_proposal
import redis
import requests
# Import a unified Celery instance
from app.celery_app import celery_app
from app.core.config import settings
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.cv_model import QWenCV
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.prompts.generator import question_proposal
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchVectorFactory,
)
from app.db import get_db, get_db_context
from app.models.document_model import Document
from app.models.knowledge_model import Knowledge
from app.services.memory_agent_service import MemoryAgentService
@celery_app.task(name="tasks.process_item")
@@ -170,7 +174,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
"""Celery task to process a read message via MemoryAgentService.
Args:
group_id: Group ID for the memory agent
group_id: Group ID for the memory agent (also used as end_user_id)
message: User message to process
history: Conversation history
search_switch: Search switch parameter
@@ -184,9 +188,28 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
"""
start_time = time.time()
# Resolve config_id if None
actual_config_id = config_id
if actual_config_id is None:
try:
from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db())
try:
connected_config = get_end_user_connected_config(group_id, db)
actual_config_id = connected_config.get("memory_config_id")
finally:
db.close()
except Exception as e:
# Log but continue - will fail later with proper error
pass
async def _run() -> str:
service = MemoryAgentService()
return await service.read_memory(group_id, message, history, search_switch, config_id,storage_type,user_rag_memory_id)
db = next(get_db())
try:
service = MemoryAgentService()
return await service.read_memory(group_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
finally:
db.close()
try:
# 使用 nest_asyncio 来避免事件循环冲突
@@ -217,11 +240,17 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
except Exception as e:
except BaseException as e:
elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
else:
detailed_error = str(e)
return {
"status": "FAILURE",
"error": str(e),
"error": detailed_error,
"group_id": group_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
@@ -234,7 +263,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
"""Celery task to process a write message via MemoryAgentService.
Args:
group_id: Group ID for the memory agent
group_id: Group ID for the memory agent (also used as end_user_id)
message: Message to write
config_id: Optional configuration ID
@@ -246,9 +275,28 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
"""
start_time = time.time()
# Resolve config_id if None
actual_config_id = config_id
if actual_config_id is None:
try:
from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db())
try:
connected_config = get_end_user_connected_config(group_id, db)
actual_config_id = connected_config.get("memory_config_id")
finally:
db.close()
except Exception as e:
# Log but continue - will fail later with proper error
pass
async def _run() -> str:
service = MemoryAgentService()
return await service.write_memory(group_id, message, config_id,storage_type,user_rag_memory_id)
db = next(get_db())
try:
service = MemoryAgentService()
return await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
finally:
db.close()
try:
# 使用 nest_asyncio 来避免事件循环冲突
@@ -279,11 +327,17 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
except Exception as e:
except BaseException as e:
elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
else:
detailed_error = str(e)
return {
"status": "FAILURE",
"error": str(e),
"error": detailed_error,
"group_id": group_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
@@ -291,6 +345,27 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
}
def reflection_engine() -> None:
"""Empty function placeholder for timed background reflection.
Intentionally left blank; replace with real reflection logic later.
"""
import asyncio
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
asyncio.run(self_reflexion(host_id))
@celery_app.task(name="app.core.memory.agent.reflection.timer")
def reflection_timer_task() -> None:
"""Periodic Celery task that invokes reflection_engine.
Raises an exception on failure.
"""
reflection_engine()
@celery_app.task(name="app.core.memory.agent.health.check_read_service")
def check_read_service_task() -> Dict[str, str]:
@@ -353,10 +428,10 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.services.memory_storage_service import search_all
from app.repositories.memory_increment_repository import write_memory_increment
from app.models.end_user_model import EndUser
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.repositories.memory_increment_repository import write_memory_increment
from app.services.memory_storage_service import search_all
with get_db_context() as db:
try:
@@ -465,9 +540,9 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.services.user_memory_service import UserMemoryService
from app.repositories.end_user_repository import EndUserRepository
from app.core.logging_config import get_logger
from app.repositories.end_user_repository import EndUserRepository
from app.services.user_memory_service import UserMemoryService
logger = get_logger(__name__)
logger.info("开始执行记忆缓存重新生成定时任务")
@@ -645,9 +720,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService
from app.models.workspace_model import Workspace
from app.core.logging_config import get_api_logger
from app.models.workspace_model import Workspace
from app.services.memory_reflection_service import (
MemoryReflectionService,
WorkspaceAppService,
)
api_logger = get_api_logger()

View File

@@ -127,6 +127,7 @@ dependencies = [
"uvicorn>=0.34.0",
"celery>=5.5.2",
"simpleeval>=1.0.3",
"langchain-aws>=1.0.0a1",
]
[tool.pytest.ini_options]

71
api/uv.lock generated
View File

@@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = "==3.12.*"
resolution-markers = [
"sys_platform == 'darwin'",
@@ -7,6 +7,9 @@ resolution-markers = [
"(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')",
]
[options]
prerelease-mode = "allow"
[[package]]
name = "aiohappyeyeballs"
version = "2.6.1"
@@ -241,6 +244,34 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" },
]
[[package]]
name = "boto3"
version = "1.42.14"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "botocore" },
{ name = "jmespath" },
{ name = "s3transfer" },
]
sdist = { url = "https://files.pythonhosted.org/packages/09/72/e236ca627bc0461710685f5b7438f759ef3b4106e0e08dda08513a6539ab/boto3-1.42.14.tar.gz", hash = "sha256:a5d005667b480c844ed3f814a59f199ce249d0f5669532a17d06200c0a93119c", size = 112825, upload-time = "2025-12-19T20:27:15.325Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/bb/ba/c657ea6f6d63563cc46748202fccd097b51755d17add00ebe4ea27580d06/boto3-1.42.14-py3-none-any.whl", hash = "sha256:bfcc665227bb4432a235cb4adb47719438d6472e5ccbf7f09512046c3f749670", size = 140571, upload-time = "2025-12-19T20:27:13.316Z" },
]
[[package]]
name = "botocore"
version = "1.42.14"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jmespath" },
{ name = "python-dateutil" },
{ name = "urllib3" },
]
sdist = { url = "https://files.pythonhosted.org/packages/35/3f/50c56f093c2c6ce6de1f579726598db1cf9a9cccd3bf8693f73b1cf5e319/botocore-1.42.14.tar.gz", hash = "sha256:cf5bebb580803c6cfd9886902ca24834b42ecaa808da14fb8cd35ad523c9f621", size = 14910547, upload-time = "2025-12-19T20:27:04.431Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ad/94/67a78a8d08359e779894d4b1672658a3c7fcce216b48f06dfbe1de45521d/botocore-1.42.14-py3-none-any.whl", hash = "sha256:efe89adfafa00101390ec2c371d453b3359d5f9690261bc3bd70131e0d453e8e", size = 14583247, upload-time = "2025-12-19T20:27:00.54Z" },
]
[[package]]
name = "cachetools"
version = "6.2.1"
@@ -1114,6 +1145,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2f/9c/6753e6522b8d0ef07d3a3d239426669e984fb0eba15a315cdbc1253904e4/jiter-0.12.0-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c24e864cb30ab82311c6425655b0cdab0a98c5d973b065c66a3f020740c2324c", size = 346110, upload-time = "2025-11-09T20:49:21.817Z" },
]
[[package]]
name = "jmespath"
version = "1.0.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/00/2a/e867e8531cf3e36b41201936b7fa7ba7b5702dbef42922193f05c8976cd6/jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe", size = 25843, upload-time = "2022-06-17T18:00:12.224Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" },
]
[[package]]
name = "joblib"
version = "1.5.2"
@@ -1262,6 +1302,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f3/39/ed3121ea3a0c60a0cda6ea5c4c1cece013e8bbc9b18344ff3ae507728f98/langchain-1.1.3-py3-none-any.whl", hash = "sha256:e5b208ed93e553df4087117a40bd0d450f9095030a843cad35c53ff2814bf731", size = 102227, upload-time = "2025-12-08T19:31:47.246Z" },
]
[[package]]
name = "langchain-aws"
version = "1.0.0a1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "boto3" },
{ name = "langchain-core" },
{ name = "numpy" },
{ name = "pydantic" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c0/c3/a98c0849c13c6880b5629409cadb22d4070e9c611013da127be975f8c0dc/langchain_aws-1.0.0a1.tar.gz", hash = "sha256:3bb193a5fa915520c52bb47581e892d11ac4d114939a1b3ecfeca56fe153fff7", size = 121650, upload-time = "2025-09-18T20:52:36.098Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/9d/7b/be49a224fe3aa07ed869801356f06e1d7a321bb7f22b6f7935dce86d258a/langchain_aws-1.0.0a1-py3-none-any.whl", hash = "sha256:24207d05c619ea61dfeab0a0f7086ae388cc3f2f5c03a8ae56b12d1b77d72585", size = 146839, upload-time = "2025-09-18T20:52:35.013Z" },
]
[[package]]
name = "langchain-classic"
version = "1.0.0"
@@ -2825,6 +2880,7 @@ dependencies = [
{ name = "json-repair" },
{ name = "kombu" },
{ name = "langchain" },
{ name = "langchain-aws" },
{ name = "langchain-community" },
{ name = "langchain-mcp-adapters" },
{ name = "langchain-ollama" },
@@ -2949,6 +3005,7 @@ requires-dist = [
{ name = "json-repair", specifier = "==0.53.0" },
{ name = "kombu", specifier = "==5.5.4" },
{ name = "langchain", specifier = ">=1.0.3" },
{ name = "langchain-aws", specifier = ">=1.0.0a1" },
{ name = "langchain-community", specifier = ">=0.3.31" },
{ name = "langchain-mcp-adapters", specifier = ">=0.1.13" },
{ name = "langchain-ollama" },
@@ -3199,6 +3256,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/e5/8925a4208f131b218f9a7e459c0d6fcac8324ae35da269cb437894576366/ruamel_yaml_clib-0.2.15-cp312-cp312-win_amd64.whl", hash = "sha256:2b216904750889133d9222b7b873c199d48ecbb12912aca78970f84a5aa1a4bc", size = 119013, upload-time = "2025-11-16T16:13:32.164Z" },
]
[[package]]
name = "s3transfer"
version = "0.16.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "botocore" },
]
sdist = { url = "https://files.pythonhosted.org/packages/05/04/74127fc843314818edfa81b5540e26dd537353b123a4edc563109d8f17dd/s3transfer-0.16.0.tar.gz", hash = "sha256:8e990f13268025792229cd52fa10cb7163744bf56e719e0b9cb925ab79abf920", size = 153827, upload-time = "2025-12-01T02:30:59.114Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" },
]
[[package]]
name = "scikit-learn"
version = "1.7.2"