Merge branch 'develop' of https://github.com/SuanmoSuanyangTechnology/MemoryBear into feature/app-share-wxy
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
from redis.asyncio import ConnectionPool
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# 设置日志记录器
|
||||
|
||||
@@ -62,7 +62,7 @@ celery_app.conf.update(
|
||||
task_serializer='json',
|
||||
accept_content=['json'],
|
||||
result_serializer='json',
|
||||
|
||||
|
||||
# 时区
|
||||
timezone='Asia/Shanghai',
|
||||
enable_utc=False,
|
||||
@@ -70,43 +70,44 @@ celery_app.conf.update(
|
||||
# 任务追踪
|
||||
task_track_started=True,
|
||||
task_ignore_result=False,
|
||||
|
||||
|
||||
# 超时设置
|
||||
task_time_limit=3600, # 60分钟硬超时
|
||||
task_soft_time_limit=3000, # 50分钟软超时
|
||||
|
||||
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
|
||||
|
||||
# 结果过期时间
|
||||
result_expires=3600, # 结果保存1小时
|
||||
|
||||
|
||||
# 任务确认设置
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_disable_rate_limits=True,
|
||||
|
||||
|
||||
# FLower setting
|
||||
worker_send_task_events=True,
|
||||
task_send_sent_event=True,
|
||||
|
||||
|
||||
# task routing
|
||||
task_routes={
|
||||
# Memory tasks → memory_tasks queue (threads worker)
|
||||
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'},
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
@@ -131,7 +132,7 @@ implicit_emotions_update_schedule = crontab(
|
||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
||||
)
|
||||
|
||||
#构建定时任务配置
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
|
||||
@@ -13,4 +13,4 @@ logger.info("Celery worker logging initialized")
|
||||
# 导入任务模块以注册任务
|
||||
import app.tasks
|
||||
|
||||
__all__ = ['celery_app']
|
||||
__all__ = ['celery_app']
|
||||
|
||||
@@ -16,6 +16,7 @@ from . import (
|
||||
file_controller,
|
||||
file_storage_controller,
|
||||
home_page_controller,
|
||||
i18n_controller,
|
||||
implicit_memory_controller,
|
||||
knowledge_controller,
|
||||
knowledgeshare_controller,
|
||||
@@ -94,5 +95,6 @@ manager_router.include_router(memory_working_controller.router)
|
||||
manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
manager_router.include_router(skill_controller.router)
|
||||
manager_router.include_router(i18n_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -844,6 +844,15 @@ async def draft_run_compare(
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
other_id=str(current_user.id),
|
||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
from sqlalchemy import select
|
||||
from app.models import AgentConfig
|
||||
@@ -889,6 +898,8 @@ async def draft_run_compare(
|
||||
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||
})
|
||||
|
||||
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
@@ -900,7 +911,7 @@ async def draft_run_compare(
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
@@ -931,7 +942,7 @@ async def draft_run_compare(
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -16,6 +17,7 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.dependencies import get_current_user, oauth2_scheme
|
||||
from app.models.user_model import User
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
# 获取专用日志器
|
||||
auth_logger = get_auth_logger()
|
||||
@@ -26,7 +28,8 @@ router = APIRouter(tags=["Authentication"])
|
||||
@router.post("/token", response_model=ApiResponse)
|
||||
async def login_for_access_token(
|
||||
form_data: TokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""用户登录获取token"""
|
||||
auth_logger.info(f"用户登录请求: {form_data.email}")
|
||||
@@ -40,10 +43,10 @@ async def login_for_access_token(
|
||||
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
||||
|
||||
if not invite_info.is_valid:
|
||||
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
|
||||
raise BusinessException(t("auth.invite.invalid"), code=BizCode.BAD_REQUEST)
|
||||
|
||||
if invite_info.email != form_data.email:
|
||||
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
|
||||
raise BusinessException(t("auth.invite.email_mismatch"), code=BizCode.BAD_REQUEST)
|
||||
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
||||
try:
|
||||
# 尝试认证用户
|
||||
@@ -69,7 +72,7 @@ async def login_for_access_token(
|
||||
elif e.code == BizCode.PASSWORD_ERROR:
|
||||
# 用户存在但密码错误
|
||||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
|
||||
raise BusinessException(t("auth.invite.password_verification_failed"), BizCode.LOGIN_FAILED)
|
||||
else:
|
||||
# 其他认证失败情况,直接抛出
|
||||
raise
|
||||
@@ -82,7 +85,7 @@ async def login_for_access_token(
|
||||
except BusinessException as e:
|
||||
|
||||
# 其他认证失败情况,直接抛出
|
||||
raise BusinessException(e.message,BizCode.LOGIN_FAILED)
|
||||
raise BusinessException(e.message, BizCode.LOGIN_FAILED)
|
||||
|
||||
# 创建 tokens
|
||||
access_token, access_token_id = security.create_access_token(subject=user.id)
|
||||
@@ -110,14 +113,15 @@ async def login_for_access_token(
|
||||
expires_at=access_expires_at,
|
||||
refresh_expires_at=refresh_expires_at
|
||||
),
|
||||
msg="登录成功"
|
||||
msg=t("auth.login.success")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=ApiResponse)
|
||||
async def refresh_token(
|
||||
refresh_request: RefreshTokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""刷新token"""
|
||||
auth_logger.info("收到token刷新请求")
|
||||
@@ -125,18 +129,18 @@ async def refresh_token(
|
||||
# 验证 refresh token
|
||||
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
||||
if not userId:
|
||||
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
|
||||
raise BusinessException(t("auth.token.invalid_refresh_token"), code=BizCode.TOKEN_INVALID)
|
||||
|
||||
# 检查用户是否存在
|
||||
user = auth_service.get_user_by_id(db, userId)
|
||||
if not user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 检查 refresh token 黑名单
|
||||
if settings.ENABLE_SINGLE_SESSION:
|
||||
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||||
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
||||
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
|
||||
raise BusinessException(t("auth.token.refresh_token_blacklisted"), code=BizCode.TOKEN_BLACKLISTED)
|
||||
|
||||
# 生成新 tokens
|
||||
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
||||
@@ -167,7 +171,7 @@ async def refresh_token(
|
||||
expires_at=access_expires_at,
|
||||
refresh_expires_at=refresh_expires_at
|
||||
),
|
||||
msg="token刷新成功"
|
||||
msg=t("auth.token.refresh_success")
|
||||
)
|
||||
|
||||
|
||||
@@ -175,14 +179,15 @@ async def refresh_token(
|
||||
async def logout(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""登出当前用户:加入token黑名单并清理会话"""
|
||||
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
||||
|
||||
token_id = security.get_token_id(token)
|
||||
if not token_id:
|
||||
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
|
||||
raise BusinessException(t("auth.token.invalid"), code=BizCode.TOKEN_INVALID)
|
||||
|
||||
# 加入黑名单
|
||||
await SessionService.blacklist_token(token_id)
|
||||
@@ -192,5 +197,5 @@ async def logout(
|
||||
await SessionService.clear_user_session(current_user.username)
|
||||
|
||||
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
||||
return success(msg="登出成功")
|
||||
return success(msg=t("auth.logout.success"))
|
||||
|
||||
|
||||
833
api/app/controllers/i18n_controller.py
Normal file
833
api/app/controllers/i18n_controller.py
Normal file
@@ -0,0 +1,833 @@
|
||||
"""
|
||||
I18n Management API Controller
|
||||
|
||||
This module provides management APIs for:
|
||||
- Language management (list, get, add, update languages)
|
||||
- Translation management (get, update, reload translations)
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Callable, Optional
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, get_current_superuser
|
||||
from app.i18n.dependencies import get_translator
|
||||
from app.i18n.service import get_translation_service
|
||||
from app.models.user_model import User
|
||||
from app.schemas.i18n_schema import (
|
||||
LanguageInfo,
|
||||
LanguageListResponse,
|
||||
LanguageCreateRequest,
|
||||
LanguageUpdateRequest,
|
||||
TranslationResponse,
|
||||
TranslationUpdateRequest,
|
||||
MissingTranslationsResponse,
|
||||
ReloadResponse
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/i18n",
|
||||
tags=["I18n Management"],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Language Management APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/languages", response_model=ApiResponse)
|
||||
def get_languages(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get list of all supported languages.
|
||||
|
||||
Returns:
|
||||
List of language information including code, name, and status
|
||||
"""
|
||||
api_logger.info(f"Get languages request from user: {current_user.username}")
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Get available locales from translation service
|
||||
available_locales = translation_service.get_available_locales()
|
||||
|
||||
# Build language info list
|
||||
languages = []
|
||||
for locale in available_locales:
|
||||
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||
|
||||
# Get native names
|
||||
native_names = {
|
||||
"zh": "中文(简体)",
|
||||
"en": "English",
|
||||
"ja": "日本語",
|
||||
"ko": "한국어",
|
||||
"fr": "Français",
|
||||
"de": "Deutsch",
|
||||
"es": "Español"
|
||||
}
|
||||
|
||||
language_info = LanguageInfo(
|
||||
code=locale,
|
||||
name=f"{locale.upper()}",
|
||||
native_name=native_names.get(locale, locale),
|
||||
is_enabled=is_enabled,
|
||||
is_default=is_default
|
||||
)
|
||||
languages.append(language_info)
|
||||
|
||||
response = LanguageListResponse(languages=languages)
|
||||
|
||||
api_logger.info(f"Returning {len(languages)} languages")
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/languages/{locale}", response_model=ApiResponse)
|
||||
def get_language(
|
||||
locale: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get information about a specific language.
|
||||
|
||||
Args:
|
||||
locale: Language code (e.g., 'zh', 'en')
|
||||
|
||||
Returns:
|
||||
Language information
|
||||
"""
|
||||
api_logger.info(f"Get language info request: locale={locale}, user={current_user.username}")
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Build language info
|
||||
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||
|
||||
native_names = {
|
||||
"zh": "中文(简体)",
|
||||
"en": "English",
|
||||
"ja": "日本語",
|
||||
"ko": "한국어",
|
||||
"fr": "Français",
|
||||
"de": "Deutsch",
|
||||
"es": "Español"
|
||||
}
|
||||
|
||||
language_info = LanguageInfo(
|
||||
code=locale,
|
||||
name=f"{locale.upper()}",
|
||||
native_name=native_names.get(locale, locale),
|
||||
is_enabled=is_enabled,
|
||||
is_default=is_default
|
||||
)
|
||||
|
||||
api_logger.info(f"Returning language info for: {locale}")
|
||||
return success(data=language_info.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/languages", response_model=ApiResponse)
|
||||
def add_language(
|
||||
request: LanguageCreateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Add a new language (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual language addition
|
||||
requires creating translation files in the locales directory.
|
||||
|
||||
Args:
|
||||
request: Language creation request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Add language request: code={request.code}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if language already exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if request.code in available_locales:
|
||||
api_logger.warning(f"Language already exists: {request.code}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=t("i18n.language.already_exists", locale=request.code)
|
||||
)
|
||||
|
||||
# Note: Actual language addition requires creating translation files
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Language addition validated: {request.code}. "
|
||||
"Translation files need to be created manually."
|
||||
)
|
||||
|
||||
return success(
|
||||
msg=t(
|
||||
"i18n.language.add_instructions",
|
||||
locale=request.code,
|
||||
dir=settings.I18N_CORE_LOCALES_DIR
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.put("/languages/{locale}", response_model=ApiResponse)
|
||||
def update_language(
|
||||
locale: str,
|
||||
request: LanguageUpdateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Update language configuration (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual configuration
|
||||
changes require updating environment variables or config files.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
request: Language update request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Update language request: locale={locale}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if language exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Note: Actual configuration changes require updating settings
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Language update validated: {locale}. "
|
||||
"Configuration changes require environment variable updates."
|
||||
)
|
||||
|
||||
return success(msg=t("i18n.language.update_instructions", locale=locale))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Translation Management APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/translations", response_model=ApiResponse)
|
||||
def get_all_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all translations for all or specific locale.
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
All translations organized by locale and namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get all translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
if locale:
|
||||
# Get translations for specific locale
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
translations = {
|
||||
locale: translation_service._cache.get(locale, {})
|
||||
}
|
||||
else:
|
||||
# Get all translations
|
||||
translations = translation_service._cache
|
||||
|
||||
response = TranslationResponse(translations=translations)
|
||||
|
||||
api_logger.info(f"Returning translations for: {locale or 'all locales'}")
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/translations/{locale}", response_model=ApiResponse)
|
||||
def get_locale_translations(
|
||||
locale: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all translations for a specific locale.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
|
||||
Returns:
|
||||
All translations for the locale organized by namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get locale translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
translations = translation_service._cache.get(locale, {})
|
||||
|
||||
api_logger.info(f"Returning {len(translations)} namespaces for locale: {locale}")
|
||||
return success(data={"locale": locale, "translations": translations}, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/translations/{locale}/{namespace}", response_model=ApiResponse)
|
||||
def get_namespace_translations(
|
||||
locale: str,
|
||||
namespace: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get translations for a specific namespace in a locale.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
namespace: Translation namespace (e.g., 'common', 'auth')
|
||||
|
||||
Returns:
|
||||
Translations for the specified namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get namespace translations request: locale={locale}, "
|
||||
f"namespace={namespace}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Get namespace translations
|
||||
locale_translations = translation_service._cache.get(locale, {})
|
||||
namespace_translations = locale_translations.get(namespace, {})
|
||||
|
||||
if not namespace_translations:
|
||||
api_logger.warning(f"Namespace not found: {namespace} in locale: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.namespace.not_found", namespace=namespace, locale=locale)
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Returning translations for namespace: {namespace} in locale: {locale}"
|
||||
)
|
||||
return success(
|
||||
data={
|
||||
"locale": locale,
|
||||
"namespace": namespace,
|
||||
"translations": namespace_translations
|
||||
},
|
||||
msg=t("common.success.retrieved")
|
||||
)
|
||||
|
||||
|
||||
@router.put("/translations/{locale}/{key:path}", response_model=ApiResponse)
|
||||
def update_translation(
|
||||
locale: str,
|
||||
key: str,
|
||||
request: TranslationUpdateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Update a single translation (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual translation updates
|
||||
require modifying translation files in the locales directory.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
key: Translation key (format: "namespace.key.subkey")
|
||||
request: Translation update request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Update translation request: locale={locale}, key={key}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Validate key format
|
||||
if "." not in key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=t("i18n.translation.invalid_key_format", key=key)
|
||||
)
|
||||
|
||||
# Note: Actual translation updates require modifying JSON files
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Translation update validated: {locale}/{key}. "
|
||||
"Translation files need to be updated manually."
|
||||
)
|
||||
|
||||
return success(
|
||||
msg=t("i18n.translation.update_instructions", locale=locale, key=key)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/translations/missing", response_model=ApiResponse)
|
||||
def get_missing_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get list of missing translations.
|
||||
|
||||
Compares translations across locales to find missing keys.
|
||||
|
||||
Args:
|
||||
locale: Optional locale to check (defaults to checking all non-default locales)
|
||||
|
||||
Returns:
|
||||
List of missing translation keys
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get missing translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
default_locale = settings.I18N_DEFAULT_LANGUAGE
|
||||
available_locales = translation_service.get_available_locales()
|
||||
|
||||
# Get default locale translations as reference
|
||||
default_translations = translation_service._cache.get(default_locale, {})
|
||||
|
||||
# Collect all keys from default locale
|
||||
def collect_keys(data, prefix=""):
|
||||
keys = []
|
||||
for key, value in data.items():
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, dict):
|
||||
keys.extend(collect_keys(value, full_key))
|
||||
else:
|
||||
keys.append(full_key)
|
||||
return keys
|
||||
|
||||
default_keys = set()
|
||||
for namespace, translations in default_translations.items():
|
||||
namespace_keys = collect_keys(translations, namespace)
|
||||
default_keys.update(namespace_keys)
|
||||
|
||||
# Find missing keys in target locale(s)
|
||||
missing_by_locale = {}
|
||||
|
||||
target_locales = [locale] if locale else [
|
||||
loc for loc in available_locales if loc != default_locale
|
||||
]
|
||||
|
||||
for target_locale in target_locales:
|
||||
if target_locale not in available_locales:
|
||||
continue
|
||||
|
||||
target_translations = translation_service._cache.get(target_locale, {})
|
||||
target_keys = set()
|
||||
|
||||
for namespace, translations in target_translations.items():
|
||||
namespace_keys = collect_keys(translations, namespace)
|
||||
target_keys.update(namespace_keys)
|
||||
|
||||
missing_keys = default_keys - target_keys
|
||||
if missing_keys:
|
||||
missing_by_locale[target_locale] = sorted(list(missing_keys))
|
||||
|
||||
response = MissingTranslationsResponse(missing_translations=missing_by_locale)
|
||||
|
||||
total_missing = sum(len(keys) for keys in missing_by_locale.values())
|
||||
api_logger.info(f"Found {total_missing} missing translations across {len(missing_by_locale)} locales")
|
||||
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/reload", response_model=ApiResponse)
|
||||
def reload_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Trigger hot reload of translation files (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale to reload (defaults to reloading all locales)
|
||||
|
||||
Returns:
|
||||
Reload status and statistics
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Reload translations request: locale={locale or 'all'}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
if not settings.I18N_ENABLE_HOT_RELOAD:
|
||||
api_logger.warning("Hot reload is disabled in configuration")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=t("i18n.reload.disabled")
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
try:
|
||||
# Reload translations
|
||||
translation_service.reload(locale)
|
||||
|
||||
# Get statistics
|
||||
available_locales = translation_service.get_available_locales()
|
||||
reloaded_locales = [locale] if locale else available_locales
|
||||
|
||||
response = ReloadResponse(
|
||||
success=True,
|
||||
reloaded_locales=reloaded_locales,
|
||||
total_locales=len(available_locales)
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Successfully reloaded translations for: {', '.join(reloaded_locales)}"
|
||||
)
|
||||
|
||||
return success(data=response.dict(), msg=t("i18n.reload.success"))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to reload translations: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=t("i18n.reload.failed", error=str(e))
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Performance Monitoring APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/metrics", response_model=ApiResponse)
|
||||
def get_metrics(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get i18n performance metrics (admin only).
|
||||
|
||||
Returns:
|
||||
Performance metrics including:
|
||||
- Request counts
|
||||
- Missing translations
|
||||
- Timing statistics
|
||||
- Locale usage
|
||||
- Error counts
|
||||
"""
|
||||
api_logger.info(f"Get metrics request: admin={current_user.username}")
|
||||
|
||||
translation_service = get_translation_service()
|
||||
metrics = translation_service.get_metrics_summary()
|
||||
|
||||
api_logger.info("Returning i18n metrics")
|
||||
return success(data=metrics, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/metrics/cache", response_model=ApiResponse)
|
||||
def get_cache_stats(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get cache statistics (admin only).
|
||||
|
||||
Returns:
|
||||
Cache statistics including:
|
||||
- Hit/miss rates
|
||||
- LRU cache performance
|
||||
- Loaded locales
|
||||
- Memory usage
|
||||
"""
|
||||
api_logger.info(f"Get cache stats request: admin={current_user.username}")
|
||||
|
||||
translation_service = get_translation_service()
|
||||
cache_stats = translation_service.get_cache_stats()
|
||||
memory_usage = translation_service.get_memory_usage()
|
||||
|
||||
data = {
|
||||
"cache": cache_stats,
|
||||
"memory": memory_usage
|
||||
}
|
||||
|
||||
api_logger.info("Returning cache statistics")
|
||||
return success(data=data, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/metrics/prometheus")
|
||||
def get_prometheus_metrics(
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get metrics in Prometheus format (admin only).
|
||||
|
||||
Returns:
|
||||
Prometheus-formatted metrics as plain text
|
||||
"""
|
||||
api_logger.info(f"Get Prometheus metrics request: admin={current_user.username}")
|
||||
|
||||
from app.i18n.metrics import get_metrics
|
||||
metrics = get_metrics()
|
||||
prometheus_output = metrics.export_prometheus()
|
||||
|
||||
from fastapi.responses import PlainTextResponse
|
||||
return PlainTextResponse(content=prometheus_output)
|
||||
|
||||
|
||||
@router.post("/metrics/reset", response_model=ApiResponse)
|
||||
def reset_metrics(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Reset all metrics (admin only).
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(f"Reset metrics request: admin={current_user.username}")
|
||||
|
||||
from app.i18n.metrics import get_metrics
|
||||
metrics = get_metrics()
|
||||
metrics.reset()
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_service.cache.reset_stats()
|
||||
|
||||
api_logger.info("Metrics reset completed")
|
||||
return success(msg=t("i18n.metrics.reset_success"))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Missing Translation Logging and Reporting APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/logs/missing", response_model=ApiResponse)
|
||||
def get_missing_translation_logs(
|
||||
locale: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get missing translation logs (admin only).
|
||||
|
||||
Returns logged missing translations with context information.
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
limit: Maximum number of entries to return (default: 100)
|
||||
|
||||
Returns:
|
||||
Missing translation logs with context
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get missing translation logs request: locale={locale}, "
|
||||
f"limit={limit}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Get missing translations
|
||||
missing_translations = translation_logger.get_missing_translations(locale)
|
||||
|
||||
# Get missing with context
|
||||
missing_with_context = translation_logger.get_missing_with_context(locale, limit)
|
||||
|
||||
# Get statistics
|
||||
statistics = translation_logger.get_statistics()
|
||||
|
||||
data = {
|
||||
"missing_translations": missing_translations,
|
||||
"recent_context": missing_with_context,
|
||||
"statistics": statistics
|
||||
}
|
||||
|
||||
api_logger.info(
|
||||
f"Returning {statistics['total_missing']} missing translations"
|
||||
)
|
||||
return success(data=data, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/logs/missing/report", response_model=ApiResponse)
|
||||
def generate_missing_translation_report(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Generate a comprehensive missing translation report (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
Comprehensive report with missing translations and statistics
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Generate missing translation report request: locale={locale}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Generate report
|
||||
report = translation_logger.generate_report(locale)
|
||||
|
||||
api_logger.info(
|
||||
f"Generated report with {report['total_missing']} missing translations"
|
||||
)
|
||||
return success(data=report, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/logs/missing/export", response_model=ApiResponse)
|
||||
def export_missing_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Export missing translations to JSON file (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
Export status and file path
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Export missing translations request: locale={locale}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
from datetime import datetime
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
locale_suffix = f"_{locale}" if locale else "_all"
|
||||
output_file = f"logs/i18n/missing_translations{locale_suffix}_{timestamp}.json"
|
||||
|
||||
# Export to file
|
||||
translation_logger.export_to_json(output_file)
|
||||
|
||||
api_logger.info(f"Missing translations exported to: {output_file}")
|
||||
return success(
|
||||
data={"file_path": output_file},
|
||||
msg=t("i18n.logs.export_success", file=output_file)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/logs/missing", response_model=ApiResponse)
|
||||
def clear_missing_translation_logs(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Clear missing translation logs (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale to clear (clears all if not specified)
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Clear missing translation logs request: locale={locale or 'all'}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Clear logs
|
||||
translation_logger.clear(locale)
|
||||
|
||||
api_logger.info(f"Cleared missing translation logs for: {locale or 'all locales'}")
|
||||
return success(msg=t("i18n.logs.clear_success"))
|
||||
@@ -1,3 +1,19 @@
|
||||
"""
|
||||
Memory Reflection Controller
|
||||
|
||||
This module provides REST API endpoints for managing memory reflection configurations
|
||||
and operations. It handles reflection engine setup, configuration management, and
|
||||
execution of self-reflection processes across memory systems.
|
||||
|
||||
Key Features:
|
||||
- Reflection configuration management (save, retrieve, update)
|
||||
- Workspace-wide reflection execution across multiple applications
|
||||
- Individual configuration-based reflection runs
|
||||
- Multi-language support for reflection outputs
|
||||
- Integration with Neo4j memory storage and LLM models
|
||||
- Comprehensive error handling and logging
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
@@ -28,9 +44,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# Load environment variables for configuration
|
||||
load_dotenv()
|
||||
|
||||
# Initialize API logger for request tracking and debugging
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Configure router with prefix and tags for API organization
|
||||
router = APIRouter(
|
||||
prefix="/memory",
|
||||
tags=["Memory"],
|
||||
@@ -43,7 +63,38 @@ async def save_reflection_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Save reflection configuration to data_comfig table"""
|
||||
"""
|
||||
Save reflection configuration to memory config table
|
||||
|
||||
Persists reflection engine configuration settings to the data_config table,
|
||||
including reflection parameters, model settings, and evaluation criteria.
|
||||
Validates configuration parameters and ensures data consistency.
|
||||
|
||||
Args:
|
||||
request: Memory reflection configuration data including:
|
||||
- config_id: Configuration identifier to update
|
||||
- reflection_enabled: Whether reflection is enabled
|
||||
- reflection_period_in_hours: Reflection execution interval
|
||||
- reflexion_range: Scope of reflection (partial/all)
|
||||
- baseline: Reflection strategy (time/fact/hybrid)
|
||||
- reflection_model_id: LLM model for reflection operations
|
||||
- memory_verify: Enable memory verification checks
|
||||
- quality_assessment: Enable quality assessment evaluation
|
||||
current_user: Authenticated user saving the configuration
|
||||
db: Database session for data operations
|
||||
|
||||
Returns:
|
||||
dict: Success response with saved reflection configuration data
|
||||
|
||||
Raises:
|
||||
HTTPException 400: If config_id is missing or parameters are invalid
|
||||
HTTPException 500: If configuration save operation fails
|
||||
|
||||
Database Operations:
|
||||
- Updates memory_config table with reflection settings
|
||||
- Commits transaction and refreshes entity
|
||||
- Maintains configuration consistency
|
||||
"""
|
||||
try:
|
||||
config_id = request.config_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
@@ -54,6 +105,7 @@ async def save_reflection_config(
|
||||
)
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
# Update reflection configuration in database
|
||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||
db,
|
||||
config_id=config_id,
|
||||
@@ -66,6 +118,7 @@ async def save_reflection_config(
|
||||
quality_assessment=request.quality_assessment
|
||||
)
|
||||
|
||||
# Commit transaction and refresh entity
|
||||
db.commit()
|
||||
db.refresh(memory_config)
|
||||
|
||||
@@ -102,13 +155,55 @@ async def start_workspace_reflection(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""启动工作空间中所有匹配应用的反思功能"""
|
||||
"""
|
||||
Start reflection functionality for all matching applications in workspace
|
||||
|
||||
Initiates reflection processes across all applications within the user's current
|
||||
workspace that have valid memory configurations. Processes each application's
|
||||
configurations and associated end users, executing reflection operations
|
||||
with proper error isolation and transaction management.
|
||||
|
||||
This endpoint serves as a workspace-wide reflection orchestrator, ensuring
|
||||
that reflection failures for individual users don't affect other operations.
|
||||
|
||||
Args:
|
||||
current_user: Authenticated user initiating workspace reflection
|
||||
db: Database session for configuration queries
|
||||
|
||||
Returns:
|
||||
dict: Success response with reflection results for all processed applications:
|
||||
- app_id: Application identifier
|
||||
- config_id: Memory configuration identifier
|
||||
- end_user_id: End user identifier
|
||||
- reflection_result: Individual reflection operation result
|
||||
|
||||
Processing Logic:
|
||||
1. Retrieve all applications in the current workspace
|
||||
2. Filter applications with valid memory configurations
|
||||
3. For each configuration, find matching releases
|
||||
4. Execute reflection for each end user with isolated transactions
|
||||
5. Aggregate results with error handling per user
|
||||
|
||||
Error Handling:
|
||||
- Individual user reflection failures are isolated
|
||||
- Failed operations are logged and included in results
|
||||
- Database transactions are isolated per user to prevent cascading failures
|
||||
- Comprehensive error reporting for debugging
|
||||
|
||||
Raises:
|
||||
HTTPException 500: If workspace reflection initialization fails
|
||||
|
||||
Performance Notes:
|
||||
- Uses independent database sessions for each user operation
|
||||
- Prevents transaction failures from affecting other users
|
||||
- Comprehensive logging for operation tracking
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||
|
||||
# 使用独立的数据库会话来获取工作空间应用详情,避免事务失败
|
||||
# Use independent database session to get workspace app details, avoiding transaction failures
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as query_db:
|
||||
service = WorkspaceAppService(query_db)
|
||||
@@ -116,8 +211,9 @@ async def start_workspace_reflection(
|
||||
|
||||
reflection_results = []
|
||||
|
||||
# Process each application in the workspace
|
||||
for data in result['apps_detailed_info']:
|
||||
# 跳过没有配置的应用
|
||||
# Skip applications without configurations
|
||||
if not data['memory_configs']:
|
||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||
continue
|
||||
@@ -126,22 +222,22 @@ async def start_workspace_reflection(
|
||||
memory_configs = data['memory_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
# 为每个配置和用户组合执行反思
|
||||
# Execute reflection for each configuration and user combination
|
||||
for config in memory_configs:
|
||||
config_id_str = str(config['config_id'])
|
||||
|
||||
# 找到匹配此配置的所有release
|
||||
# Find all releases matching this configuration
|
||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||
|
||||
if not matching_releases:
|
||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||
continue
|
||||
|
||||
# 为每个用户执行反思 - 使用独立的数据库会话
|
||||
# Execute reflection for each user - using independent database sessions
|
||||
for user in end_users:
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||
|
||||
# 为每个用户创建独立的数据库会话,避免事务失败影响其他用户
|
||||
# Create independent database session for each user to avoid transaction failure impact
|
||||
with get_db_context() as user_db:
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(user_db)
|
||||
@@ -184,14 +280,51 @@ async def start_reflection_configs(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||
"""
|
||||
Query reflection configuration information by config_id
|
||||
|
||||
Retrieves detailed reflection configuration settings from the memory_config
|
||||
table for a specific configuration ID. Provides comprehensive reflection
|
||||
parameters including model settings, evaluation criteria, and operational flags.
|
||||
|
||||
Args:
|
||||
config_id: Configuration identifier (UUID or integer) to query
|
||||
current_user: Authenticated user making the request
|
||||
db: Database session for data operations
|
||||
|
||||
Returns:
|
||||
dict: Success response with detailed reflection configuration:
|
||||
- config_id: Resolved configuration identifier
|
||||
- reflection_enabled: Whether reflection is enabled for this config
|
||||
- reflection_period_in_hours: Reflection execution interval
|
||||
- reflexion_range: Scope of reflection operations (partial/all)
|
||||
- baseline: Reflection strategy (time/fact/hybrid)
|
||||
- reflection_model_id: LLM model identifier for reflection
|
||||
- memory_verify: Memory verification flag
|
||||
- quality_assessment: Quality assessment flag
|
||||
|
||||
Database Operations:
|
||||
- Queries memory_config table by resolved config_id
|
||||
- Retrieves all reflection-related configuration fields
|
||||
- Resolves configuration ID for consistent formatting
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If configuration with specified ID is not found
|
||||
HTTPException 500: If configuration query operation fails
|
||||
|
||||
ID Resolution:
|
||||
- Supports both UUID and integer config_id formats
|
||||
- Automatically resolves to appropriate internal format
|
||||
- Maintains consistency across different ID representations
|
||||
"""
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
try:
|
||||
config_id=resolve_config_id(config_id,db)
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
memory_config_id = resolve_config_id(result.config_id, db)
|
||||
# 构建返回数据
|
||||
|
||||
# Build response data with comprehensive configuration details
|
||||
reflection_config = {
|
||||
"config_id": memory_config_id,
|
||||
"reflection_enabled": result.enable_self_reflexion,
|
||||
@@ -204,10 +337,12 @@ async def start_reflection_configs(
|
||||
}
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="反思配置查询成功")
|
||||
|
||||
|
||||
api_logger.info(f"Successfully queried reflection config, config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="Reflection configuration query successful")
|
||||
|
||||
except HTTPException:
|
||||
# 重新抛出HTTP异常
|
||||
# Re-raise HTTP exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"查询反思配置失败: {str(e)}")
|
||||
@@ -223,13 +358,66 @@ async def reflection_run(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
# 使用集中化的语言校验
|
||||
"""
|
||||
Execute reflection engine with specified configuration
|
||||
|
||||
Runs the reflection engine using configuration parameters from the database.
|
||||
Validates model availability, sets up the reflection engine with proper
|
||||
configuration, and executes the reflection process with multi-language support.
|
||||
|
||||
This endpoint provides a test run capability for reflection configurations,
|
||||
allowing users to validate their reflection settings and see results before
|
||||
deploying to production environments.
|
||||
|
||||
Args:
|
||||
config_id: Configuration identifier (UUID or integer) for reflection settings
|
||||
language_type: Language preference header for output localization (optional)
|
||||
current_user: Authenticated user executing the reflection
|
||||
db: Database session for configuration queries
|
||||
|
||||
Returns:
|
||||
dict: Success response with reflection execution results including:
|
||||
- baseline: Reflection strategy used
|
||||
- source_data: Input data processed
|
||||
- memory_verifies: Memory verification results (if enabled)
|
||||
- quality_assessments: Quality assessment results (if enabled)
|
||||
- reflexion_data: Generated reflection insights and solutions
|
||||
|
||||
Configuration Validation:
|
||||
- Verifies configuration exists in database
|
||||
- Validates LLM model availability
|
||||
- Falls back to default model if specified model is unavailable
|
||||
- Ensures all required parameters are properly set
|
||||
|
||||
Reflection Engine Setup:
|
||||
- Creates ReflectionConfig with database parameters
|
||||
- Initializes Neo4j connector for memory access
|
||||
- Sets up ReflectionEngine with validated model
|
||||
- Configures language preferences for output
|
||||
|
||||
Error Handling:
|
||||
- Model validation with fallback to default
|
||||
- Configuration validation and error reporting
|
||||
- Comprehensive logging for debugging
|
||||
- Graceful handling of missing configurations
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If configuration is not found
|
||||
HTTPException 500: If reflection execution fails
|
||||
|
||||
Performance Notes:
|
||||
- Direct database query for configuration retrieval
|
||||
- Model validation to prevent runtime failures
|
||||
- Efficient reflection engine initialization
|
||||
- Language-aware output processing
|
||||
"""
|
||||
# Use centralized language validation for consistent localization
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 使用MemoryConfigRepository查询反思配置
|
||||
|
||||
# Query reflection configuration using MemoryConfigRepository
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
@@ -239,7 +427,7 @@ async def reflection_run(
|
||||
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 验证模型ID是否存在
|
||||
# Validate model ID existence
|
||||
model_id = result.reflection_model_id
|
||||
if model_id:
|
||||
try:
|
||||
@@ -250,6 +438,7 @@ async def reflection_run(
|
||||
# 可以设置为None,让反思引擎使用默认模型
|
||||
model_id = None
|
||||
|
||||
# Create reflection configuration with database parameters
|
||||
config = ReflectionConfig(
|
||||
enabled=result.enable_self_reflexion,
|
||||
iteration_period=result.iteration_period,
|
||||
@@ -262,11 +451,13 @@ async def reflection_run(
|
||||
model_id=model_id,
|
||||
language_type=language_type
|
||||
)
|
||||
|
||||
# Initialize Neo4j connector and reflection engine
|
||||
connector = Neo4jConnector()
|
||||
engine = ReflectionEngine(
|
||||
config=config,
|
||||
neo4j_connector=connector,
|
||||
llm_client=model_id # 传入验证后的 model_id
|
||||
llm_client=model_id # Pass validated model_id
|
||||
)
|
||||
|
||||
result=await (engine.reflection_run())
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
"""
|
||||
Memory Short Term Controller
|
||||
|
||||
This module provides REST API endpoints for managing short-term and long-term memory
|
||||
data retrieval and analysis. It handles memory system statistics, data aggregation,
|
||||
and provides comprehensive memory insights for end users.
|
||||
|
||||
Key Features:
|
||||
- Short-term memory data retrieval and statistics
|
||||
- Long-term memory data aggregation
|
||||
- Entity count integration
|
||||
- Multi-language response support
|
||||
- Memory system analytics and reporting
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -13,9 +28,13 @@ from app.models.user_model import User
|
||||
from app.services.memory_short_service import LongService, ShortService
|
||||
from app.services.memory_storage_service import search_entity
|
||||
|
||||
# Load environment variables for configuration
|
||||
load_dotenv()
|
||||
|
||||
# Initialize API logger for request tracking and debugging
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Configure router with prefix and tags for API organization
|
||||
router = APIRouter(
|
||||
prefix="/memory/short",
|
||||
tags=["Memory"],
|
||||
@@ -27,24 +46,73 @@ async def short_term_configs(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
"""
|
||||
Retrieve comprehensive short-term and long-term memory statistics
|
||||
|
||||
Provides a comprehensive overview of memory system data for a specific end user,
|
||||
including short-term memory entries, long-term memory aggregations, entity counts,
|
||||
and retrieval statistics. Supports multi-language responses based on request headers.
|
||||
|
||||
This endpoint serves as a central dashboard for memory system analytics, combining
|
||||
data from multiple memory subsystems to provide a holistic view of user memory state.
|
||||
|
||||
Args:
|
||||
end_user_id: Unique identifier for the end user whose memory data to retrieve
|
||||
language_type: Language preference header for response localization (optional)
|
||||
current_user: Authenticated user making the request (injected by dependency)
|
||||
db: Database session for data operations (injected by dependency)
|
||||
|
||||
Returns:
|
||||
dict: Success response containing comprehensive memory statistics:
|
||||
- short_term: List of short-term memory entries with detailed data
|
||||
- long_term: List of long-term memory aggregations and summaries
|
||||
- entity: Count of entities associated with the end user
|
||||
- retrieval_number: Total count of short-term memory retrievals
|
||||
- long_term_number: Total count of long-term memory entries
|
||||
|
||||
Response Structure:
|
||||
{
|
||||
"code": 200,
|
||||
"msg": "Short-term memory system data retrieved successfully",
|
||||
"data": {
|
||||
"short_term": [...], # Short-term memory entries
|
||||
"long_term": [...], # Long-term memory data
|
||||
"entity": 42, # Entity count
|
||||
"retrieval_number": 156, # Short-term retrieval count
|
||||
"long_term_number": 23 # Long-term memory count
|
||||
}
|
||||
}
|
||||
|
||||
Raises:
|
||||
HTTPException: If end_user_id is invalid or data retrieval fails
|
||||
|
||||
Performance Notes:
|
||||
- Combines multiple service calls for comprehensive data
|
||||
- Entity search is performed asynchronously for better performance
|
||||
- Response time depends on memory data volume for the specified user
|
||||
"""
|
||||
# Use centralized language validation for consistent localization
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# 获取短期记忆数据
|
||||
short_term=ShortService(end_user_id, db)
|
||||
short_result=short_term.get_short_databasets()
|
||||
short_count=short_term.get_short_count()
|
||||
# Retrieve short-term memory data and statistics
|
||||
short_term = ShortService(end_user_id, db)
|
||||
short_result = short_term.get_short_databasets() # Get short-term memory entries
|
||||
short_count = short_term.get_short_count() # Get short-term retrieval count
|
||||
|
||||
long_term=LongService(end_user_id, db)
|
||||
long_result=long_term.get_long_databasets()
|
||||
# Retrieve long-term memory data and aggregations
|
||||
long_term = LongService(end_user_id, db)
|
||||
long_result = long_term.get_long_databasets() # Get long-term memory entries
|
||||
|
||||
# Get entity count for the specified end user
|
||||
entity_result = await search_entity(end_user_id)
|
||||
|
||||
# Compile comprehensive memory statistics response
|
||||
result = {
|
||||
'short_term': short_result,
|
||||
'long_term': long_result,
|
||||
'entity': entity_result.get('num', 0),
|
||||
"retrieval_number":short_count,
|
||||
"long_term_number":len(long_result)
|
||||
'short_term': short_result, # Short-term memory entries
|
||||
'long_term': long_result, # Long-term memory data
|
||||
'entity': entity_result.get('num', 0), # Entity count (default to 0 if not found)
|
||||
"retrieval_number": short_count, # Short-term retrieval statistics
|
||||
"long_term_number": len(long_result) # Long-term memory entry count
|
||||
}
|
||||
|
||||
return success(data=result, msg="短期记忆系统数据获取成功")
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -19,6 +20,7 @@ from app.services import user_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.security import verify_password
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -33,7 +35,8 @@ router = APIRouter(
|
||||
def create_superuser(
|
||||
user: user_schema.UserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_superuser: User = Depends(get_current_superuser)
|
||||
current_superuser: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""创建超级管理员(仅超级管理员可访问)"""
|
||||
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
|
||||
@@ -42,7 +45,7 @@ def create_superuser(
|
||||
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="超级管理员创建成功")
|
||||
return success(data=result_schema, msg=t("users.create.superuser_success"))
|
||||
|
||||
|
||||
@router.delete("/{user_id}", response_model=ApiResponse)
|
||||
@@ -50,6 +53,7 @@ def delete_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""停用用户(软删除)"""
|
||||
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -57,13 +61,14 @@ def delete_user(
|
||||
db=db, user_id_to_deactivate=user_id, current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
|
||||
return success(msg="用户停用成功")
|
||||
return success(msg=t("users.delete.deactivate_success"))
|
||||
|
||||
@router.post("/{user_id}/activate", response_model=ApiResponse)
|
||||
def activate_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""激活用户"""
|
||||
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -74,13 +79,14 @@ def activate_user(
|
||||
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="用户激活成功")
|
||||
return success(data=result_schema, msg=t("users.activate.success"))
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_current_user_info(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前用户信息"""
|
||||
api_logger.info(f"当前用户信息请求: {current_user.username}")
|
||||
@@ -105,7 +111,7 @@ def get_current_user_info(
|
||||
break
|
||||
|
||||
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||
return success(data=result_schema, msg="用户信息获取成功")
|
||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||
|
||||
|
||||
@router.get("/superusers", response_model=ApiResponse)
|
||||
@@ -113,6 +119,7 @@ def get_tenant_superusers(
|
||||
include_inactive: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
|
||||
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
|
||||
@@ -125,7 +132,7 @@ def get_tenant_superusers(
|
||||
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
|
||||
|
||||
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
|
||||
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
||||
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
||||
|
||||
|
||||
|
||||
@@ -134,6 +141,7 @@ def get_user_info_by_id(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""根据用户ID获取用户信息"""
|
||||
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -144,7 +152,7 @@ def get_user_info_by_id(
|
||||
api_logger.info(f"用户信息获取成功: {result.username}")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="用户信息获取成功")
|
||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||
|
||||
|
||||
@router.put("/change-password", response_model=ApiResponse)
|
||||
@@ -152,6 +160,7 @@ async def change_password(
|
||||
request: ChangePasswordRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""修改当前用户密码"""
|
||||
api_logger.info(f"用户密码修改请求: {current_user.username}")
|
||||
@@ -164,7 +173,7 @@ async def change_password(
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户密码修改成功: {current_user.username}")
|
||||
return success(msg="密码修改成功")
|
||||
return success(msg=t("auth.password.change_success"))
|
||||
|
||||
|
||||
@router.put("/admin/change-password", response_model=ApiResponse)
|
||||
@@ -172,6 +181,7 @@ async def admin_change_password(
|
||||
request: AdminChangePasswordRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""超级管理员修改指定用户的密码"""
|
||||
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
|
||||
@@ -186,16 +196,17 @@ async def admin_change_password(
|
||||
# 根据是否生成了随机密码来构造响应
|
||||
if request.new_password:
|
||||
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
|
||||
return success(msg="密码修改成功")
|
||||
return success(msg=t("auth.password.change_success"))
|
||||
else:
|
||||
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
return success(data=generated_password, msg=t("auth.password.reset_success"))
|
||||
|
||||
|
||||
@router.post("/verify_pwd", response_model=ApiResponse)
|
||||
def verify_pwd(
|
||||
request: VerifyPasswordRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""验证当前用户密码"""
|
||||
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
||||
@@ -203,8 +214,8 @@ def verify_pwd(
|
||||
is_valid = verify_password(request.password, current_user.hashed_password)
|
||||
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
||||
if not is_valid:
|
||||
raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg="验证完成")
|
||||
raise BusinessException(t("users.errors.password_verification_failed"), code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/send-email-code", response_model=ApiResponse)
|
||||
@@ -212,6 +223,7 @@ async def send_email_code(
|
||||
request: SendEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""发送邮箱验证码"""
|
||||
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
||||
@@ -219,7 +231,7 @@ async def send_email_code(
|
||||
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
||||
|
||||
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
||||
return success(msg="验证码已发送到您的邮箱,请查收")
|
||||
return success(msg=t("users.email.code_sent"))
|
||||
|
||||
|
||||
@router.put("/change-email", response_model=ApiResponse)
|
||||
@@ -227,6 +239,7 @@ async def change_email(
|
||||
request: VerifyEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""验证验证码并修改邮箱"""
|
||||
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
||||
@@ -239,4 +252,51 @@ async def change_email(
|
||||
)
|
||||
|
||||
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
||||
return success(msg="邮箱修改成功")
|
||||
return success(msg=t("users.email.change_success"))
|
||||
|
||||
|
||||
|
||||
@router.get("/me/language", response_model=ApiResponse)
|
||||
def get_current_user_language(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前用户的语言偏好"""
|
||||
api_logger.info(f"获取用户语言偏好: {current_user.username}")
|
||||
|
||||
language = user_service.get_user_language_preference(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"用户语言偏好获取成功: {current_user.username}, language={language}")
|
||||
return success(
|
||||
data=user_schema.LanguagePreferenceResponse(language=language),
|
||||
msg=t("users.language.get_success")
|
||||
)
|
||||
|
||||
|
||||
@router.put("/me/language", response_model=ApiResponse)
|
||||
def update_current_user_language(
|
||||
request: user_schema.LanguagePreferenceRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""设置当前用户的语言偏好"""
|
||||
api_logger.info(f"更新用户语言偏好: {current_user.username}, language={request.language}")
|
||||
|
||||
updated_user = user_service.update_user_language_preference(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
language=request.language,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"用户语言偏好更新成功: {current_user.username}, language={request.language}")
|
||||
return success(
|
||||
data=user_schema.LanguagePreferenceResponse(language=updated_user.preferred_language),
|
||||
msg=t("users.language.update_success")
|
||||
)
|
||||
|
||||
@@ -14,6 +14,12 @@ from app.dependencies import (
|
||||
get_current_user,
|
||||
workspace_access_guard,
|
||||
)
|
||||
from app.i18n.dependencies import get_current_language, get_translator
|
||||
from app.i18n.serializers import (
|
||||
WorkspaceSerializer,
|
||||
WorkspaceMemberSerializer,
|
||||
WorkspaceInviteSerializer
|
||||
)
|
||||
from app.models.tenant_model import Tenants
|
||||
from app.models.user_model import User
|
||||
from app.models.workspace_model import InviteStatus
|
||||
@@ -65,7 +71,9 @@ def get_workspaces(
|
||||
include_current: bool = Query(True, description="是否包含当前工作空间"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_tenant: Tenants = Depends(get_current_tenant)
|
||||
current_tenant: Tenants = Depends(get_current_tenant),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前租户下用户参与的所有工作空间
|
||||
|
||||
@@ -88,8 +96,13 @@ def get_workspaces(
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
|
||||
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
|
||||
return success(data=workspaces_schema, msg="工作空间列表获取成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
workspaces_data = [WorkspaceResponse.model_validate(w).model_dump() for w in workspaces]
|
||||
workspaces_i18n = serializer.serialize_list(workspaces_data, language)
|
||||
|
||||
return success(data=workspaces_i18n, msg=t("workspace.list_retrieved"))
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
@@ -98,6 +111,8 @@ def create_workspace(
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""创建新的工作空间"""
|
||||
from app.core.language_utils import get_language_from_header
|
||||
@@ -118,8 +133,13 @@ def create_workspace(
|
||||
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
|
||||
f"创建者: {current_user.username}, language={language}"
|
||||
)
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间创建成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||
result_i18n = serializer.serialize(result_data, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.created"))
|
||||
|
||||
@router.put("", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
@@ -127,6 +147,8 @@ def update_workspace(
|
||||
workspace: WorkspaceUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""更新工作空间"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -139,14 +161,21 @@ def update_workspace(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间更新成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||
result_i18n = serializer.serialize(result_data, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.updated"))
|
||||
|
||||
@router.get("/members", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_cur_workspace_members(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间成员列表(关系序列化)"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
|
||||
@@ -157,8 +186,14 @@ def get_cur_workspace_members(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
|
||||
|
||||
# 转换为表格项并使用序列化器添加国际化字段
|
||||
table_items = _convert_members_to_table_items(members)
|
||||
return success(data=table_items, msg="工作空间成员列表获取成功")
|
||||
serializer = WorkspaceMemberSerializer()
|
||||
members_data = [item.model_dump() for item in table_items]
|
||||
members_i18n = serializer.serialize_list(members_data, language)
|
||||
|
||||
return success(data=members_i18n, msg=t("workspace.members.list_retrieved"))
|
||||
|
||||
|
||||
@router.put("/members", response_model=ApiResponse)
|
||||
@@ -168,6 +203,7 @@ def update_workspace_members(
|
||||
updates: List[WorkspaceMemberUpdate],
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
|
||||
@@ -178,7 +214,7 @@ def update_workspace_members(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
|
||||
return success(msg="成员角色更新成功")
|
||||
return success(msg=t("workspace.members.role_updated"))
|
||||
|
||||
|
||||
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
||||
@@ -187,6 +223,7 @@ def delete_workspace_member(
|
||||
member_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
@@ -198,7 +235,7 @@ def delete_workspace_member(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
|
||||
return success(msg="成员删除成功")
|
||||
return success(msg=t("workspace.members.deleted"))
|
||||
|
||||
|
||||
# 创建空间协作邀请
|
||||
@@ -208,6 +245,8 @@ def create_workspace_invite(
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""创建工作空间邀请"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -220,7 +259,12 @@ def create_workspace_invite(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
|
||||
return success(data=result, msg="邀请创建成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.created"))
|
||||
|
||||
|
||||
@router.get("/invites", response_model=ApiResponse)
|
||||
@@ -232,6 +276,8 @@ def get_workspace_invites(
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间邀请列表"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -246,18 +292,30 @@ def get_workspace_invites(
|
||||
offset=offset
|
||||
)
|
||||
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
|
||||
return success(data=invites, msg="邀请列表获取成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
invites_i18n = serializer.serialize_list(invites, language)
|
||||
|
||||
return success(data=invites_i18n, msg=t("workspace.invites.list_retrieved"))
|
||||
|
||||
|
||||
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
|
||||
def get_workspace_invite_info(
|
||||
token: str,
|
||||
db: Session = Depends(get_db),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间邀请用户信息(无需认证)"""
|
||||
result = workspace_service.validate_invite_token(db=db, token=token)
|
||||
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
|
||||
return success(data=result, msg="邀请验证成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.validated"))
|
||||
|
||||
|
||||
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
|
||||
@@ -267,6 +325,8 @@ def revoke_workspace_invite(
|
||||
invite_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""撤销工作空间邀请"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -279,7 +339,12 @@ def revoke_workspace_invite(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
|
||||
return success(data=result, msg="邀请撤销成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.revoked"))
|
||||
|
||||
# ==================== 公开邀请接口(无需认证) ====================
|
||||
|
||||
@@ -302,6 +367,7 @@ def switch_workspace(
|
||||
workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""切换工作空间"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
|
||||
@@ -312,7 +378,7 @@ def switch_workspace(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"成功切换工作空间为 {workspace_id}")
|
||||
return success(msg="工作空间切换成功")
|
||||
return success(msg=t("workspace.switched"))
|
||||
|
||||
|
||||
@router.get("/storage", response_model=ApiResponse)
|
||||
@@ -320,6 +386,7 @@ def switch_workspace(
|
||||
def get_workspace_storage_type(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前工作空间的存储类型"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -331,7 +398,7 @@ def get_workspace_storage_type(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
|
||||
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
|
||||
return success(data={"storage_type": storage_type}, msg=t("workspace.storage.type_retrieved"))
|
||||
|
||||
|
||||
@router.get("/workspace_models", response_model=ApiResponse)
|
||||
@@ -339,6 +406,8 @@ def get_workspace_storage_type(
|
||||
def workspace_models_configs(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -354,14 +423,14 @@ def workspace_models_configs(
|
||||
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作空间不存在或无权访问"
|
||||
detail=t("workspace.not_found")
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||
)
|
||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功")
|
||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg=t("workspace.models.config_retrieved"))
|
||||
|
||||
|
||||
@router.put("/workspace_models", response_model=ApiResponse)
|
||||
@@ -370,6 +439,7 @@ def update_workspace_models_configs(
|
||||
models_update: WorkspaceModelsUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""更新当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -386,5 +456,5 @@ def update_workspace_models_configs(
|
||||
f"成功更新工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
|
||||
)
|
||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功")
|
||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg=t("workspace.models.config_updated"))
|
||||
|
||||
|
||||
@@ -162,6 +162,44 @@ class Settings:
|
||||
# This controls the language used for memory summary titles and other generated content
|
||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# ========================================================================
|
||||
# Internationalization (i18n) Configuration
|
||||
# ========================================================================
|
||||
# Default language for API responses
|
||||
I18N_DEFAULT_LANGUAGE: str = os.getenv("I18N_DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# Supported languages (comma-separated)
|
||||
I18N_SUPPORTED_LANGUAGES: list[str] = [
|
||||
lang.strip()
|
||||
for lang in os.getenv("I18N_SUPPORTED_LANGUAGES", "zh,en").split(",")
|
||||
if lang.strip()
|
||||
]
|
||||
|
||||
# Core locales directory (community edition)
|
||||
# Use absolute path to work from any working directory
|
||||
I18N_CORE_LOCALES_DIR: str = os.getenv(
|
||||
"I18N_CORE_LOCALES_DIR",
|
||||
os.path.join(os.path.dirname(os.path.dirname(__file__)), "locales")
|
||||
)
|
||||
|
||||
# Premium locales directory (enterprise edition, optional)
|
||||
I18N_PREMIUM_LOCALES_DIR: Optional[str] = os.getenv("I18N_PREMIUM_LOCALES_DIR", None)
|
||||
|
||||
# Enable translation cache
|
||||
I18N_ENABLE_TRANSLATION_CACHE: bool = os.getenv("I18N_ENABLE_TRANSLATION_CACHE", "true").lower() == "true"
|
||||
|
||||
# LRU cache size for hot translations
|
||||
I18N_LRU_CACHE_SIZE: int = int(os.getenv("I18N_LRU_CACHE_SIZE", "1000"))
|
||||
|
||||
# Enable hot reload of translation files
|
||||
I18N_ENABLE_HOT_RELOAD: bool = os.getenv("I18N_ENABLE_HOT_RELOAD", "false").lower() == "true"
|
||||
|
||||
# Fallback language when translation is missing
|
||||
I18N_FALLBACK_LANGUAGE: str = os.getenv("I18N_FALLBACK_LANGUAGE", "zh")
|
||||
|
||||
# Log missing translations
|
||||
I18N_LOG_MISSING_TRANSLATIONS: bool = os.getenv("I18N_LOG_MISSING_TRANSLATIONS", "true").lower() == "true"
|
||||
|
||||
# Logging settings
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
|
||||
@@ -2,15 +2,37 @@ from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
||||
|
||||
|
||||
def content_input_node(state: ReadState) -> ReadState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
"""
|
||||
Start node - Extract content and maintain state information
|
||||
|
||||
Extracts the content from the first message in the state and returns it
|
||||
as the data field while preserving all other state information.
|
||||
|
||||
Args:
|
||||
state: ReadState containing messages and other state data
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with extracted content in data field
|
||||
"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
# Return content and maintain all state information
|
||||
return {"data": content}
|
||||
|
||||
|
||||
def content_input_write(state: WriteState) -> WriteState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
"""
|
||||
Start node - Extract content and maintain state information for write operations
|
||||
|
||||
Extracts the content from the first message in the state for write operations.
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages and other state data
|
||||
|
||||
Returns:
|
||||
WriteState: Updated state with extracted content in data field
|
||||
"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
return {"data": content}
|
||||
# Return content and maintain all state information
|
||||
return {"data": content}
|
||||
|
||||
@@ -19,19 +19,39 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class ProblemNodeService(LLMServiceMixin):
|
||||
"""问题处理节点服务类"""
|
||||
"""
|
||||
Problem processing node service class
|
||||
|
||||
Handles problem decomposition and extension operations using LLM services.
|
||||
Inherits from LLMServiceMixin to provide structured LLM calling capabilities.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# Create global service instance
|
||||
problem_service = ProblemNodeService()
|
||||
|
||||
|
||||
async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
"""问题分解节点"""
|
||||
"""
|
||||
Problem decomposition node
|
||||
|
||||
Breaks down complex user queries into smaller, more manageable sub-problems.
|
||||
Uses LLM to analyze the input and generate structured problem decomposition
|
||||
with question types and reasoning.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user input and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with problem decomposition results
|
||||
"""
|
||||
# 从状态中获取数据
|
||||
content = state.get('data', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
@@ -64,7 +84,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
# 添加更详细的日志记录
|
||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if not structured or not hasattr(structured, 'root'):
|
||||
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
@@ -106,7 +126,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
# Provide more detailed error information
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
@@ -116,7 +136,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
|
||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||
|
||||
# 创建默认的空结果
|
||||
# Create default empty result
|
||||
result = {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": content,
|
||||
@@ -130,13 +150,25 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 返回更新后的状态,包含spit_context字段
|
||||
# Return updated state including spit_context field
|
||||
return {"spit_data": result}
|
||||
|
||||
|
||||
async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
"""问题扩展节点"""
|
||||
# 获取原始数据和分解结果
|
||||
"""
|
||||
Problem extension node
|
||||
|
||||
Extends the decomposed problems from Split_The_Problem node by generating
|
||||
additional related questions and organizing them by original question.
|
||||
Uses LLM to create comprehensive question extensions for better memory retrieval.
|
||||
|
||||
Args:
|
||||
state: ReadState containing decomposed problems and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with extended problem results
|
||||
"""
|
||||
# Get original data and decomposition results
|
||||
start = time.time()
|
||||
content = state.get('data', '')
|
||||
data = state.get('spit_data', '')['context']
|
||||
@@ -182,7 +214,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
|
||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if not response_content or not hasattr(response_content, 'root'):
|
||||
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
|
||||
aggregated_dict = {}
|
||||
@@ -216,7 +248,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
# Provide more detailed error information
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
|
||||
@@ -29,6 +29,18 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
"""
|
||||
Configure RAG (Retrieval-Augmented Generation) settings
|
||||
|
||||
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||
weights, and reranker settings.
|
||||
|
||||
Args:
|
||||
state: Current state containing user_rag_memory_id
|
||||
|
||||
Returns:
|
||||
dict: RAG configuration dictionary
|
||||
"""
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
@@ -48,6 +60,19 @@ async def rag_config(state):
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
"""
|
||||
Retrieve knowledge using RAG approach
|
||||
|
||||
Performs knowledge retrieval from configured knowledge bases using the
|
||||
provided question and returns formatted results.
|
||||
|
||||
Args:
|
||||
state: Current state containing configuration
|
||||
question: Question to search for
|
||||
|
||||
Returns:
|
||||
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||
"""
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
@@ -68,12 +93,24 @@ async def rag_knowledge(state, question):
|
||||
|
||||
|
||||
async def llm_infomation(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Get LLM configuration information from state
|
||||
|
||||
Retrieves model configuration details including model ID and tenant ID
|
||||
from the memory configuration in the current state.
|
||||
|
||||
Args:
|
||||
state: ReadState containing memory configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Model configuration as Pydantic model
|
||||
"""
|
||||
memory_config = state.get('memory_config', None)
|
||||
model_id = memory_config.llm_model_id
|
||||
tenant_id = memory_config.tenant_id
|
||||
|
||||
# 使用现有的 memory_config 而不是重新查询数据库
|
||||
# 或者使用线程安全的数据库访问
|
||||
# Use existing memory_config instead of re-querying database
|
||||
# or use thread-safe database access
|
||||
with get_db_context() as db:
|
||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
|
||||
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||
@@ -82,16 +119,20 @@ async def llm_infomation(state: ReadState) -> ReadState:
|
||||
|
||||
async def clean_databases(data) -> str:
|
||||
"""
|
||||
简化的数据库搜索结果清理函数
|
||||
Simplified database search result cleaning function
|
||||
|
||||
Processes and cleans search results from various sources including
|
||||
reranked results and time-based search results. Extracts text content
|
||||
from structured data and returns as formatted string.
|
||||
|
||||
Args:
|
||||
data: 搜索结果数据
|
||||
data: Search result data (can be string, dict, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的内容字符串
|
||||
str: Cleaned content string
|
||||
"""
|
||||
try:
|
||||
# 解析JSON字符串
|
||||
# Parse JSON string
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
data = json.loads(data)
|
||||
@@ -101,24 +142,24 @@ async def clean_databases(data) -> str:
|
||||
if not isinstance(data, dict):
|
||||
return str(data)
|
||||
|
||||
# 获取结果数据
|
||||
# Get result data
|
||||
# with open("搜索结果.json","w",encoding='utf-8') as f:
|
||||
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
||||
results = data.get('results', data)
|
||||
if not isinstance(results, dict):
|
||||
return str(results)
|
||||
|
||||
# 收集所有内容
|
||||
# Collect all content
|
||||
content_list = []
|
||||
|
||||
# 处理重排序结果
|
||||
# Process reranked results
|
||||
reranked = results.get('reranked_results', {})
|
||||
if reranked:
|
||||
for category in ['summaries', 'statements', 'chunks', 'entities']:
|
||||
items = reranked.get(category, [])
|
||||
if isinstance(items, list):
|
||||
content_list.extend(items)
|
||||
# 处理时间搜索结果
|
||||
# Process time search results
|
||||
time_search = results.get('time_search', {})
|
||||
if time_search:
|
||||
if isinstance(time_search, dict):
|
||||
@@ -128,7 +169,7 @@ async def clean_databases(data) -> str:
|
||||
elif isinstance(time_search, list):
|
||||
content_list.extend(time_search)
|
||||
|
||||
# 提取文本内容
|
||||
# Extract text content
|
||||
text_parts = []
|
||||
for item in content_list:
|
||||
if isinstance(item, dict):
|
||||
@@ -146,10 +187,19 @@ async def clean_databases(data) -> str:
|
||||
|
||||
|
||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
'''
|
||||
|
||||
模型信息
|
||||
'''
|
||||
"""
|
||||
Retrieve information using simplified search approach
|
||||
|
||||
Processes extended problems from previous nodes and performs retrieval
|
||||
using either RAG or hybrid search based on storage type. Handles concurrent
|
||||
processing of multiple questions and deduplicates results.
|
||||
|
||||
Args:
|
||||
state: ReadState containing problem extensions and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with retrieval results and intermediate outputs
|
||||
"""
|
||||
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
@@ -163,7 +213,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
problem_list.append(data)
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
# Create async task to process individual questions
|
||||
async def process_question_nodes(idx, question):
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
@@ -209,7 +259,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
# Process all questions concurrently
|
||||
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
databases_data = {
|
||||
@@ -257,7 +307,20 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取end_user_id
|
||||
"""
|
||||
Advanced retrieve function using LangChain agents and tools
|
||||
|
||||
Uses LangChain agents with specialized retrieval tools (time-based and hybrid)
|
||||
to perform sophisticated information retrieval. Supports both RAG and traditional
|
||||
memory storage approaches with concurrent processing and result deduplication.
|
||||
|
||||
Args:
|
||||
state: ReadState containing problem extensions and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with retrieval results and intermediate outputs
|
||||
"""
|
||||
# Get end_user_id from state
|
||||
import time
|
||||
start = time.time()
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
@@ -299,21 +362,21 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||
)
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
# Create async task to process individual questions
|
||||
import asyncio
|
||||
|
||||
# 在模块级别定义信号量,限制最大并发数
|
||||
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
|
||||
# Define semaphore at module level to limit maximum concurrency
|
||||
SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations
|
||||
|
||||
async def process_question(idx, question):
|
||||
async with SEMAPHORE: # 限制并发
|
||||
async with SEMAPHORE: # Limit concurrency
|
||||
try:
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||
question)
|
||||
else:
|
||||
cleaned_query = question
|
||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||
# Use asyncio to run synchronous agent.invoke in thread pool
|
||||
import asyncio
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
@@ -362,7 +425,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
# Process all questions concurrently
|
||||
import asyncio
|
||||
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
|
||||
@@ -23,18 +23,39 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class SummaryNodeService(LLMServiceMixin):
|
||||
"""总结节点服务类"""
|
||||
"""
|
||||
Summary node service class
|
||||
|
||||
Handles summary generation operations using LLM services. Inherits from
|
||||
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||
generating summaries from retrieved information.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# Create global service instance
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
"""
|
||||
Configure RAG (Retrieval-Augmented Generation) settings for summary operations
|
||||
|
||||
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||
weights, and reranker settings specifically for summary generation.
|
||||
|
||||
Args:
|
||||
state: Current state containing user_rag_memory_id
|
||||
|
||||
Returns:
|
||||
dict: RAG configuration dictionary with knowledge base settings
|
||||
"""
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
@@ -54,6 +75,23 @@ async def rag_config(state):
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
"""
|
||||
Retrieve knowledge using RAG approach for summary generation
|
||||
|
||||
Performs knowledge retrieval from configured knowledge bases using the
|
||||
provided question and returns formatted results for summary processing.
|
||||
|
||||
Args:
|
||||
state: Current state containing configuration
|
||||
question: Question to search for in knowledge base
|
||||
|
||||
Returns:
|
||||
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||
- retrieval_knowledge: List of retrieved knowledge chunks
|
||||
- clean_content: Formatted content string
|
||||
- cleaned_query: Processed query string
|
||||
- raw_results: Raw retrieval results
|
||||
"""
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
@@ -74,6 +112,18 @@ async def rag_knowledge(state, question):
|
||||
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Retrieve conversation history for summary context
|
||||
|
||||
Gets the conversation history for the current user to provide context
|
||||
for summary generation operations.
|
||||
|
||||
Args:
|
||||
state: ReadState containing end_user_id
|
||||
|
||||
Returns:
|
||||
ReadState: Conversation history data
|
||||
"""
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
return history
|
||||
@@ -82,11 +132,26 @@ async def summary_history(state: ReadState) -> ReadState:
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||
search_mode) -> str:
|
||||
"""
|
||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||
Enhanced summary_llm function with better error handling and data validation
|
||||
|
||||
Generates summaries using LLM with structured output. Includes fallback mechanisms
|
||||
for handling LLM failures and provides robust error recovery.
|
||||
|
||||
Args:
|
||||
state: ReadState containing current context
|
||||
history: Conversation history for context
|
||||
retrieve_info: Retrieved information to summarize
|
||||
template_name: Jinja2 template name for prompt generation
|
||||
operation_name: Type of operation (summary, input_summary, retrieve_summary)
|
||||
response_model: Pydantic model for structured output
|
||||
search_mode: Search mode flag ("0" for simple, "1" for complex)
|
||||
|
||||
Returns:
|
||||
str: Generated summary text or fallback message
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
|
||||
# 构建系统提示词
|
||||
# Build system prompt
|
||||
if str(search_mode) == "0":
|
||||
system_prompt = await summary_service.template_service.render_template(
|
||||
template_name=template_name,
|
||||
@@ -103,7 +168,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
try:
|
||||
# 使用优化的LLM服务进行结构化输出
|
||||
# Use optimized LLM service for structured output
|
||||
with get_db_context() as db_session:
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
@@ -112,23 +177,23 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
response_model=response_model,
|
||||
fallback_value=None
|
||||
)
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if structured is None:
|
||||
logger.warning("LLM返回None,使用默认回答")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
# 根据操作类型提取答案
|
||||
# Extract answer based on operation type
|
||||
if operation_name == "summary":
|
||||
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
# 处理RetrieveSummaryResponse
|
||||
# Handle RetrieveSummaryResponse
|
||||
if hasattr(structured, 'data') and structured.data:
|
||||
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
logger.warning("结构化响应缺少data字段")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
# 验证答案不为空
|
||||
# Validate answer is not empty
|
||||
if not aimessages or aimessages.strip() == "":
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
@@ -137,7 +202,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
||||
|
||||
# 尝试非结构化输出作为fallback
|
||||
# Try unstructured output as fallback
|
||||
try:
|
||||
logger.info("尝试非结构化输出作为fallback")
|
||||
response = await summary_service.call_llm_simple(
|
||||
@@ -148,9 +213,9 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
)
|
||||
|
||||
if response and response.strip():
|
||||
# 简单清理响应
|
||||
# Simple response cleaning
|
||||
cleaned_response = response.strip()
|
||||
# 移除可能的JSON标记
|
||||
# Remove possible JSON markers
|
||||
if cleaned_response.startswith('```'):
|
||||
lines = cleaned_response.split('\n')
|
||||
cleaned_response = '\n'.join(lines[1:-1])
|
||||
@@ -165,6 +230,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
|
||||
|
||||
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||
"""
|
||||
Save summary results to Redis session storage
|
||||
|
||||
Stores the generated summary and user query in Redis for session management
|
||||
and conversation history tracking.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user and query information
|
||||
aimessages: Generated summary message to save
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state after saving to Redis
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
await SessionService(store).save_session(
|
||||
@@ -179,6 +257,20 @@ async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||
|
||||
|
||||
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||
"""
|
||||
Format summary results for different output types
|
||||
|
||||
Creates structured output formats for both input summary and retrieval summary
|
||||
operations, including metadata and intermediate results for frontend display.
|
||||
|
||||
Args:
|
||||
state: ReadState containing storage and user information
|
||||
aimessages: Generated summary message
|
||||
raw_results: Raw search/retrieval results
|
||||
|
||||
Returns:
|
||||
tuple: (input_summary, retrieve_summary) formatted result dictionaries
|
||||
"""
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
data = state.get("data", '')
|
||||
@@ -217,6 +309,19 @@ async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState
|
||||
|
||||
|
||||
async def Input_Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate quick input summary from retrieved information
|
||||
|
||||
Performs fast retrieval and generates a quick summary response for user queries.
|
||||
This function prioritizes speed by only searching summary nodes and provides
|
||||
immediate feedback to users.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user query, storage configuration, and context
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing summary results with status and metadata
|
||||
"""
|
||||
start = time.time()
|
||||
storage_type = state.get("storage_type", '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
@@ -266,6 +371,19 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
|
||||
|
||||
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate comprehensive summary from retrieved expansion issues
|
||||
|
||||
Processes retrieved expansion issues and generates a detailed summary using LLM.
|
||||
This function handles complex retrieval results and provides comprehensive answers
|
||||
based on expanded query results.
|
||||
|
||||
Args:
|
||||
state: ReadState containing retrieve data with expansion issues
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing comprehensive summary results
|
||||
"""
|
||||
retrieve = state.get("retrieve", '')
|
||||
history = await summary_history(state)
|
||||
import json
|
||||
@@ -299,13 +417,26 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
# Fixed coroutine call - await first, then access return value
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate final comprehensive summary from verified data
|
||||
|
||||
Creates the final summary using verified expansion issues and conversation history.
|
||||
This function processes verified data to generate the most comprehensive and
|
||||
accurate response to user queries.
|
||||
|
||||
Args:
|
||||
state: ReadState containing verified data and query information
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing final summary results
|
||||
"""
|
||||
start = time.time()
|
||||
query = state.get("data", '')
|
||||
verify = state.get("verify", '')
|
||||
@@ -336,13 +467,26 @@ async def Summary(state: ReadState) -> ReadState:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
# Fixed coroutine call - await first, then access return value
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary_fails(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate fallback summary when normal summary process fails
|
||||
|
||||
Provides a fallback summary generation mechanism when the standard summary
|
||||
process encounters errors or fails to produce satisfactory results. Uses
|
||||
a specialized failure template to handle edge cases.
|
||||
|
||||
Args:
|
||||
state: ReadState containing verified data and failure context
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing fallback summary results
|
||||
"""
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
history = await summary_history(state)
|
||||
|
||||
@@ -18,24 +18,46 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class VerificationNodeService(LLMServiceMixin):
|
||||
"""验证节点服务类"""
|
||||
"""
|
||||
Verification node service class
|
||||
|
||||
Handles data verification operations using LLM services. Inherits from
|
||||
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||
verifying and validating retrieved information.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# Create global service instance
|
||||
verification_service = VerificationNodeService()
|
||||
|
||||
|
||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
"""处理验证结果并生成输出格式"""
|
||||
"""
|
||||
Process verification results and generate output format
|
||||
|
||||
Transforms VerificationResult objects into structured output format suitable
|
||||
for frontend consumption. Handles conversion of VerificationItem objects to
|
||||
dictionary format and adds metadata for tracking.
|
||||
|
||||
Args:
|
||||
state: ReadState containing storage and user configuration
|
||||
messages_deal: VerificationResult containing verification outcomes
|
||||
|
||||
Returns:
|
||||
dict: Formatted verification result with status and metadata
|
||||
"""
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
data = state.get('data', '')
|
||||
|
||||
# 将 VerificationItem 对象转换为字典列表
|
||||
# Convert VerificationItem objects to dictionary list
|
||||
verified_data = []
|
||||
if messages_deal.expansion_issue:
|
||||
for item in messages_deal.expansion_issue:
|
||||
@@ -89,7 +111,7 @@ async def Verify(state: ReadState):
|
||||
|
||||
logger.info("Verify: 开始渲染模板")
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
# Generate JSON schema to guide LLM output format
|
||||
json_schema = VerificationResult.model_json_schema()
|
||||
|
||||
system_prompt = await verification_service.template_service.render_template(
|
||||
@@ -104,8 +126,8 @@ async def Verify(state: ReadState):
|
||||
# 使用优化的LLM服务,添加超时保护
|
||||
logger.info("Verify: 开始调用 LLM")
|
||||
try:
|
||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||
# Add asyncio.wait_for timeout wrapper to prevent infinite waiting
|
||||
# Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds)
|
||||
|
||||
with get_db_context() as db_session:
|
||||
structured = await asyncio.wait_for(
|
||||
@@ -122,7 +144,7 @@ async def Verify(state: ReadState):
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150秒超时
|
||||
timeout=150.0 # 150 second timeout
|
||||
)
|
||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
@@ -33,7 +33,19 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph():
|
||||
"""创建并返回 LangGraph 工作流"""
|
||||
"""
|
||||
Create and return a LangGraph workflow for memory reading operations
|
||||
|
||||
Builds a state graph workflow that handles memory retrieval, problem analysis,
|
||||
verification, and summarization. The workflow includes nodes for content input,
|
||||
problem splitting, retrieval, verification, and various summary operations.
|
||||
|
||||
Yields:
|
||||
StateGraph: Compiled LangGraph workflow for memory reading
|
||||
|
||||
Raises:
|
||||
Exception: If workflow creation fails
|
||||
"""
|
||||
try:
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
@@ -48,7 +60,7 @@ async def make_read_graph():
|
||||
workflow.add_node("Summary", Summary)
|
||||
workflow.add_node("Summary_fails", Summary_fails)
|
||||
|
||||
# 添加边
|
||||
# Add edges to define workflow flow
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
@@ -63,7 +75,7 @@ async def make_read_graph():
|
||||
'''-----'''
|
||||
# workflow.add_edge("Retrieve", END)
|
||||
|
||||
# 编译工作流
|
||||
# Compile workflow
|
||||
graph = workflow.compile()
|
||||
yield graph
|
||||
|
||||
@@ -72,108 +84,3 @@ async def make_read_graph():
|
||||
raise
|
||||
finally:
|
||||
print("工作流创建完成")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
end_user_id = '88a459f5_text09' # 组ID
|
||||
storage_type = 'neo4j' # 存储类型
|
||||
search_switch = '1' # 搜索开关
|
||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
import time
|
||||
start = time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||
"end_user_id": end_user_id
|
||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||
"memory_config": memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
summary = ''
|
||||
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
print(f"处理节点: {node_name}")
|
||||
|
||||
# 处理不同Summary节点的返回结构
|
||||
if 'Summary' in node_name:
|
||||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||||
summary = node_data['InputSummary']['summary_result']
|
||||
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
||||
summary = node_data['RetrieveSummary']['summary_result']
|
||||
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
||||
summary = node_data['summary']['summary_result']
|
||||
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
||||
summary = node_data['SummaryFails']['summary_result']
|
||||
|
||||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||||
if spit_data and spit_data != [] and spit_data != {}:
|
||||
_intermediate_outputs.append(spit_data)
|
||||
|
||||
# Problem_Extension 节点
|
||||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||||
_intermediate_outputs.append(problem_extension)
|
||||
|
||||
# Retrieve 节点
|
||||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||
_intermediate_outputs.extend(retrieve_node)
|
||||
|
||||
# Verify 节点
|
||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||
if verify_n and verify_n != [] and verify_n != {}:
|
||||
_intermediate_outputs.append(verify_n)
|
||||
|
||||
# Summary 节点
|
||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||
if summary_n and summary_n != [] and summary_n != {}:
|
||||
_intermediate_outputs.append(summary_n)
|
||||
|
||||
# # 过滤掉空值
|
||||
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||
#
|
||||
# # 优化搜索结果
|
||||
# print("=== 开始优化搜索结果 ===")
|
||||
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||
# result=reorder_output_results(optimized_outputs)
|
||||
# # 保存优化后的结果到文件
|
||||
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
|
||||
# import json
|
||||
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
|
||||
#
|
||||
print(f"=== 最终摘要 ===")
|
||||
print(summary)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
end = time.time()
|
||||
print(100 * 'y')
|
||||
print(f"总耗时: {end - start}s")
|
||||
print(100 * 'y')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
counter = COUNTState(limit=3)
|
||||
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
|
||||
|
||||
def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
@@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # 默认情况
|
||||
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
@@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
|
||||
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
status=state.get('verify', '')['status']
|
||||
status = state.get('verify', '')['status']
|
||||
# loop_count = counter.get_total()
|
||||
if "success" in status:
|
||||
# counter.reset()
|
||||
@@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
|
||||
# if loop_count < 2: # Maximum loop count is 3
|
||||
# return "content_input"
|
||||
# else:
|
||||
# counter.reset()
|
||||
# counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# Add default return value to avoid returning None
|
||||
|
||||
@@ -2,77 +2,104 @@ import json
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.redis_tool import count_store
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context, get_db
|
||||
from app.db import get_db_context
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
"""
|
||||
Write messages to RAG storage system
|
||||
|
||||
Combines user and AI messages into a single string format and stores them
|
||||
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for the conversation
|
||||
user_message: User's input message content
|
||||
ai_message: AI's response message content
|
||||
user_rag_memory_id: RAG memory identifier for storage location
|
||||
"""
|
||||
# RAG mode: combine messages into string format (maintain original logic)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||
actual_config_id, long_term_messages=[]):
|
||||
|
||||
|
||||
async def write(
|
||||
storage_type,
|
||||
end_user_id,
|
||||
user_message,
|
||||
ai_message,
|
||||
user_rag_memory_id,
|
||||
actual_end_user_id,
|
||||
actual_config_id,
|
||||
long_term_messages=None
|
||||
):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
Write memory with structured message support
|
||||
|
||||
Handles memory writing operations for different storage types (Neo4j/RAG).
|
||||
Supports both individual message pairs and batch long-term message processing.
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
storage_type: Storage type identifier ("neo4j" or "rag")
|
||||
end_user_id: Terminal user identifier
|
||||
user_message: User message content
|
||||
ai_message: AI response content
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_end_user_id: Actual user identifier for storage
|
||||
actual_config_id: Configuration identifier
|
||||
long_term_messages: Optional list of structured messages for batch processing
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
Logic explanation:
|
||||
- RAG mode: Combines user_message and ai_message into string format, maintains original logic
|
||||
- Neo4j mode: Uses structured message lists
|
||||
1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant]
|
||||
2. If only user_message exists: Creates single user message [user] (for historical memory scenarios)
|
||||
3. Each message is converted to independent Chunk, preserving speaker field
|
||||
"""
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
if long_term_messages is None:
|
||||
long_term_messages = []
|
||||
with get_db_context() as db:
|
||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
# Neo4j mode: Use structured message lists
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
# Always add user message (if not empty)
|
||||
if isinstance(user_message, str) and user_message.strip() != "":
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
# Only add assistant message when AI reply is not empty
|
||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||
# If long_term_messages provided, use it to replace structured_messages
|
||||
if long_term_messages and isinstance(long_term_messages, list):
|
||||
structured_messages = long_term_messages
|
||||
elif long_term_messages and isinstance(long_term_messages, str):
|
||||
# 如果是 JSON 字符串,先解析
|
||||
# If it's a JSON string, parse it first
|
||||
try:
|
||||
structured_messages = json.loads(long_term_messages)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
# If no messages, return directly
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
@@ -80,29 +107,41 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
|
||||
logger.info(
|
||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: JSON 字符串格式的消息列表
|
||||
str(actual_config_id), # config_id: 配置ID字符串
|
||||
actual_end_user_id, # end_user_id: User ID
|
||||
structured_messages, # message: JSON string format message list
|
||||
str(actual_config_id), # config_id: Configuration ID string
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||
|
||||
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
|
||||
"""
|
||||
Save long-term memory data to database
|
||||
|
||||
Handles the storage of long-term memory data based on different strategies
|
||||
(chunk-based or aggregate-based) and manages the transition from short-term
|
||||
to long-term memory storage.
|
||||
|
||||
Args:
|
||||
long_term_messages: Long-term message data to be saved
|
||||
actual_config_id: Configuration identifier for memory settings
|
||||
end_user_id: User identifier for memory association
|
||||
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||
scope: Scope/window size for memory processing
|
||||
"""
|
||||
with get_db_context() as db_session:
|
||||
repo = LongTermMemoryRepository(db_session)
|
||||
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data)==scope:
|
||||
if len(chunk_data) == scope:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'---------写入短长期-----------')
|
||||
else:
|
||||
@@ -112,18 +151,23 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
"""Window-based dialogue processing"""
|
||||
|
||||
'''根据窗口'''
|
||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
'''
|
||||
根据窗口获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
langchain_messages:原始数据LIST
|
||||
scope:窗口大小
|
||||
'''
|
||||
scope=scope
|
||||
|
||||
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||
"""
|
||||
Process dialogue based on window size and write to Neo4j
|
||||
|
||||
Manages conversation data based on a sliding window approach. When the window
|
||||
reaches the specified scope size, it triggers long-term memory storage to Neo4j.
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
memory_config: Memory configuration object containing settings
|
||||
langchain_messages: Original message data list
|
||||
scope: Window size determining when to trigger long-term storage
|
||||
"""
|
||||
scope = scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
@@ -135,50 +179,72 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = (redis_messages)
|
||||
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
config_id, formatted_messages)
|
||||
|
||||
await write(
|
||||
AgentMemory_Long_Term.STORAGE_NEO4J,
|
||||
end_user_id,
|
||||
"",
|
||||
"",
|
||||
None,
|
||||
end_user_id,
|
||||
config_id,
|
||||
formatted_messages
|
||||
)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""根据时间"""
|
||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
'''
|
||||
根据时间获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
'''
|
||||
"""Time-based memory processing"""
|
||||
|
||||
|
||||
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||
"""
|
||||
Process memory storage based on time intervals and write to Neo4j
|
||||
|
||||
Retrieves Redis data based on time intervals and writes it to Neo4j for
|
||||
long-term storage. This function handles time-based memory consolidation.
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
memory_config: Memory configuration object containing settings
|
||||
time: Time interval for data retrieval
|
||||
"""
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
format_messages = (long_time_data)
|
||||
messages=[]
|
||||
memory_config=memory_config.config_id
|
||||
format_messages = long_time_data
|
||||
messages = []
|
||||
memory_config = memory_config.config_id
|
||||
for i in format_messages:
|
||||
message=json.loads(i['Query'])
|
||||
messages+= message
|
||||
if format_messages!=[]:
|
||||
message = json.loads(i['Query'])
|
||||
messages += message
|
||||
if format_messages:
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
memory_config, messages)
|
||||
'''聚合判断'''
|
||||
|
||||
|
||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||
"""
|
||||
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
||||
Aggregation judgment function: determine if input sentence and historical messages describe the same event
|
||||
|
||||
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
|
||||
historical data or stored as separate events. This helps optimize memory storage and retrieval.
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: 内存配置对象
|
||||
"""
|
||||
end_user_id: Terminal user identifier
|
||||
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: Memory configuration object containing LLM settings
|
||||
|
||||
Returns:
|
||||
dict: Aggregation judgment result containing is_same_event flag and processed output
|
||||
"""
|
||||
history = None
|
||||
try:
|
||||
# 1. 获取历史会话数据(使用新方法)
|
||||
# 1. Get historical session data (using new method)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
@@ -210,7 +276,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
output_value = structured.output
|
||||
if isinstance(output_value, list):
|
||||
output_value = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in output_value
|
||||
]
|
||||
|
||||
@@ -223,16 +289,16 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||
memory_config.config_id, output_value)
|
||||
return result_dict
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
"output": ori_messages,
|
||||
"messages": ori_messages,
|
||||
"history": history if 'history' in locals() else [],
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,41 +2,53 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from app.core.memory.src.search import (
|
||||
search_by_temporal,
|
||||
search_by_keyword_temporal,
|
||||
)
|
||||
|
||||
|
||||
def extract_tool_message_content(response):
|
||||
"""从agent响应中提取ToolMessage内容和工具名称"""
|
||||
"""
|
||||
Extract ToolMessage content and tool names from agent response
|
||||
|
||||
Parses agent response messages to extract tool execution results and metadata.
|
||||
Handles JSON parsing and provides structured access to tool output data.
|
||||
|
||||
Args:
|
||||
response: Agent response dictionary containing messages
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
|
||||
- tool_name: Name of the executed tool
|
||||
- content: Parsed tool execution result (JSON or raw text)
|
||||
"""
|
||||
messages = response.get('messages', [])
|
||||
|
||||
for message in messages:
|
||||
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
|
||||
# 这是一个ToolMessage
|
||||
# This is a ToolMessage
|
||||
tool_content = message.content
|
||||
tool_name = None
|
||||
|
||||
# 尝试获取工具名称
|
||||
# Try to get tool name
|
||||
if hasattr(message, 'name'):
|
||||
tool_name = message.name
|
||||
elif hasattr(message, 'tool_name'):
|
||||
tool_name = message.tool_name
|
||||
|
||||
try:
|
||||
# 解析JSON内容
|
||||
# Parse JSON content
|
||||
parsed_content = json.loads(tool_content)
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': parsed_content
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# 如果不是JSON格式,直接返回内容
|
||||
# If not JSON format, return content directly
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': tool_content
|
||||
@@ -46,38 +58,61 @@ def extract_tool_message_content(response):
|
||||
|
||||
|
||||
class TimeRetrievalInput(BaseModel):
|
||||
"""时间检索工具的输入模式"""
|
||||
"""
|
||||
Input schema for time retrieval tool
|
||||
|
||||
Defines the expected input parameters for time-based retrieval operations.
|
||||
Used for validation and documentation of tool parameters.
|
||||
|
||||
Attributes:
|
||||
context: User input query content for search
|
||||
end_user_id: Group ID for filtering search results, defaults to test user
|
||||
"""
|
||||
context: str = Field(description="用户输入的查询内容")
|
||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
|
||||
|
||||
def create_time_retrieval_tool(end_user_id: str):
|
||||
"""
|
||||
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
|
||||
|
||||
Creates a specialized time-based retrieval tool that searches for statements within
|
||||
specified time ranges. Includes field cleaning functionality to remove unnecessary
|
||||
metadata from search results.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for scoping search results
|
||||
|
||||
Returns:
|
||||
function: Configured TimeRetrievalWithGroupId tool function
|
||||
"""
|
||||
|
||||
|
||||
def clean_temporal_result_fields(data):
|
||||
"""
|
||||
清理时间搜索结果中不需要的字段,并修改结构
|
||||
Clean unnecessary fields from temporal search results and modify structure
|
||||
|
||||
Removes metadata fields that are not needed for end-user consumption and
|
||||
restructures the response format for better usability.
|
||||
|
||||
Args:
|
||||
data: 要清理的数据
|
||||
data: Data to be cleaned (dict, list, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
Cleaned data with unnecessary fields removed
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# List of fields to filter out
|
||||
fields_to_remove = {
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'valid_at', 'invalid_at', 'statement_ids'
|
||||
}
|
||||
|
||||
|
||||
if isinstance(data, dict):
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
|
||||
# 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
|
||||
# Change statements: {"statements": [...]} to time_search: {"statements": [...]}
|
||||
cleaned_value = clean_temporal_result_fields(value)
|
||||
# 进一步将内部的 statements 改为 time_search
|
||||
# Further change internal statements to time_search
|
||||
if 'statements' in cleaned_value:
|
||||
cleaned['results'] = {
|
||||
'time_search': cleaned_value['statements']
|
||||
@@ -91,26 +126,35 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
return [clean_temporal_result_fields(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
@tool
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
|
||||
end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询上下文内容
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Performs time-based search operations with automatic metadata filtering. Supports
|
||||
flexible date range specification and provides clean, user-friendly output.
|
||||
|
||||
Explicit parameters:
|
||||
- context: Query context content
|
||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||
- end_user_id_param: Group ID (optional, overrides default group ID)
|
||||
- clean_output: Whether to clean metadata fields from output
|
||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results with temporal data
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
# 使用传入的参数或默认值
|
||||
# Use passed parameters or default values
|
||||
actual_end_user_id = end_user_id_param or end_user_id
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
# 基本时间搜索
|
||||
# Basic time search
|
||||
results = await search_by_temporal(
|
||||
end_user_id=actual_end_user_id,
|
||||
start_date=actual_start_date,
|
||||
@@ -118,33 +162,43 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
limit=10
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
cleaned_results = results
|
||||
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
@tool
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
|
||||
clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询内容
|
||||
- days_back: 向前搜索的天数,默认7天
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
- end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Performs combined keyword and temporal search operations with automatic metadata
|
||||
filtering. Provides more targeted search results by combining content relevance
|
||||
with time-based filtering.
|
||||
|
||||
Explicit parameters:
|
||||
- context: Query content for keyword matching
|
||||
- days_back: Number of days to search backwards, default 7 days
|
||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||
- clean_output: Whether to clean metadata fields from output
|
||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results combining keyword and temporal data
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
||||
|
||||
# 关键词时间搜索
|
||||
# Keyword time search
|
||||
results = await search_by_keyword_temporal(
|
||||
query_text=context,
|
||||
end_user_id=end_user_id,
|
||||
@@ -153,7 +207,7 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
limit=15
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
@@ -162,51 +216,60 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
|
||||
return TimeRetrievalWithGroupId
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"""
|
||||
创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段
|
||||
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
|
||||
|
||||
Creates an advanced hybrid search tool that combines multiple search strategies
|
||||
(keyword, vector, hybrid) with automatic result cleaning and formatting.
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数,包含end_user_id, limit, include等
|
||||
memory_config: Memory configuration object containing LLM and search settings
|
||||
**search_params: Search parameters including end_user_id, limit, include, etc.
|
||||
|
||||
Returns:
|
||||
function: Configured HybridSearch tool function with async capabilities
|
||||
"""
|
||||
|
||||
|
||||
def clean_result_fields(data):
|
||||
"""
|
||||
递归清理结果中不需要的字段
|
||||
Recursively clean unnecessary fields from results
|
||||
|
||||
Removes metadata fields that are not needed for end-user consumption,
|
||||
improving readability and reducing response size.
|
||||
|
||||
Args:
|
||||
data: 要清理的数据(可能是字典、列表或其他类型)
|
||||
data: Data to be cleaned (can be dict, list, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
Cleaned data with unnecessary fields removed
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# List of fields to filter out
|
||||
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
||||
}
|
||||
|
||||
|
||||
if isinstance(data, dict):
|
||||
# 对字典进行清理
|
||||
# Clean dictionary
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key not in fields_to_remove:
|
||||
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
|
||||
cleaned[key] = clean_result_fields(value) # Recursively clean nested data
|
||||
return cleaned
|
||||
elif isinstance(data, list):
|
||||
# 对列表中的每个元素进行清理
|
||||
# Clean each element in list
|
||||
return [clean_result_fields(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
# Return other types directly
|
||||
return data
|
||||
|
||||
|
||||
@tool
|
||||
async def HybridSearch(
|
||||
context: str,
|
||||
@@ -216,57 +279,63 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
clean_output: bool = True # 新增:是否清理输出字段
|
||||
clean_output: bool = True # New: whether to clean output fields
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
|
||||
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
|
||||
|
||||
Provides comprehensive search capabilities combining multiple search strategies
|
||||
with intelligent result ranking and automatic metadata filtering for clean output.
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
rerank_alpha: 重排序权重参数
|
||||
use_forgetting_rerank: 是否使用遗忘重排序
|
||||
use_llm_rerank: 是否使用LLM重排序
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
context: Query content for search
|
||||
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||
limit: Result quantity limit
|
||||
end_user_id: Group ID for filtering search results
|
||||
rerank_alpha: Reranking weight parameter for result scoring
|
||||
use_forgetting_rerank: Whether to use forgetting-based reranking
|
||||
use_llm_rerank: Whether to use LLM-based reranking
|
||||
clean_output: Whether to clean metadata fields from output
|
||||
|
||||
Returns:
|
||||
str: JSON formatted comprehensive search results
|
||||
"""
|
||||
try:
|
||||
# 导入run_hybrid_search函数
|
||||
# Import run_hybrid_search function
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
|
||||
# 合并参数,优先使用传入的参数
|
||||
# Merge parameters, prioritize passed parameters
|
||||
final_params = {
|
||||
"query_text": context,
|
||||
"search_type": search_type,
|
||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||
"limit": limit or search_params.get("limit", 10),
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||
"output_path": None, # 不保存到文件
|
||||
"output_path": None, # Don't save to file
|
||||
"memory_config": memory_config,
|
||||
"rerank_alpha": rerank_alpha,
|
||||
"use_forgetting_rerank": use_forgetting_rerank,
|
||||
"use_llm_rerank": use_llm_rerank
|
||||
}
|
||||
|
||||
# 执行混合检索
|
||||
# Execute hybrid retrieval
|
||||
raw_results = await run_hybrid_search(**final_params)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_result_fields(raw_results)
|
||||
else:
|
||||
cleaned_results = raw_results
|
||||
|
||||
# 格式化返回结果
|
||||
# Format return results
|
||||
formatted_results = {
|
||||
"search_query": context,
|
||||
"search_type": search_type,
|
||||
"results": cleaned_results
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_result = {
|
||||
"error": f"混合检索失败: {str(e)}",
|
||||
@@ -275,38 +344,52 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
return json.dumps(error_result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return HybridSearch
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"""
|
||||
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
|
||||
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
|
||||
|
||||
Creates a synchronous wrapper around the async hybrid search functionality,
|
||||
making it compatible with synchronous tool execution environments.
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数
|
||||
memory_config: Memory configuration object containing search settings
|
||||
**search_params: Search parameters for configuration
|
||||
|
||||
Returns:
|
||||
function: Configured HybridSearchSync tool function
|
||||
"""
|
||||
|
||||
@tool
|
||||
def HybridSearchSync(
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
|
||||
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Provides the same hybrid search capabilities as the async version but in a
|
||||
synchronous execution context. Automatically handles async-to-sync conversion.
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
context: Query content for search
|
||||
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||
limit: Result quantity limit
|
||||
end_user_id: Group ID for filtering search results
|
||||
clean_output: Whether to clean metadata fields from output
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
# 创建异步工具并执行
|
||||
# Create async tool and execute
|
||||
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
||||
return await async_tool.ainvoke({
|
||||
"context": context,
|
||||
@@ -315,7 +398,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"end_user_id": end_user_id,
|
||||
"clean_output": clean_output
|
||||
})
|
||||
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
return HybridSearchSync
|
||||
|
||||
return HybridSearchSync
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
import json
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
async def format_parsing(messages: list,type:str='string'):
|
||||
|
||||
|
||||
async def format_parsing(messages: list, type: str = 'string'):
|
||||
"""
|
||||
格式化解析消息列表
|
||||
Format and parse message lists into different output types
|
||||
|
||||
Processes message lists from storage and converts them into either string format
|
||||
or dictionary format based on the specified type parameter. Handles JSON parsing
|
||||
and role-based message organization.
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
type: 返回类型 ('string' 或 'dict')
|
||||
messages: List of message objects from storage containing message data
|
||||
type: Return type specification ('string' for text format, 'dict' for key-value pairs)
|
||||
|
||||
Returns:
|
||||
格式化后的消息列表
|
||||
list: Formatted message list in the specified format
|
||||
- 'string': List of formatted text messages with role prefixes
|
||||
- 'dict': List of dictionaries mapping user messages to AI responses
|
||||
"""
|
||||
result = []
|
||||
user=[]
|
||||
ai=[]
|
||||
user = []
|
||||
ai = []
|
||||
|
||||
for message in messages:
|
||||
hstory_messages = message['messages']
|
||||
@@ -24,25 +32,38 @@ async def format_parsing(messages: list,type:str='string'):
|
||||
role = content['role']
|
||||
content = content['content']
|
||||
if type == "string":
|
||||
if role == 'human' or role=="user":
|
||||
if role == 'human' or role == "user":
|
||||
content = '用户:' + content
|
||||
else:
|
||||
content = 'AI:' + content
|
||||
result.append(content)
|
||||
if type == "dict" :
|
||||
if role == 'human' or role=="user":
|
||||
user.append( content)
|
||||
if type == "dict":
|
||||
if role == 'human' or role == "user":
|
||||
user.append(content)
|
||||
else:
|
||||
ai.append(content)
|
||||
if type == "dict":
|
||||
for key,values in zip(user,ai):
|
||||
result.append({key:values})
|
||||
for key, values in zip(user, ai):
|
||||
result.append({key: values})
|
||||
return result
|
||||
|
||||
|
||||
async def messages_parse(messages: list | dict):
|
||||
user=[]
|
||||
ai=[]
|
||||
database=[]
|
||||
"""
|
||||
Parse messages from storage format into user-AI conversation pairs
|
||||
|
||||
Extracts and organizes conversation data from stored message format,
|
||||
separating user and AI messages and pairing them for database storage.
|
||||
|
||||
Args:
|
||||
messages: List or dictionary containing stored message data with Query fields
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing user-AI message pairs for database storage
|
||||
"""
|
||||
user = []
|
||||
ai = []
|
||||
database = []
|
||||
for message in messages:
|
||||
Query = message['Query']
|
||||
Query = json.loads(Query)
|
||||
@@ -54,10 +75,23 @@ async def messages_parse(messages: list | dict):
|
||||
ai.append(data['content'])
|
||||
for key, values in zip(user, ai):
|
||||
database.append({key, values})
|
||||
return database
|
||||
return database
|
||||
|
||||
|
||||
async def agent_chat_messages(user_content,ai_content):
|
||||
async def agent_chat_messages(user_content, ai_content):
|
||||
"""
|
||||
Create structured chat message format for agent conversations
|
||||
|
||||
Formats user and AI content into a standardized message structure suitable
|
||||
for agent processing and storage. Creates role-based message objects.
|
||||
|
||||
Args:
|
||||
user_content: User's message content string
|
||||
ai_content: AI's response content string
|
||||
|
||||
Returns:
|
||||
list: List of structured message dictionaries with role and content fields
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
@@ -13,7 +13,6 @@ from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -42,10 +41,26 @@ async def make_write_graph():
|
||||
|
||||
yield graph
|
||||
|
||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||
|
||||
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
|
||||
end_user_id: str = '', scope: int = 6):
|
||||
"""
|
||||
Handle long-term memory storage with different strategies
|
||||
|
||||
Supports multiple storage strategies including chunk-based, time-based,
|
||||
and aggregate judgment approaches for long-term memory persistence.
|
||||
|
||||
Args:
|
||||
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||
langchain_messages: List of messages to store
|
||||
memory_config: Memory configuration identifier
|
||||
end_user_id: User group identifier
|
||||
scope: Scope parameter for chunk-based storage (default: 6)
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||
write_store.save_session_write(end_user_id, langchain_messages)
|
||||
# 获取数据库会话
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
@@ -53,26 +68,39 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
|
||||
config_id=memory_config, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type=='chunk':
|
||||
'''方案一:对话窗口6轮对话'''
|
||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||
if long_term_type=='time':
|
||||
"""时间"""
|
||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||
if long_term_type=='aggregate':
|
||||
"""方案三:聚合判断"""
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||
"""Time-based strategy"""
|
||||
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
"""Strategy 3: Aggregate judgment"""
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
|
||||
"""
|
||||
Write long-term memory with different storage types
|
||||
|
||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||
Handles both RAG-based storage and traditional memory storage approaches.
|
||||
For traditional storage, uses chunk-based strategy with paired user-AI messages.
|
||||
|
||||
Args:
|
||||
storage_type: Type of storage (RAG or traditional)
|
||||
end_user_id: User group identifier
|
||||
message_chat: User message content
|
||||
aimessages: AI response messages
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_config_id: Actual configuration ID
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
else:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
@@ -101,4 +129,4 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
# asyncio.run(main())
|
||||
|
||||
@@ -8,10 +8,11 @@ from langgraph.graph import add_messages
|
||||
|
||||
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
||||
|
||||
|
||||
class WriteState(TypedDict):
|
||||
'''
|
||||
"""
|
||||
Langgrapg Writing TypedDict
|
||||
'''
|
||||
"""
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
end_user_id: str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
@@ -20,6 +21,7 @@ class WriteState(TypedDict):
|
||||
data: str
|
||||
language: str # 语言类型 ("zh" 中文, "en" 英文)
|
||||
|
||||
|
||||
class ReadState(TypedDict):
|
||||
"""
|
||||
LangGraph 工作流状态定义
|
||||
@@ -43,18 +45,20 @@ class ReadState(TypedDict):
|
||||
config_id: str
|
||||
data: str # 新增字段用于传递内容
|
||||
spit_data: dict # 新增字段用于传递问题分解结果
|
||||
problem_extension:dict
|
||||
problem_extension: dict
|
||||
storage_type: str
|
||||
user_rag_memory_id: str
|
||||
llm_id: str
|
||||
embedding_id: str
|
||||
memory_config: object # 新增字段用于传递内存配置对象
|
||||
retrieve:dict
|
||||
retrieve: dict
|
||||
RetrieveSummary: dict
|
||||
InputSummary: dict
|
||||
verify: dict
|
||||
SummaryFails: dict
|
||||
summary: dict
|
||||
|
||||
|
||||
class COUNTState:
|
||||
"""
|
||||
工作流对话检索内容计数器
|
||||
@@ -99,6 +103,7 @@ class COUNTState:
|
||||
self.total = 0
|
||||
print("[COUNTState] 已重置为 0")
|
||||
|
||||
|
||||
def deduplicate_entries(entries):
|
||||
seen = set()
|
||||
deduped = []
|
||||
@@ -109,6 +114,7 @@ def deduplicate_entries(entries):
|
||||
deduped.append(entry)
|
||||
return deduped
|
||||
|
||||
|
||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||
grouped = defaultdict(list)
|
||||
for item in data:
|
||||
@@ -142,4 +148,4 @@ def convert_extended_question_to_question(data):
|
||||
return [convert_extended_question_to_question(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
return data
|
||||
return data
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Dict, Optional
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
@@ -14,15 +14,15 @@ from app.core.memory.utils.log.logging_utils import prompt_logger
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
|
||||
class TripletExtractor:
|
||||
"""Extracts knowledge triplets and entities from statements using LLM"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
language: str = "zh"):
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
language: str = "zh"
|
||||
):
|
||||
"""Initialize the TripletExtractor with an LLM client
|
||||
|
||||
Args:
|
||||
@@ -65,7 +65,8 @@ class TripletExtractor:
|
||||
|
||||
# Create messages for LLM
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
||||
{"role": "system",
|
||||
"content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
||||
{"role": "user", "content": prompt_content}
|
||||
]
|
||||
|
||||
@@ -116,7 +117,8 @@ class TripletExtractor:
|
||||
logger.error(f"Error processing statement: {e}", exc_info=True)
|
||||
return TripletExtractionResponse(triplets=[], entities=[])
|
||||
|
||||
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]:
|
||||
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[
|
||||
str, TripletExtractionResponse]:
|
||||
"""Extract triplets and entities from statements
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
自我反思引擎实现
|
||||
Self-Reflection Engine Implementation
|
||||
|
||||
该模块实现了记忆系统的自我反思功能,包括:
|
||||
1. 基于时间的反思 - 根据时间周期触发反思
|
||||
2. 基于事实的反思 - 检测记忆冲突并解决
|
||||
3. 综合反思 - 整合多种反思策略
|
||||
4. 反思结果应用 - 更新记忆库
|
||||
This module implements the self-reflection functionality of the memory system, including:
|
||||
1. Time-based reflection - Triggers reflection based on time cycles
|
||||
2. Fact-based reflection - Detects and resolves memory conflicts
|
||||
3. Comprehensive reflection - Integrates multiple reflection strategies
|
||||
4. Reflection result application - Updates memory database
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -38,7 +38,7 @@ from app.schemas.memory_storage_schema import (
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 配置日志
|
||||
# Configure logging
|
||||
_root_logger = logging.getLogger()
|
||||
if not _root_logger.handlers:
|
||||
logging.basicConfig(
|
||||
@@ -49,35 +49,62 @@ else:
|
||||
_root_logger.setLevel(logging.INFO)
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
"""翻译响应模型"""
|
||||
"""Translation response model for language conversion"""
|
||||
data: str
|
||||
|
||||
class ReflectionRange(str, Enum):
|
||||
"""反思范围枚举"""
|
||||
PARTIAL = "partial" # 从检索结果中反思
|
||||
ALL = "all" # 从整个数据库中反思
|
||||
"""
|
||||
Reflection range enumeration
|
||||
|
||||
Defines the scope of data to be included in reflection operations.
|
||||
"""
|
||||
PARTIAL = "partial" # Reflect from retrieval results
|
||||
ALL = "all" # Reflect from entire database
|
||||
|
||||
|
||||
class ReflectionBaseline(str, Enum):
|
||||
"""反思基线枚举"""
|
||||
TIME = "TIME" # 基于时间的反思
|
||||
FACT = "FACT" # 基于事实的反思
|
||||
HYBRID = "HYBRID" # 混合反思
|
||||
"""
|
||||
Reflection baseline enumeration
|
||||
|
||||
Defines the strategy or approach used for reflection operations.
|
||||
"""
|
||||
TIME = "TIME" # Time-based reflection
|
||||
FACT = "FACT" # Fact-based reflection
|
||||
HYBRID = "HYBRID" # Hybrid reflection combining multiple strategies
|
||||
|
||||
|
||||
class ReflectionConfig(BaseModel):
|
||||
"""反思引擎配置"""
|
||||
"""
|
||||
Reflection engine configuration
|
||||
|
||||
Defines all configuration parameters for the reflection engine including
|
||||
operation modes, model settings, and evaluation criteria.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether reflection engine is enabled
|
||||
iteration_period: Reflection cycle period (e.g., "3" hours)
|
||||
reflexion_range: Scope of reflection (PARTIAL or ALL)
|
||||
baseline: Reflection strategy (TIME, FACT, or HYBRID)
|
||||
model_id: LLM model identifier for reflection operations
|
||||
end_user_id: User identifier for scoped operations
|
||||
output_example: Example output format for guidance
|
||||
memory_verify: Enable memory verification checks
|
||||
quality_assessment: Enable quality assessment evaluation
|
||||
violation_handling_strategy: Strategy for handling violations
|
||||
language_type: Language type for output ("zh" or "en")
|
||||
"""
|
||||
enabled: bool = False
|
||||
iteration_period: str = "3" # 反思周期
|
||||
iteration_period: str = "3" # Reflection cycle period
|
||||
reflexion_range: ReflectionRange = ReflectionRange.PARTIAL
|
||||
baseline: ReflectionBaseline = ReflectionBaseline.TIME
|
||||
model_id: Optional[str] = None # 模型ID
|
||||
model_id: Optional[str] = None # Model ID
|
||||
end_user_id: Optional[str] = None
|
||||
output_example: Optional[str] = None # 输出示例
|
||||
output_example: Optional[str] = None # Output example
|
||||
|
||||
# 评估相关字段
|
||||
memory_verify: bool = True # 记忆验证
|
||||
quality_assessment: bool = True # 质量评估
|
||||
violation_handling_strategy: str = "warn" # 违规处理策略
|
||||
# Evaluation related fields
|
||||
memory_verify: bool = True # Memory verification
|
||||
quality_assessment: bool = True # Quality assessment
|
||||
violation_handling_strategy: str = "warn" # Violation handling strategy
|
||||
language_type: str = "zh"
|
||||
|
||||
class Config:
|
||||
@@ -85,7 +112,21 @@ class ReflectionConfig(BaseModel):
|
||||
|
||||
|
||||
class ReflectionResult(BaseModel):
|
||||
"""反思结果"""
|
||||
"""
|
||||
Reflection operation result
|
||||
|
||||
Contains comprehensive information about the outcome of a reflection operation
|
||||
including success status, metrics, and execution details.
|
||||
|
||||
Attributes:
|
||||
success: Whether the reflection operation succeeded
|
||||
message: Descriptive message about the operation result
|
||||
conflicts_found: Number of conflicts detected during reflection
|
||||
conflicts_resolved: Number of conflicts successfully resolved
|
||||
memories_updated: Number of memory entries updated in database
|
||||
execution_time: Total time taken for the reflection operation
|
||||
details: Additional details about the operation (optional)
|
||||
"""
|
||||
success: bool
|
||||
message: str
|
||||
conflicts_found: int = 0
|
||||
@@ -97,9 +138,22 @@ class ReflectionResult(BaseModel):
|
||||
|
||||
class ReflectionEngine:
|
||||
"""
|
||||
自我反思引擎
|
||||
|
||||
负责执行记忆系统的自我反思,包括冲突检测、冲突解决和记忆更新。
|
||||
Self-Reflection Engine
|
||||
|
||||
Responsible for executing memory system self-reflection operations including
|
||||
conflict detection, conflict resolution, and memory updates. Supports multiple
|
||||
reflection strategies and provides comprehensive result tracking.
|
||||
|
||||
The engine can operate in different modes:
|
||||
- Time-based: Reflects on memories within specific time periods
|
||||
- Fact-based: Detects and resolves factual conflicts in memories
|
||||
- Hybrid: Combines multiple reflection strategies
|
||||
|
||||
Attributes:
|
||||
config: Reflection engine configuration
|
||||
neo4j_connector: Neo4j database connector
|
||||
llm_client: Language model client for analysis
|
||||
Various function handlers for data processing and prompt rendering
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -115,18 +169,21 @@ class ReflectionEngine:
|
||||
update_query: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化反思引擎
|
||||
Initialize reflection engine
|
||||
|
||||
Sets up the reflection engine with configuration and optional dependencies.
|
||||
Uses lazy initialization to avoid circular imports and optimize startup time.
|
||||
|
||||
Args:
|
||||
config: 反思引擎配置
|
||||
neo4j_connector: Neo4j 连接器(可选)
|
||||
llm_client: LLM 客户端(可选)
|
||||
get_data_func: 获取数据的函数(可选)
|
||||
render_evaluate_prompt_func: 渲染评估提示词的函数(可选)
|
||||
render_reflexion_prompt_func: 渲染反思提示词的函数(可选)
|
||||
conflict_schema: 冲突结果 Schema(可选)
|
||||
reflexion_schema: 反思结果 Schema(可选)
|
||||
update_query: 更新查询语句(可选)
|
||||
config: Reflection engine configuration object
|
||||
neo4j_connector: Neo4j connector instance (optional, will be created if not provided)
|
||||
llm_client: LLM client instance (optional, will be created if not provided)
|
||||
get_data_func: Function for retrieving data (optional, uses default if not provided)
|
||||
render_evaluate_prompt_func: Function for rendering evaluation prompts (optional)
|
||||
render_reflexion_prompt_func: Function for rendering reflection prompts (optional)
|
||||
conflict_schema: Schema for conflict result validation (optional)
|
||||
reflexion_schema: Schema for reflection result validation (optional)
|
||||
update_query: Query string for database updates (optional)
|
||||
"""
|
||||
self.config = config
|
||||
self.neo4j_connector = neo4j_connector
|
||||
@@ -137,14 +194,20 @@ class ReflectionEngine:
|
||||
self.conflict_schema = conflict_schema
|
||||
self.reflexion_schema = reflexion_schema
|
||||
self.update_query = update_query
|
||||
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5
|
||||
self._semaphore = asyncio.Semaphore(5) # Default concurrency limit of 5
|
||||
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
# Lazy import to avoid circular dependencies
|
||||
self._lazy_init_done = False
|
||||
|
||||
def _lazy_init(self):
|
||||
"""延迟初始化,避免循环导入"""
|
||||
"""
|
||||
Lazy initialization to avoid circular imports
|
||||
|
||||
Initializes dependencies only when needed, preventing circular import issues
|
||||
and optimizing startup performance. Sets up default implementations for
|
||||
any components not provided during construction.
|
||||
"""
|
||||
if self._lazy_init_done:
|
||||
return
|
||||
|
||||
@@ -158,7 +221,7 @@ class ReflectionEngine:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(self.config.model_id)
|
||||
elif isinstance(self.llm_client, str):
|
||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||
# If llm_client is a string (model_id), use it to initialize the client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
@@ -172,10 +235,10 @@ class ReflectionEngine:
|
||||
model_config = config_service.get_model_config(model_id)
|
||||
|
||||
extra_params={
|
||||
"temperature": 0.2, # 降低温度提高响应速度和一致性
|
||||
"max_tokens": 600, # 限制最大token数
|
||||
"top_p": 0.8, # 优化采样参数
|
||||
"stream": False, # 确保非流式输出以获得最快响应
|
||||
"temperature": 0.2, # Lower temperature for faster response and consistency
|
||||
"max_tokens": 600, # Limit maximum token count
|
||||
"top_p": 0.8, # Optimize sampling parameters
|
||||
"stream": False, # Ensure non-streaming output for fastest response
|
||||
}
|
||||
|
||||
self.llm_client = OpenAIClient(RedBearModelConfig(
|
||||
@@ -191,7 +254,7 @@ class ReflectionEngine:
|
||||
if self.get_data_func is None:
|
||||
self.get_data_func = get_data
|
||||
|
||||
# 导入get_data_statement函数
|
||||
# Import get_data_statement function
|
||||
if not hasattr(self, 'get_data_statement'):
|
||||
self.get_data_statement = get_data_statement
|
||||
|
||||
@@ -223,13 +286,20 @@ class ReflectionEngine:
|
||||
|
||||
async def execute_reflection(self, host_id) -> ReflectionResult:
|
||||
"""
|
||||
执行完整的反思流程
|
||||
Execute complete reflection workflow
|
||||
|
||||
Performs the full reflection process including data retrieval, conflict detection,
|
||||
conflict resolution, and memory updates. This is the main entry point for
|
||||
reflection operations.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host identifier for scoping reflection operations
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive result of the reflection operation including
|
||||
success status, conflict metrics, and execution time
|
||||
"""
|
||||
# 延迟初始化
|
||||
# Lazy initialization
|
||||
self._lazy_init()
|
||||
|
||||
if not self.config.enabled:
|
||||
@@ -243,7 +313,7 @@ class ReflectionEngine:
|
||||
|
||||
print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment)
|
||||
try:
|
||||
# 1. 获取反思数据
|
||||
# 1. Get reflection data
|
||||
reflexion_data, statement_databasets = await self._get_reflexion_data(host_id)
|
||||
if not reflexion_data:
|
||||
return ReflectionResult(
|
||||
@@ -252,7 +322,7 @@ class ReflectionEngine:
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
# 2. Detect conflicts (fact-based reflection)
|
||||
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
|
||||
conflict_list=[]
|
||||
for i in conflict_data:
|
||||
@@ -261,7 +331,7 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
conflicts_found=0
|
||||
# 3. 解决冲突
|
||||
# 3. Resolve conflicts
|
||||
solved_data = await self._resolve_conflicts(conflict_list, statement_databasets)
|
||||
|
||||
if not solved_data:
|
||||
@@ -276,7 +346,7 @@ class ReflectionEngine:
|
||||
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
||||
|
||||
|
||||
# 4. 应用反思结果(更新记忆库)
|
||||
# 4. Apply reflection results (update memory database)
|
||||
memories_updated=await self._apply_reflection_results(solved_data)
|
||||
|
||||
execution_time = asyncio.get_event_loop().time() - start_time
|
||||
@@ -302,7 +372,19 @@ class ReflectionEngine:
|
||||
)
|
||||
|
||||
async def Translate(self, text):
|
||||
# 翻译中文为英文
|
||||
"""
|
||||
Translate Chinese text to English
|
||||
|
||||
Uses the configured LLM to translate Chinese text to English with structured output.
|
||||
Provides consistent translation format for reflection results.
|
||||
|
||||
Args:
|
||||
text: Chinese text to be translated
|
||||
|
||||
Returns:
|
||||
str: Translated English text
|
||||
"""
|
||||
# Translate Chinese to English
|
||||
translation_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -316,6 +398,19 @@ class ReflectionEngine:
|
||||
)
|
||||
return response.data
|
||||
async def extract_translation(self,data):
|
||||
"""
|
||||
Extract and translate reflection data to English
|
||||
|
||||
Processes reflection data structure and translates all Chinese content to English.
|
||||
Handles nested data structures including memory verifications, quality assessments,
|
||||
and reflection data while preserving the original structure.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing reflection data with Chinese content
|
||||
|
||||
Returns:
|
||||
dict: Translated data structure with English content
|
||||
"""
|
||||
end_datas={}
|
||||
end_datas['source_data']=await self.Translate(data['source_data'])
|
||||
quality_assessments = []
|
||||
@@ -350,6 +445,18 @@ class ReflectionEngine:
|
||||
return end_datas
|
||||
|
||||
async def reflection_run(self):
|
||||
"""
|
||||
Execute reflection workflow with comprehensive data processing
|
||||
|
||||
Performs a complete reflection operation including conflict detection, resolution,
|
||||
and result formatting. Supports both Chinese and English output based on
|
||||
configuration settings.
|
||||
|
||||
Returns:
|
||||
dict: Comprehensive reflection results including source data, memory verifications,
|
||||
quality assessments, and reflection data. Results are translated to English
|
||||
if language_type is set to 'en'.
|
||||
"""
|
||||
self._lazy_init()
|
||||
start_time = time.time()
|
||||
memory_verifies_flag = self.config.memory_verify
|
||||
@@ -367,7 +474,7 @@ class ReflectionEngine:
|
||||
result_data['source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(databasets, source_data)
|
||||
# 遍历数据提取字段
|
||||
# Traverse data to extract fields
|
||||
quality_assessments = []
|
||||
memory_verifies = []
|
||||
for item in conflict_data:
|
||||
@@ -375,9 +482,9 @@ class ReflectionEngine:
|
||||
memory_verifies.append(item['memory_verify'])
|
||||
result_data['memory_verifies'] = memory_verifies
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
conflicts_found = 0 # 初始化为整数0而不是空字符串
|
||||
conflicts_found = 0 # Initialize as integer 0 instead of empty string
|
||||
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
||||
# Clearn conflict_data,And memory_verify和quality_assessment
|
||||
# Clean conflict_data, and memory_verify and quality_assessment
|
||||
cleaned_conflict_data = []
|
||||
for item in conflict_data:
|
||||
cleaned_item = {
|
||||
@@ -389,7 +496,7 @@ class ReflectionEngine:
|
||||
for item in conflict_data:
|
||||
cleaned_data = []
|
||||
for row in item.get("data", []):
|
||||
# 删除 created_at / expired_at
|
||||
# Remove created_at / expired_at
|
||||
cleaned_row = {
|
||||
k: v
|
||||
for k, v in row.items()
|
||||
@@ -402,7 +509,7 @@ class ReflectionEngine:
|
||||
}
|
||||
cleaned_conflict_data_.append(cleaned_item)
|
||||
print(cleaned_conflict_data_)
|
||||
# 3. 解决冲突
|
||||
# 3. Resolve conflicts
|
||||
solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data)
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
@@ -413,7 +520,7 @@ class ReflectionEngine:
|
||||
)
|
||||
reflexion_data = []
|
||||
|
||||
# 遍历数据提取reflexion字段
|
||||
# Traverse data to extract reflexion fields
|
||||
for item in solved_data:
|
||||
if 'results' in item:
|
||||
for result in item['results']:
|
||||
@@ -431,15 +538,24 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
async def extract_fields_from_json(self):
|
||||
"""从example.json中提取source_data和databasets字段"""
|
||||
"""
|
||||
Extract source_data and databasets fields from example.json
|
||||
|
||||
Reads reflection example data from the example.json file and extracts
|
||||
the source data and database statements for testing and demonstration purposes.
|
||||
|
||||
Returns:
|
||||
tuple: (source_data, databasets) extracted from the example file
|
||||
Returns empty lists if file reading fails
|
||||
"""
|
||||
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "example")
|
||||
try:
|
||||
# 读取JSON文件
|
||||
# Read JSON file
|
||||
with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f:
|
||||
data = json.loads(f.read())
|
||||
|
||||
# 提取memory_verify下的字段
|
||||
# Extract fields under memory_verify
|
||||
memory_verify = data.get("memory_verify", {})
|
||||
source_data = memory_verify.get("source_data", [])
|
||||
databasets = memory_verify.get("databasets", [])
|
||||
@@ -451,15 +567,17 @@ class ReflectionEngine:
|
||||
|
||||
async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]:
|
||||
"""
|
||||
获取反思数据
|
||||
|
||||
根据配置的反思范围获取需要反思的记忆数据。
|
||||
Get reflection data from database
|
||||
|
||||
Retrieves memory data for reflection based on the configured reflection range.
|
||||
Supports both partial (from retrieval results) and full (entire database) modes.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping data retrieval
|
||||
|
||||
Returns:
|
||||
List[Any]: 反思数据列表
|
||||
tuple: (reflexion_data, statement_data) containing memory data for reflection
|
||||
Returns empty lists if query fails
|
||||
"""
|
||||
|
||||
print("=== 获取反思数据 ===")
|
||||
@@ -484,26 +602,29 @@ class ReflectionEngine:
|
||||
|
||||
async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
检测冲突(基于事实的反思)
|
||||
|
||||
使用 LLM 分析记忆数据,检测其中的冲突。
|
||||
Detect conflicts (fact-based reflection)
|
||||
|
||||
Uses LLM to analyze memory data and detect conflicts within the memories.
|
||||
Performs comprehensive conflict detection including memory verification and
|
||||
quality assessment based on configuration settings.
|
||||
|
||||
Args:
|
||||
data: 待检测的记忆数据
|
||||
data: Memory data to be analyzed for conflicts
|
||||
statement_databasets: Statement database records for context
|
||||
|
||||
Returns:
|
||||
List[Any]: 冲突记忆列表
|
||||
List[Any]: List of detected conflicts with detailed analysis
|
||||
"""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
# 数据预处理:如果数据量太少,直接返回无冲突
|
||||
# Data preprocessing: if data is too small, return no conflicts directly
|
||||
if len(data) < 2:
|
||||
logging.info("数据量不足,无需检测冲突")
|
||||
return []
|
||||
|
||||
# 使用转换后的数据
|
||||
# print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
|
||||
# Use converted data
|
||||
# print("Converted data:", data[:2] if len(data) > 2 else data) # Only print first 2 to avoid long logs
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
logging.info("====== 冲突检测开始 ======")
|
||||
@@ -512,7 +633,7 @@ class ReflectionEngine:
|
||||
language_type=self.config.language_type
|
||||
|
||||
try:
|
||||
# 渲染冲突检测提示词
|
||||
# Render conflict detection prompt
|
||||
rendered_prompt = await self.render_evaluate_prompt_func(
|
||||
data,
|
||||
self.conflict_schema,
|
||||
@@ -526,7 +647,7 @@ class ReflectionEngine:
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
logging.info(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
# 调用 LLM 进行冲突检测
|
||||
# Call LLM for conflict detection
|
||||
response = await self.llm_client.response_structured(
|
||||
messages,
|
||||
self.conflict_schema
|
||||
@@ -539,7 +660,7 @@ class ReflectionEngine:
|
||||
logging.error("LLM 冲突检测输出解析失败")
|
||||
return []
|
||||
|
||||
# 标准化返回格式
|
||||
# Standardize return format
|
||||
if isinstance(response, BaseModel):
|
||||
return [response.model_dump()]
|
||||
elif hasattr(response, 'dict'):
|
||||
@@ -553,15 +674,17 @@ class ReflectionEngine:
|
||||
|
||||
async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
解决冲突
|
||||
|
||||
使用 LLM 对检测到的冲突进行反思和解决。
|
||||
Resolve detected conflicts
|
||||
|
||||
Uses LLM to perform reflection and resolution on detected conflicts.
|
||||
Processes conflicts in parallel for efficiency while respecting concurrency limits.
|
||||
|
||||
Args:
|
||||
conflicts: 冲突列表
|
||||
conflicts: List of conflicts to be resolved
|
||||
statement_databasets: Statement database records for context
|
||||
|
||||
Returns:
|
||||
List[Any]: 解决方案列表
|
||||
List[Any]: List of resolution solutions with reflection results
|
||||
"""
|
||||
if not conflicts:
|
||||
return []
|
||||
@@ -570,12 +693,12 @@ class ReflectionEngine:
|
||||
baseline = self.config.baseline
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
# 并行处理每个冲突
|
||||
# Process each conflict in parallel
|
||||
async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]:
|
||||
"""解决单个冲突"""
|
||||
"""Resolve a single conflict"""
|
||||
async with self._semaphore:
|
||||
try:
|
||||
# 渲染反思提示词
|
||||
# Render reflection prompt
|
||||
rendered_prompt = await self.render_reflexion_prompt_func(
|
||||
[conflict],
|
||||
self.reflexion_schema,
|
||||
@@ -587,7 +710,7 @@ class ReflectionEngine:
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
|
||||
# 调用 LLM 进行反思
|
||||
# Call LLM for reflection
|
||||
response = await self.llm_client.response_structured(
|
||||
messages,
|
||||
self.reflexion_schema
|
||||
@@ -596,7 +719,7 @@ class ReflectionEngine:
|
||||
if not response:
|
||||
return None
|
||||
|
||||
# 标准化返回格式
|
||||
# Standardize return format
|
||||
if isinstance(response, BaseModel):
|
||||
return response.model_dump()
|
||||
elif hasattr(response, 'dict'):
|
||||
@@ -610,11 +733,11 @@ class ReflectionEngine:
|
||||
logging.warning(f"解决单个冲突失败: {e}")
|
||||
return None
|
||||
|
||||
# 并发执行所有冲突解决任务
|
||||
# Execute all conflict resolution tasks concurrently
|
||||
tasks = [_resolve_one(conflict) for conflict in conflicts]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
# 过滤掉失败的结果
|
||||
# Filter out failed results
|
||||
solved = [r for r in results if r is not None]
|
||||
|
||||
logging.info(f"成功解决 {len(solved)}/{len(conflicts)} 个冲突")
|
||||
@@ -626,15 +749,16 @@ class ReflectionEngine:
|
||||
solved_data: List[Dict[str, Any]]
|
||||
) -> int:
|
||||
"""
|
||||
应用反思结果(更新记忆库)
|
||||
|
||||
将解决冲突后的记忆更新到 Neo4j 数据库中。
|
||||
Apply reflection results (update memory database)
|
||||
|
||||
Updates the Neo4j database with resolved conflicts and reflection results.
|
||||
Processes the solved data and applies changes to the memory storage system.
|
||||
|
||||
Args:
|
||||
solved_data: 解决方案列表
|
||||
solved_data: List of resolved conflict solutions with reflection data
|
||||
|
||||
Returns:
|
||||
int: 成功更新的记忆数量
|
||||
int: Number of successfully updated memory entries
|
||||
"""
|
||||
changes = extract_and_process_changes(solved_data)
|
||||
success_count = await neo4j_data(changes)
|
||||
@@ -642,80 +766,86 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
|
||||
# 基于时间的反思方法
|
||||
# Time-based reflection methods
|
||||
async def time_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID,
|
||||
time_period: Optional[str] = None
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于时间的反思
|
||||
|
||||
根据时间周期触发反思,检查在指定时间段内的记忆。
|
||||
Time-based reflection
|
||||
|
||||
Triggers reflection based on time cycles, checking memories within
|
||||
specified time periods. Uses the configured iteration period if
|
||||
no specific time period is provided.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
time_period: 时间周期(如"三小时"),如果不提供则使用配置中的值
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
time_period: Time period (e.g., "three hours"), uses config value if not provided
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result
|
||||
"""
|
||||
period = time_period or self.config.iteration_period
|
||||
logging.info(f"执行基于时间的反思,周期: {period}")
|
||||
|
||||
# 使用标准反思流程
|
||||
# Use standard reflection workflow
|
||||
return await self.execute_reflection(host_id)
|
||||
|
||||
# 基于事实的反思方法
|
||||
# Fact-based reflection methods
|
||||
async def fact_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于事实的反思
|
||||
|
||||
检测记忆中的事实冲突并解决。
|
||||
Fact-based reflection
|
||||
|
||||
Detects and resolves factual conflicts within memories. Analyzes
|
||||
memory data for inconsistencies and contradictions that need resolution.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result
|
||||
"""
|
||||
logging.info("执行基于事实的反思")
|
||||
|
||||
# 使用标准反思流程
|
||||
# Use standard reflection workflow
|
||||
return await self.execute_reflection(host_id)
|
||||
|
||||
# 综合反思方法
|
||||
# Comprehensive reflection methods
|
||||
async def comprehensive_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
综合反思
|
||||
|
||||
整合基于时间和基于事实的反思策略。
|
||||
Comprehensive reflection
|
||||
|
||||
Integrates time-based and fact-based reflection strategies based on
|
||||
the configured baseline. Supports hybrid approaches that combine
|
||||
multiple reflection methodologies.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result combining
|
||||
multiple strategies if using hybrid baseline
|
||||
"""
|
||||
logging.info("执行综合反思")
|
||||
|
||||
# 根据配置的基线选择反思策略
|
||||
# Choose reflection strategy based on configured baseline
|
||||
if self.config.baseline == ReflectionBaseline.TIME:
|
||||
return await self.time_based_reflection(host_id)
|
||||
elif self.config.baseline == ReflectionBaseline.FACT:
|
||||
return await self.fact_based_reflection(host_id)
|
||||
elif self.config.baseline == ReflectionBaseline.HYBRID:
|
||||
# 混合策略:先执行基于时间的反思,再执行基于事实的反思
|
||||
# Hybrid strategy: execute time-based reflection first, then fact-based reflection
|
||||
time_result = await self.time_based_reflection(host_id)
|
||||
fact_result = await self.fact_based_reflection(host_id)
|
||||
|
||||
# 合并结果
|
||||
# Merge results
|
||||
return ReflectionResult(
|
||||
success=time_result.success and fact_result.success,
|
||||
message=f"时间反思: {time_result.message}; 事实反思: {fact_result.message}",
|
||||
|
||||
@@ -2,9 +2,17 @@ import json
|
||||
|
||||
|
||||
def escape_lucene_query(query: str) -> str:
|
||||
"""Escape Lucene special characters in a free-text query.
|
||||
|
||||
This prevents ParseException when using Neo4j full-text procedures.
|
||||
"""
|
||||
Escape special characters in Lucene queries
|
||||
|
||||
Prevents ParseException when using Neo4j full-text search procedures.
|
||||
Escapes all Lucene reserved special characters and operators.
|
||||
|
||||
Args:
|
||||
query: Original query string
|
||||
|
||||
Returns:
|
||||
str: Escaped query string safe for Lucene search
|
||||
"""
|
||||
if query is None:
|
||||
return ""
|
||||
@@ -22,11 +30,21 @@ def escape_lucene_query(query: str) -> str:
|
||||
return s
|
||||
|
||||
def extract_plain_query(query_input: str) -> str:
|
||||
"""Extract clean, plain-text query from various input forms.
|
||||
|
||||
"""
|
||||
Extract clean plain-text query from various input forms
|
||||
|
||||
Handles the following cases:
|
||||
- Strips surrounding quotes and whitespace
|
||||
- If input looks like JSON, prefers the 'original' field
|
||||
- Fallbacks to the raw string when parsing fails
|
||||
- Falls back to raw string when parsing fails
|
||||
- Handles dictionary-type input
|
||||
- Best-effort unescape common escape characters
|
||||
|
||||
Args:
|
||||
query_input: Query input in various forms (string, dict, etc.)
|
||||
|
||||
Returns:
|
||||
str: Extracted plain-text query string
|
||||
"""
|
||||
if query_input is None:
|
||||
return ""
|
||||
|
||||
@@ -4,7 +4,13 @@ from datetime import datetime
|
||||
|
||||
def validate_date_format(date_str: str) -> bool:
|
||||
"""
|
||||
Validate if the date string is in the format YYYY-MM-DD.
|
||||
Validate if date string conforms to YYYY-MM-DD format
|
||||
|
||||
Args:
|
||||
date_str: Date string to validate
|
||||
|
||||
Returns:
|
||||
bool: True if format is correct, False otherwise
|
||||
"""
|
||||
pattern = r"^\d{4}-\d{1,2}-\d{1,2}$"
|
||||
return bool(re.match(pattern, date_str))
|
||||
@@ -41,7 +47,20 @@ def normalize_date(date_str: str) -> str:
|
||||
|
||||
|
||||
def preprocess_date_string(date_str: str) -> str:
|
||||
"""预处理日期字符串,处理特殊格式"""
|
||||
"""
|
||||
预处理日期字符串,处理特殊格式
|
||||
|
||||
处理以下特殊格式:
|
||||
- 年份后直接跟月份没有分隔符的格式(如 "20259/28")
|
||||
- 无分隔符的纯数字格式(如 "20251028", "251028")
|
||||
- 混合分隔符,统一为 "-"
|
||||
|
||||
Args:
|
||||
date_str: 原始日期字符串
|
||||
|
||||
Returns:
|
||||
str: 预处理后的日期字符串,格式为 "YYYY-MM-DD" 或 "YYYY-MM"
|
||||
"""
|
||||
|
||||
# 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔)
|
||||
match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str)
|
||||
@@ -78,7 +97,23 @@ def preprocess_date_string(date_str: str) -> str:
|
||||
|
||||
|
||||
def fallback_parse(date_str: str) -> str:
|
||||
"""备选解析方案"""
|
||||
"""
|
||||
备选日期解析方案
|
||||
|
||||
当智能解析失败时,尝试使用预定义的日期格式进行解析。
|
||||
支持多种常见的日期格式,包括:
|
||||
- YYYY-MM-DD, YYYY/MM/DD, YYYY.MM.DD
|
||||
- YYYYMMDD, YYMMDD
|
||||
- MM-DD-YYYY, MM/DD/YYYY, MM.DD.YYYY
|
||||
- DD-MM-YYYY, DD/MM/YYYY, DD.MM.YYYY
|
||||
- YYYY-MM, YYYY/MM, YYYY.MM
|
||||
|
||||
Args:
|
||||
date_str: 待解析的日期字符串
|
||||
|
||||
Returns:
|
||||
str: 标准化后的日期字符串(YYYY-MM-DD格式),解析失败时返回原字符串
|
||||
"""
|
||||
|
||||
# 尝试常见的日期格式[citation:4][citation:5]
|
||||
formats_to_try = [
|
||||
|
||||
@@ -2,15 +2,15 @@ import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
# Setup Jinja2 environment
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
baseline: str = "TIME",
|
||||
memory_verify: bool = False,quality_assessment:bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
memory_verify: bool = False, quality_assessment: bool = False,
|
||||
statement_databasets=None, language_type: str = "zh") -> str:
|
||||
"""
|
||||
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
|
||||
|
||||
@@ -23,6 +23,8 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
if statement_databasets is None:
|
||||
statement_databasets = []
|
||||
template = prompt_env.get_template("evaluate.jinja2")
|
||||
|
||||
# Convert Pydantic model to JSON schema if needed
|
||||
@@ -46,7 +48,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
|
||||
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
statement_databasets=None, language_type: str = "zh") -> str:
|
||||
"""
|
||||
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
||||
|
||||
@@ -58,6 +60,8 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
|
||||
Returns:
|
||||
Rendered prompt content as a string.
|
||||
"""
|
||||
if statement_databasets is None:
|
||||
statement_databasets = []
|
||||
template = prompt_env.get_template("reflexion.jinja2")
|
||||
|
||||
# Convert Pydantic model to JSON schema if needed
|
||||
@@ -69,7 +73,7 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
|
||||
json_schema = schema
|
||||
|
||||
rendered_prompt = template.render(data=data, json_schema=json_schema,
|
||||
baseline=baseline,memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets,language_type=language_type)
|
||||
baseline=baseline, memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets, language_type=language_type)
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -1,23 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
|
||||
from langchain_aws import ChatBedrock
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_openai import ChatOpenAI, OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import httpx
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLanguageModel, BaseLLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -163,25 +159,17 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
if type == ModelType.LLM:
|
||||
from langchain_openai import OpenAI
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaLLM
|
||||
return OllamaLLM
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
from langchain_aws import ChatBedrock, ChatBedrockConverse
|
||||
|
||||
return ChatBedrock
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
|
||||
from app.schemas import FileInput
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -620,11 +621,12 @@ class BaseNode(ABC):
|
||||
|
||||
@staticmethod
|
||||
async def process_message(
|
||||
provider: str,
|
||||
is_omni: bool,
|
||||
api_config: ModelInfo,
|
||||
content: str | dict | FileObject,
|
||||
end_user_id: str,
|
||||
enable_file=False
|
||||
) -> list | str | None:
|
||||
provider = api_config.provider
|
||||
if isinstance(content, dict):
|
||||
content = FileObject(
|
||||
type=content.get("type"),
|
||||
@@ -643,7 +645,7 @@ class BaseNode(ABC):
|
||||
if content.content_cache.get(provider):
|
||||
return content.content_cache[provider]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
|
||||
multimodel_service = MultimodalService(db, api_config=api_config)
|
||||
file_obj = FileInput(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
@@ -653,7 +655,8 @@ class BaseNode(ABC):
|
||||
)
|
||||
file_obj.set_content(content.get_content())
|
||||
message = await multimodel_service.process_files(
|
||||
[file_obj]
|
||||
end_user_id,
|
||||
[file_obj],
|
||||
)
|
||||
content.set_content(file_obj.get_content())
|
||||
if message:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
@@ -23,6 +23,26 @@ class IfElseNode(BaseNode):
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
result = []
|
||||
for case in self.typed_config.cases:
|
||||
expressions = []
|
||||
for expression in case.expressions:
|
||||
expressions.append({
|
||||
"left": self.get_variable(expression.left, variable_pool, strict=False),
|
||||
"right": expression.right
|
||||
if expression.input_type == ValueInputType.CONSTANT
|
||||
else self.get_variable(expression.right, variable_pool, strict=False),
|
||||
"operator": expression.operator,
|
||||
})
|
||||
result.append({
|
||||
"expressions": expressions,
|
||||
"logical_operator": case.logical_operator,
|
||||
})
|
||||
return {
|
||||
"cases": result
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||
match operator:
|
||||
|
||||
@@ -30,6 +30,12 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
"output": VariableType.ARRAY_STRING
|
||||
}
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
return {
|
||||
"query": self._render_template(self.typed_config.query, variable_pool),
|
||||
"knowledge_bases": [kb_config.model_dump(mode="json") for kb_config in self.typed_config.knowledge_bases],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
|
||||
"""
|
||||
|
||||
@@ -20,6 +20,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -113,12 +114,15 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 在 Session 关闭前提取所有需要的数据
|
||||
api_config = self.model_balance(config)
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
model_info = ModelInfo(
|
||||
model_name=api_config.model_name,
|
||||
model_type=ModelType(config.type),
|
||||
api_key=api_config.api_key,
|
||||
api_base=api_config.api_base,
|
||||
provider=api_config.provider,
|
||||
is_omni=api_config.is_omni,
|
||||
capability=api_config.capability
|
||||
)
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||
@@ -126,17 +130,18 @@ class LLMNode(BaseNode):
|
||||
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
model_name=model_info.model_name,
|
||||
provider=model_info.provider,
|
||||
api_key=model_info.api_key,
|
||||
base_url=model_info.api_base,
|
||||
extra_params=extra_params,
|
||||
is_omni=is_omni
|
||||
is_omni=model_info.is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
type=model_info.model_type
|
||||
)
|
||||
|
||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||
logger.debug(
|
||||
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
||||
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
@@ -148,35 +153,40 @@ class LLMNode(BaseNode):
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, variable_pool)
|
||||
content = self._render_template(content_template, variable_pool)
|
||||
|
||||
user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(
|
||||
model_info,
|
||||
content,
|
||||
user_id,
|
||||
self.typed_config.vision,
|
||||
)
|
||||
})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
})
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||
for file in files.value:
|
||||
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
@@ -190,14 +200,19 @@ class LLMNode(BaseNode):
|
||||
if isinstance(message["content"], list):
|
||||
file_content = []
|
||||
for file in message["content"]:
|
||||
content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
history_message.append(
|
||||
{"role": message["role"], "content": file_content}
|
||||
)
|
||||
else:
|
||||
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
|
||||
message["content"] = await self.process_message(
|
||||
model_info,
|
||||
message["content"],
|
||||
user_id,
|
||||
self.typed_config.vision
|
||||
)
|
||||
history_message.append(message)
|
||||
messages = messages[:-1] + history_message + messages[-1:]
|
||||
self.messages = messages
|
||||
@@ -293,7 +308,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
last_meta_data = {}
|
||||
async for chunk in llm.astream(self.messages, stream_usage=True):
|
||||
async for chunk in llm.astream(self.messages):
|
||||
# 提取内容
|
||||
if hasattr(chunk, 'content'):
|
||||
content = self.process_model_output(chunk.content)
|
||||
|
||||
@@ -37,6 +37,14 @@ class ParameterExtractorNode(BaseNode):
|
||||
}
|
||||
return None
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
return {
|
||||
"text": self._render_template(self.typed_config.text, variable_pool),
|
||||
"prompt": self._render_template(self.typed_config.prompt, variable_pool),
|
||||
"params": [param.model_dump(mode="json") for param in self.typed_config.params],
|
||||
"model_id": str(self.typed_config.model_id),
|
||||
}
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
outputs = {}
|
||||
for param in self.typed_config.params:
|
||||
|
||||
61
api/app/i18n/README.md
Normal file
61
api/app/i18n/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Internationalization (i18n) Module
|
||||
|
||||
This module provides internationalization support for the MemoryBear API.
|
||||
|
||||
## Components
|
||||
|
||||
- `service.py` - Translation service and core translation logic
|
||||
- `middleware.py` - Language detection middleware
|
||||
- `dependencies.py` - FastAPI dependency injection functions
|
||||
- `exceptions.py` - Internationalized exception classes
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Translation
|
||||
|
||||
```python
|
||||
from app.i18n import t
|
||||
|
||||
# Simple translation
|
||||
message = t("common.success.created")
|
||||
|
||||
# Parameterized translation
|
||||
message = t("common.validation.required", field="Name")
|
||||
```
|
||||
|
||||
### Enum Translation
|
||||
|
||||
```python
|
||||
from app.i18n import t_enum
|
||||
|
||||
# Translate enum value
|
||||
role_display = t_enum("workspace_role", "manager")
|
||||
```
|
||||
|
||||
### In FastAPI Endpoints
|
||||
|
||||
```python
|
||||
from fastapi import Depends
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
@router.post("/workspaces")
|
||||
async def create_workspace(
|
||||
data: WorkspaceCreate,
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
workspace = await workspace_service.create(data)
|
||||
return {
|
||||
"success": True,
|
||||
"message": t("workspace.created_successfully"),
|
||||
"data": workspace
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
See `app/core/config.py` for i18n configuration options:
|
||||
|
||||
- `I18N_DEFAULT_LANGUAGE` - Default language (default: "zh")
|
||||
- `I18N_SUPPORTED_LANGUAGES` - Supported languages (default: "zh,en")
|
||||
- `I18N_ENABLE_TRANSLATION_CACHE` - Enable caching (default: true)
|
||||
- `I18N_LOG_MISSING_TRANSLATIONS` - Log missing translations (default: true)
|
||||
124
api/app/i18n/__init__.py
Normal file
124
api/app/i18n/__init__.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Internationalization (i18n) module for MemoryBear Enterprise.
|
||||
|
||||
This module provides complete i18n support for the backend API including:
|
||||
- Translation loading from multiple directories (community + enterprise)
|
||||
- Translation service with caching and fallback
|
||||
- Language detection middleware
|
||||
- Dependency injection for FastAPI
|
||||
- Convenience functions for easy usage
|
||||
|
||||
Usage:
|
||||
from app.i18n import t, t_enum
|
||||
|
||||
# Simple translation
|
||||
message = t("common.success.created")
|
||||
|
||||
# Parameterized translation
|
||||
error = t("common.validation.required", field="名称")
|
||||
|
||||
# Enum translation
|
||||
role_display = t_enum("workspace_role", "manager")
|
||||
"""
|
||||
|
||||
from app.i18n.dependencies import (
|
||||
get_current_language,
|
||||
get_enum_translator,
|
||||
get_translator,
|
||||
)
|
||||
from app.i18n.exceptions import (
|
||||
BadRequestError,
|
||||
ConflictError,
|
||||
FileNotFoundError,
|
||||
FileTooLargeError,
|
||||
ForbiddenError,
|
||||
I18nException,
|
||||
InternalServerError,
|
||||
InvalidCredentialsError,
|
||||
InvalidFileTypeError,
|
||||
NotFoundError,
|
||||
QuotaExceededError,
|
||||
RateLimitExceededError,
|
||||
ServiceUnavailableError,
|
||||
TenantNotFoundError,
|
||||
TenantSuspendedError,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
UnauthorizedError,
|
||||
UserAlreadyExistsError,
|
||||
UserNotFoundError,
|
||||
ValidationError,
|
||||
WorkspaceNotFoundError,
|
||||
WorkspacePermissionDeniedError,
|
||||
get_current_locale,
|
||||
set_current_locale,
|
||||
)
|
||||
from app.i18n.loader import TranslationLoader
|
||||
from app.i18n.logger import (
|
||||
TranslationLogger,
|
||||
get_translation_logger,
|
||||
log_missing_translation,
|
||||
log_translation_error,
|
||||
)
|
||||
from app.i18n.middleware import LanguageMiddleware
|
||||
from app.i18n.serializers import (
|
||||
I18nResponseMixin,
|
||||
WorkspaceSerializer,
|
||||
WorkspaceMemberSerializer,
|
||||
WorkspaceInviteSerializer,
|
||||
)
|
||||
from app.i18n.service import (
|
||||
TranslationService,
|
||||
get_translation_service,
|
||||
t,
|
||||
t_enum,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TranslationLoader",
|
||||
"LanguageMiddleware",
|
||||
"TranslationService",
|
||||
"get_translation_service",
|
||||
"t",
|
||||
"t_enum",
|
||||
"get_current_language",
|
||||
"get_translator",
|
||||
"get_enum_translator",
|
||||
# Context management
|
||||
"get_current_locale",
|
||||
"set_current_locale",
|
||||
# Logging
|
||||
"TranslationLogger",
|
||||
"get_translation_logger",
|
||||
"log_missing_translation",
|
||||
"log_translation_error",
|
||||
# Serializers
|
||||
"I18nResponseMixin",
|
||||
"WorkspaceSerializer",
|
||||
"WorkspaceMemberSerializer",
|
||||
"WorkspaceInviteSerializer",
|
||||
# Exception classes
|
||||
"I18nException",
|
||||
"BadRequestError",
|
||||
"UnauthorizedError",
|
||||
"ForbiddenError",
|
||||
"NotFoundError",
|
||||
"ConflictError",
|
||||
"ValidationError",
|
||||
"InternalServerError",
|
||||
"ServiceUnavailableError",
|
||||
"WorkspaceNotFoundError",
|
||||
"WorkspacePermissionDeniedError",
|
||||
"UserNotFoundError",
|
||||
"UserAlreadyExistsError",
|
||||
"TenantNotFoundError",
|
||||
"TenantSuspendedError",
|
||||
"InvalidCredentialsError",
|
||||
"TokenExpiredError",
|
||||
"TokenInvalidError",
|
||||
"FileNotFoundError",
|
||||
"FileTooLargeError",
|
||||
"InvalidFileTypeError",
|
||||
"RateLimitExceededError",
|
||||
"QuotaExceededError",
|
||||
]
|
||||
291
api/app/i18n/cache.py
Normal file
291
api/app/i18n/cache.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
Advanced caching system for i18n translations.
|
||||
|
||||
This module provides:
|
||||
- LRU cache for hot translations
|
||||
- Lazy loading mechanism
|
||||
- Memory optimization
|
||||
- Cache statistics
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, Optional
|
||||
from collections import OrderedDict
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranslationCache:
|
||||
"""
|
||||
Advanced translation cache with LRU eviction and lazy loading.
|
||||
|
||||
Features:
|
||||
- LRU cache for frequently accessed translations
|
||||
- Lazy loading to reduce startup time
|
||||
- Memory-efficient storage
|
||||
- Cache hit/miss statistics
|
||||
"""
|
||||
|
||||
def __init__(self, max_lru_size: int = 1000, enable_lazy_load: bool = True):
|
||||
"""
|
||||
Initialize the translation cache.
|
||||
|
||||
Args:
|
||||
max_lru_size: Maximum size of LRU cache for hot translations
|
||||
enable_lazy_load: Enable lazy loading of locales
|
||||
"""
|
||||
self.max_lru_size = max_lru_size
|
||||
self.enable_lazy_load = enable_lazy_load
|
||||
|
||||
# Main cache: {locale: {namespace: {key: value}}}
|
||||
self._main_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# LRU cache for hot translations
|
||||
self._lru_cache: OrderedDict = OrderedDict()
|
||||
|
||||
# Loaded locales tracker
|
||||
self._loaded_locales: set = set()
|
||||
|
||||
# Statistics
|
||||
self._stats = {
|
||||
"hits": 0,
|
||||
"misses": 0,
|
||||
"lru_hits": 0,
|
||||
"lru_misses": 0,
|
||||
"lazy_loads": 0
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"TranslationCache initialized with LRU size: {max_lru_size}, "
|
||||
f"lazy loading: {enable_lazy_load}"
|
||||
)
|
||||
|
||||
def set_locale_data(self, locale: str, data: Dict[str, Any]):
|
||||
"""
|
||||
Set translation data for a locale.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
data: Translation data dictionary
|
||||
"""
|
||||
self._main_cache[locale] = data
|
||||
self._loaded_locales.add(locale)
|
||||
logger.debug(f"Loaded locale '{locale}' into cache")
|
||||
|
||||
def get_translation(
|
||||
self,
|
||||
locale: str,
|
||||
namespace: str,
|
||||
key_path: list
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get translation from cache with LRU optimization.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
namespace: Translation namespace
|
||||
key_path: List of nested keys
|
||||
|
||||
Returns:
|
||||
Translation string or None if not found
|
||||
"""
|
||||
# Build cache key for LRU
|
||||
cache_key = f"{locale}:{namespace}:{'.'.join(key_path)}"
|
||||
|
||||
# Check LRU cache first (hot translations)
|
||||
if cache_key in self._lru_cache:
|
||||
self._stats["lru_hits"] += 1
|
||||
self._stats["hits"] += 1
|
||||
# Move to end (most recently used)
|
||||
self._lru_cache.move_to_end(cache_key)
|
||||
return self._lru_cache[cache_key]
|
||||
|
||||
self._stats["lru_misses"] += 1
|
||||
|
||||
# Check main cache
|
||||
if locale not in self._main_cache:
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
if namespace not in self._main_cache[locale]:
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
# Navigate through nested keys
|
||||
current = self._main_cache[locale][namespace]
|
||||
for key in key_path:
|
||||
if isinstance(current, dict) and key in current:
|
||||
current = current[key]
|
||||
else:
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
# Return only if it's a string value
|
||||
if not isinstance(current, str):
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
self._stats["hits"] += 1
|
||||
|
||||
# Add to LRU cache
|
||||
self._add_to_lru(cache_key, current)
|
||||
|
||||
return current
|
||||
|
||||
def _add_to_lru(self, key: str, value: str):
|
||||
"""
|
||||
Add translation to LRU cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Translation value
|
||||
"""
|
||||
# Remove oldest if cache is full
|
||||
if len(self._lru_cache) >= self.max_lru_size:
|
||||
self._lru_cache.popitem(last=False)
|
||||
|
||||
self._lru_cache[key] = value
|
||||
|
||||
def is_locale_loaded(self, locale: str) -> bool:
|
||||
"""
|
||||
Check if a locale is loaded.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
|
||||
Returns:
|
||||
True if locale is loaded
|
||||
"""
|
||||
return locale in self._loaded_locales
|
||||
|
||||
def get_loaded_locales(self) -> list:
|
||||
"""
|
||||
Get list of loaded locales.
|
||||
|
||||
Returns:
|
||||
List of locale codes
|
||||
"""
|
||||
return list(self._loaded_locales)
|
||||
|
||||
def clear_lru(self):
|
||||
"""Clear the LRU cache."""
|
||||
self._lru_cache.clear()
|
||||
logger.info("LRU cache cleared")
|
||||
|
||||
def clear_locale(self, locale: str):
|
||||
"""
|
||||
Clear cache for a specific locale.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
"""
|
||||
if locale in self._main_cache:
|
||||
del self._main_cache[locale]
|
||||
self._loaded_locales.discard(locale)
|
||||
|
||||
# Clear related LRU entries
|
||||
keys_to_remove = [k for k in self._lru_cache if k.startswith(f"{locale}:")]
|
||||
for key in keys_to_remove:
|
||||
del self._lru_cache[key]
|
||||
|
||||
logger.info(f"Cleared cache for locale '{locale}'")
|
||||
|
||||
def clear_all(self):
|
||||
"""Clear all caches."""
|
||||
self._main_cache.clear()
|
||||
self._lru_cache.clear()
|
||||
self._loaded_locales.clear()
|
||||
logger.info("All caches cleared")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
total_requests = self._stats["hits"] + self._stats["misses"]
|
||||
hit_rate = (
|
||||
self._stats["hits"] / total_requests * 100
|
||||
if total_requests > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
lru_total = self._stats["lru_hits"] + self._stats["lru_misses"]
|
||||
lru_hit_rate = (
|
||||
self._stats["lru_hits"] / lru_total * 100
|
||||
if lru_total > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_requests": total_requests,
|
||||
"hits": self._stats["hits"],
|
||||
"misses": self._stats["misses"],
|
||||
"hit_rate": round(hit_rate, 2),
|
||||
"lru_hits": self._stats["lru_hits"],
|
||||
"lru_misses": self._stats["lru_misses"],
|
||||
"lru_hit_rate": round(lru_hit_rate, 2),
|
||||
"lru_size": len(self._lru_cache),
|
||||
"lru_max_size": self.max_lru_size,
|
||||
"loaded_locales": len(self._loaded_locales),
|
||||
"lazy_loads": self._stats["lazy_loads"]
|
||||
}
|
||||
|
||||
def reset_stats(self):
|
||||
"""Reset cache statistics."""
|
||||
self._stats = {
|
||||
"hits": 0,
|
||||
"misses": 0,
|
||||
"lru_hits": 0,
|
||||
"lru_misses": 0,
|
||||
"lazy_loads": 0
|
||||
}
|
||||
logger.info("Cache statistics reset")
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Estimate memory usage of the cache.
|
||||
|
||||
Returns:
|
||||
Dictionary with memory usage information
|
||||
"""
|
||||
import sys
|
||||
|
||||
main_cache_size = sys.getsizeof(self._main_cache)
|
||||
lru_cache_size = sys.getsizeof(self._lru_cache)
|
||||
|
||||
# Rough estimate of nested data
|
||||
for locale_data in self._main_cache.values():
|
||||
main_cache_size += sys.getsizeof(locale_data)
|
||||
for namespace_data in locale_data.values():
|
||||
main_cache_size += sys.getsizeof(namespace_data)
|
||||
|
||||
return {
|
||||
"main_cache_bytes": main_cache_size,
|
||||
"lru_cache_bytes": lru_cache_size,
|
||||
"total_bytes": main_cache_size + lru_cache_size,
|
||||
"main_cache_mb": round(main_cache_size / 1024 / 1024, 2),
|
||||
"lru_cache_mb": round(lru_cache_size / 1024 / 1024, 2),
|
||||
"total_mb": round((main_cache_size + lru_cache_size) / 1024 / 1024, 2)
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def get_cached_translation_key(locale: str, namespace: str, key: str) -> str:
|
||||
"""
|
||||
LRU cached function for building translation cache keys.
|
||||
|
||||
This reduces string concatenation overhead for frequently accessed keys.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
namespace: Translation namespace
|
||||
key: Translation key
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
"""
|
||||
return f"{locale}:{namespace}:{key}"
|
||||
158
api/app/i18n/dependencies.py
Normal file
158
api/app/i18n/dependencies.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
FastAPI dependency injection functions for i18n.
|
||||
|
||||
This module provides dependency injection functions that can be used
|
||||
in FastAPI route handlers to access the current language and translator.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from app.i18n.service import get_translation_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_current_language(request: Request) -> str:
|
||||
"""
|
||||
Get the current language from the request context.
|
||||
|
||||
This dependency extracts the language that was determined by the
|
||||
LanguageMiddleware and stored in request.state.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Language code (e.g., "zh", "en")
|
||||
|
||||
Usage:
|
||||
@router.get("/example")
|
||||
async def example(language: str = Depends(get_current_language)):
|
||||
return {"language": language}
|
||||
"""
|
||||
# Get language from request state (set by LanguageMiddleware)
|
||||
language = getattr(request.state, "language", None)
|
||||
|
||||
if language is None:
|
||||
# Fallback to default language if not set
|
||||
from app.core.config import settings
|
||||
language = settings.I18N_DEFAULT_LANGUAGE
|
||||
logger.warning(
|
||||
"Language not found in request.state, using default: "
|
||||
f"{language}"
|
||||
)
|
||||
|
||||
return language
|
||||
|
||||
|
||||
async def get_translator(request: Request) -> Callable:
|
||||
"""
|
||||
Get a translator function bound to the current request's language.
|
||||
|
||||
This dependency returns a translation function that automatically
|
||||
uses the current request's language, making it easy to translate
|
||||
strings in route handlers.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Translation function with signature: t(key: str, **params) -> str
|
||||
|
||||
Usage:
|
||||
@router.post("/workspaces")
|
||||
async def create_workspace(
|
||||
data: WorkspaceCreate,
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
workspace = await workspace_service.create(data)
|
||||
return {
|
||||
"success": True,
|
||||
"message": t("workspace.created_successfully"),
|
||||
"data": workspace
|
||||
}
|
||||
|
||||
# With parameters
|
||||
@router.get("/items")
|
||||
async def get_items(t: Callable = Depends(get_translator)):
|
||||
count = 5
|
||||
return {
|
||||
"message": t("items.found", count=count)
|
||||
}
|
||||
"""
|
||||
# Get current language
|
||||
language = await get_current_language(request)
|
||||
|
||||
# Get translation service
|
||||
service = get_translation_service()
|
||||
|
||||
# Return a bound translation function
|
||||
def translate(key: str, **params) -> str:
|
||||
"""
|
||||
Translate a key using the current request's language.
|
||||
|
||||
Args:
|
||||
key: Translation key (e.g., "common.success.created")
|
||||
**params: Parameters for parameterized messages
|
||||
|
||||
Returns:
|
||||
Translated string
|
||||
"""
|
||||
return service.translate(key, language, **params)
|
||||
|
||||
return translate
|
||||
|
||||
|
||||
async def get_enum_translator(request: Request) -> Callable:
|
||||
"""
|
||||
Get an enum translator function bound to the current request's language.
|
||||
|
||||
This dependency returns a function for translating enum values
|
||||
that automatically uses the current request's language.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Enum translation function with signature:
|
||||
t_enum(enum_type: str, value: str) -> str
|
||||
|
||||
Usage:
|
||||
@router.get("/workspace/{id}")
|
||||
async def get_workspace(
|
||||
id: str,
|
||||
t_enum: Callable = Depends(get_enum_translator)
|
||||
):
|
||||
workspace = await workspace_service.get(id)
|
||||
return {
|
||||
"id": workspace.id,
|
||||
"role": workspace.role,
|
||||
"role_display": t_enum("workspace_role", workspace.role),
|
||||
"status": workspace.status,
|
||||
"status_display": t_enum("workspace_status", workspace.status)
|
||||
}
|
||||
"""
|
||||
# Get current language
|
||||
language = await get_current_language(request)
|
||||
|
||||
# Get translation service
|
||||
service = get_translation_service()
|
||||
|
||||
# Return a bound enum translation function
|
||||
def translate_enum(enum_type: str, value: str) -> str:
|
||||
"""
|
||||
Translate an enum value using the current request's language.
|
||||
|
||||
Args:
|
||||
enum_type: Enum type name (e.g., "workspace_role")
|
||||
value: Enum value (e.g., "manager")
|
||||
|
||||
Returns:
|
||||
Translated enum display name
|
||||
"""
|
||||
return service.translate_enum(enum_type, value, language)
|
||||
|
||||
return translate_enum
|
||||
495
api/app/i18n/exceptions.py
Normal file
495
api/app/i18n/exceptions.py
Normal file
@@ -0,0 +1,495 @@
|
||||
"""
|
||||
Internationalized exception classes for i18n system.
|
||||
|
||||
This module provides exception classes that automatically translate
|
||||
error messages based on the current request's language.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from app.i18n.service import get_translation_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variable to store current locale
|
||||
_current_locale: ContextVar[Optional[str]] = ContextVar("current_locale", default=None)
|
||||
|
||||
|
||||
def set_current_locale(locale: str) -> None:
|
||||
"""
|
||||
Set the current locale in the context variable.
|
||||
|
||||
This should be called by the LanguageMiddleware.
|
||||
|
||||
Args:
|
||||
locale: Locale code (e.g., "zh", "en")
|
||||
"""
|
||||
_current_locale.set(locale)
|
||||
|
||||
|
||||
def get_current_locale() -> Optional[str]:
|
||||
"""
|
||||
Get the current locale from the context variable.
|
||||
|
||||
Returns:
|
||||
Locale code or None if not set
|
||||
"""
|
||||
return _current_locale.get()
|
||||
|
||||
|
||||
class I18nException(HTTPException):
|
||||
"""
|
||||
Base exception class with automatic i18n support.
|
||||
|
||||
This exception automatically translates error messages based on:
|
||||
1. The current request's language (from request.state.language)
|
||||
2. The fallback language if request language is not available
|
||||
3. The error key itself if no translation is found
|
||||
|
||||
Features:
|
||||
- Automatic error message translation
|
||||
- Parameterized error messages support
|
||||
- Consistent error response format
|
||||
- Language-aware error handling
|
||||
|
||||
Usage:
|
||||
# Simple error
|
||||
raise I18nException(
|
||||
error_key="errors.workspace.not_found",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
# Error with parameters
|
||||
raise I18nException(
|
||||
error_key="errors.validation.missing_field",
|
||||
status_code=400,
|
||||
field="name"
|
||||
)
|
||||
|
||||
# Custom error code
|
||||
raise I18nException(
|
||||
error_key="errors.workspace.not_found",
|
||||
error_code="WORKSPACE_NOT_FOUND",
|
||||
status_code=404,
|
||||
workspace_id="123"
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str,
|
||||
status_code: int = 400,
|
||||
error_code: Optional[str] = None,
|
||||
locale: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
**params
|
||||
):
|
||||
"""
|
||||
Initialize the i18n exception.
|
||||
|
||||
Args:
|
||||
error_key: Translation key for the error message
|
||||
(e.g., "errors.workspace.not_found")
|
||||
status_code: HTTP status code (default: 400)
|
||||
error_code: Custom error code for API clients
|
||||
(default: derived from error_key)
|
||||
locale: Target locale for translation (optional)
|
||||
If not provided, uses current request's language
|
||||
headers: Additional HTTP headers
|
||||
**params: Parameters for parameterized error messages
|
||||
"""
|
||||
self.error_key = error_key
|
||||
self.error_code = error_code or self._generate_error_code(error_key)
|
||||
self.params = params
|
||||
|
||||
# Get locale from request context if not provided
|
||||
if locale is None:
|
||||
locale = self._get_current_locale()
|
||||
|
||||
# Translate error message
|
||||
translation_service = get_translation_service()
|
||||
message = translation_service.translate(
|
||||
error_key,
|
||||
locale,
|
||||
**params
|
||||
)
|
||||
|
||||
# Build error detail
|
||||
detail = {
|
||||
"error_code": self.error_code,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
# Add parameters to detail if provided
|
||||
if params:
|
||||
detail["params"] = params
|
||||
|
||||
# Initialize HTTPException
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
detail=detail,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"I18nException raised: {self.error_code} "
|
||||
f"(key: {error_key}, locale: {locale})"
|
||||
)
|
||||
|
||||
def _get_current_locale(self) -> str:
|
||||
"""
|
||||
Get the current locale from request context.
|
||||
|
||||
Returns:
|
||||
Locale code (e.g., "zh", "en")
|
||||
"""
|
||||
try:
|
||||
# Try to get locale from context variable
|
||||
locale = _current_locale.get()
|
||||
if locale:
|
||||
return locale
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get locale from context: {e}")
|
||||
|
||||
# Fallback to default locale
|
||||
from app.core.config import settings
|
||||
return settings.I18N_DEFAULT_LANGUAGE
|
||||
|
||||
def _generate_error_code(self, error_key: str) -> str:
|
||||
"""
|
||||
Generate error code from error key.
|
||||
|
||||
Converts "errors.workspace.not_found" to "WORKSPACE_NOT_FOUND"
|
||||
|
||||
Args:
|
||||
error_key: Translation key
|
||||
|
||||
Returns:
|
||||
Error code in UPPER_SNAKE_CASE
|
||||
"""
|
||||
# Remove "errors." prefix if present
|
||||
if error_key.startswith("errors."):
|
||||
error_key = error_key[7:]
|
||||
|
||||
# Convert to UPPER_SNAKE_CASE
|
||||
parts = error_key.split(".")
|
||||
return "_".join(parts).upper()
|
||||
|
||||
|
||||
# Specific exception classes for common errors
|
||||
|
||||
class BadRequestError(I18nException):
|
||||
"""Bad request error (400)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.common.bad_request",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=400,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class UnauthorizedError(I18nException):
|
||||
"""Unauthorized error (401)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.auth.unauthorized",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=401,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class ForbiddenError(I18nException):
|
||||
"""Forbidden error (403)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.auth.forbidden",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=403,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class NotFoundError(I18nException):
|
||||
"""Not found error (404)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.common.not_found",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=404,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class ConflictError(I18nException):
|
||||
"""Conflict error (409)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.common.conflict",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=409,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class ValidationError(I18nException):
|
||||
"""Validation error (422)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.common.validation_failed",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=422,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class InternalServerError(I18nException):
|
||||
"""Internal server error (500)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.common.internal_error",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=500,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class ServiceUnavailableError(I18nException):
|
||||
"""Service unavailable error (503)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.common.service_unavailable",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=503,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
# Domain-specific exception classes
|
||||
|
||||
class WorkspaceNotFoundError(NotFoundError):
|
||||
"""Workspace not found error."""
|
||||
|
||||
def __init__(self, workspace_id: Optional[str] = None, **params):
|
||||
if workspace_id:
|
||||
params["workspace_id"] = workspace_id
|
||||
super().__init__(
|
||||
error_key="errors.workspace.not_found",
|
||||
error_code="WORKSPACE_NOT_FOUND",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class WorkspacePermissionDeniedError(ForbiddenError):
|
||||
"""Workspace permission denied error."""
|
||||
|
||||
def __init__(self, workspace_id: Optional[str] = None, **params):
|
||||
if workspace_id:
|
||||
params["workspace_id"] = workspace_id
|
||||
super().__init__(
|
||||
error_key="errors.workspace.permission_denied",
|
||||
error_code="WORKSPACE_PERMISSION_DENIED",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class UserNotFoundError(NotFoundError):
|
||||
"""User not found error."""
|
||||
|
||||
def __init__(self, user_id: Optional[str] = None, **params):
|
||||
if user_id:
|
||||
params["user_id"] = user_id
|
||||
super().__init__(
|
||||
error_key="errors.user.not_found",
|
||||
error_code="USER_NOT_FOUND",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class UserAlreadyExistsError(ConflictError):
|
||||
"""User already exists error."""
|
||||
|
||||
def __init__(self, identifier: Optional[str] = None, **params):
|
||||
if identifier:
|
||||
params["identifier"] = identifier
|
||||
super().__init__(
|
||||
error_key="errors.user.already_exists",
|
||||
error_code="USER_ALREADY_EXISTS",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class TenantNotFoundError(NotFoundError):
|
||||
"""Tenant not found error."""
|
||||
|
||||
def __init__(self, tenant_id: Optional[str] = None, **params):
|
||||
if tenant_id:
|
||||
params["tenant_id"] = tenant_id
|
||||
super().__init__(
|
||||
error_key="errors.tenant.not_found",
|
||||
error_code="TENANT_NOT_FOUND",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class TenantSuspendedError(ForbiddenError):
|
||||
"""Tenant suspended error."""
|
||||
|
||||
def __init__(self, tenant_id: Optional[str] = None, **params):
|
||||
if tenant_id:
|
||||
params["tenant_id"] = tenant_id
|
||||
super().__init__(
|
||||
error_key="errors.tenant.suspended",
|
||||
error_code="TENANT_SUSPENDED",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class InvalidCredentialsError(UnauthorizedError):
|
||||
"""Invalid credentials error."""
|
||||
|
||||
def __init__(self, **params):
|
||||
super().__init__(
|
||||
error_key="errors.auth.invalid_credentials",
|
||||
error_code="INVALID_CREDENTIALS",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class TokenExpiredError(UnauthorizedError):
|
||||
"""Token expired error."""
|
||||
|
||||
def __init__(self, **params):
|
||||
super().__init__(
|
||||
error_key="errors.auth.token_expired",
|
||||
error_code="TOKEN_EXPIRED",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class TokenInvalidError(UnauthorizedError):
|
||||
"""Token invalid error."""
|
||||
|
||||
def __init__(self, **params):
|
||||
super().__init__(
|
||||
error_key="errors.auth.token_invalid",
|
||||
error_code="TOKEN_INVALID",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class FileNotFoundError(NotFoundError):
|
||||
"""File not found error."""
|
||||
|
||||
def __init__(self, file_id: Optional[str] = None, **params):
|
||||
if file_id:
|
||||
params["file_id"] = file_id
|
||||
super().__init__(
|
||||
error_key="errors.file.not_found",
|
||||
error_code="FILE_NOT_FOUND",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class FileTooLargeError(BadRequestError):
|
||||
"""File too large error."""
|
||||
|
||||
def __init__(self, max_size: Optional[str] = None, **params):
|
||||
if max_size:
|
||||
params["max_size"] = max_size
|
||||
super().__init__(
|
||||
error_key="errors.file.too_large",
|
||||
error_code="FILE_TOO_LARGE",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class InvalidFileTypeError(BadRequestError):
|
||||
"""Invalid file type error."""
|
||||
|
||||
def __init__(self, file_type: Optional[str] = None, **params):
|
||||
if file_type:
|
||||
params["file_type"] = file_type
|
||||
super().__init__(
|
||||
error_key="errors.file.invalid_type",
|
||||
error_code="INVALID_FILE_TYPE",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class RateLimitExceededError(I18nException):
|
||||
"""Rate limit exceeded error (429)."""
|
||||
|
||||
def __init__(self, **params):
|
||||
super().__init__(
|
||||
error_key="errors.api.rate_limit_exceeded",
|
||||
status_code=429,
|
||||
error_code="RATE_LIMIT_EXCEEDED",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class QuotaExceededError(ForbiddenError):
|
||||
"""Quota exceeded error."""
|
||||
|
||||
def __init__(self, resource: Optional[str] = None, **params):
|
||||
if resource:
|
||||
params["resource"] = resource
|
||||
super().__init__(
|
||||
error_key="errors.api.quota_exceeded",
|
||||
error_code="QUOTA_EXCEEDED",
|
||||
**params
|
||||
)
|
||||
199
api/app/i18n/loader.py
Normal file
199
api/app/i18n/loader.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
Translation file loader for i18n system.
|
||||
|
||||
This module handles loading translation files from multiple directories
|
||||
(community edition + enterprise edition) and provides hot reload support.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranslationLoader:
|
||||
"""
|
||||
Translation file loader that supports:
|
||||
- Loading from multiple directories (community + enterprise)
|
||||
- Hot reload of translation files
|
||||
- Automatic locale detection
|
||||
"""
|
||||
|
||||
def __init__(self, locales_dirs: Optional[List[str]] = None):
|
||||
"""
|
||||
Initialize the translation loader.
|
||||
|
||||
Args:
|
||||
locales_dirs: List of directories containing translation files.
|
||||
If None, will auto-detect from settings.
|
||||
"""
|
||||
if locales_dirs is None:
|
||||
locales_dirs = self._detect_locales_dirs()
|
||||
|
||||
self.locales_dirs = [Path(d) for d in locales_dirs]
|
||||
logger.info(f"TranslationLoader initialized with directories: {self.locales_dirs}")
|
||||
|
||||
def _detect_locales_dirs(self) -> List[str]:
|
||||
"""
|
||||
Auto-detect translation directories from settings.
|
||||
|
||||
Returns:
|
||||
List of translation directory paths
|
||||
"""
|
||||
from app.core.config import settings
|
||||
|
||||
dirs = []
|
||||
|
||||
# 1. Core locales directory (community edition, required)
|
||||
core_dir = Path(settings.I18N_CORE_LOCALES_DIR)
|
||||
if core_dir.exists():
|
||||
dirs.append(str(core_dir))
|
||||
logger.debug(f"Found core locales directory: {core_dir}")
|
||||
else:
|
||||
logger.warning(f"Core locales directory not found: {core_dir}")
|
||||
|
||||
# 2. Premium locales directory (enterprise edition, optional)
|
||||
if settings.I18N_PREMIUM_LOCALES_DIR:
|
||||
premium_dir = Path(settings.I18N_PREMIUM_LOCALES_DIR)
|
||||
if premium_dir.exists():
|
||||
dirs.append(str(premium_dir))
|
||||
logger.debug(f"Found premium locales directory: {premium_dir}")
|
||||
else:
|
||||
# Auto-detect premium directory
|
||||
premium_dir = Path("premium/locales")
|
||||
if premium_dir.exists():
|
||||
dirs.append(str(premium_dir))
|
||||
logger.debug(f"Auto-detected premium locales directory: {premium_dir}")
|
||||
|
||||
if not dirs:
|
||||
logger.error("No translation directories found!")
|
||||
|
||||
return dirs
|
||||
|
||||
def get_available_locales(self) -> List[str]:
|
||||
"""
|
||||
Get list of all available locales across all directories.
|
||||
|
||||
Returns:
|
||||
List of locale codes (e.g., ['zh', 'en'])
|
||||
"""
|
||||
locales = set()
|
||||
|
||||
for locales_dir in self.locales_dirs:
|
||||
if not locales_dir.exists():
|
||||
continue
|
||||
|
||||
for locale_dir in locales_dir.iterdir():
|
||||
if locale_dir.is_dir() and not locale_dir.name.startswith('.'):
|
||||
locales.add(locale_dir.name)
|
||||
|
||||
return sorted(list(locales))
|
||||
|
||||
def load_locale(self, locale: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Load all translation files for a specific locale from all directories.
|
||||
|
||||
Translation files are merged with priority:
|
||||
- Later directories override earlier directories
|
||||
- Enterprise translations override community translations
|
||||
|
||||
Args:
|
||||
locale: Locale code (e.g., 'zh', 'en')
|
||||
|
||||
Returns:
|
||||
Dictionary of translations organized by namespace
|
||||
Format: {namespace: {key: value, ...}, ...}
|
||||
"""
|
||||
translations = {}
|
||||
|
||||
# Load from each directory in order (later directories override earlier)
|
||||
for locales_dir in self.locales_dirs:
|
||||
locale_dir = locales_dir / locale
|
||||
if not locale_dir.exists():
|
||||
logger.debug(f"Locale directory not found: {locale_dir}")
|
||||
continue
|
||||
|
||||
# Load all JSON files in this locale directory
|
||||
for json_file in locale_dir.glob("*.json"):
|
||||
namespace = json_file.stem
|
||||
|
||||
try:
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
new_translations = json.load(f)
|
||||
|
||||
# Merge translations (deep merge)
|
||||
if namespace in translations:
|
||||
translations[namespace] = self._deep_merge(
|
||||
translations[namespace],
|
||||
new_translations
|
||||
)
|
||||
logger.debug(
|
||||
f"Merged translations: {locale}/{namespace} from {json_file}"
|
||||
)
|
||||
else:
|
||||
translations[namespace] = new_translations
|
||||
logger.debug(
|
||||
f"Loaded translations: {locale}/{namespace} from {json_file}"
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
f"Failed to parse JSON file {json_file}: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load translation file {json_file}: {e}"
|
||||
)
|
||||
|
||||
if not translations:
|
||||
logger.warning(f"No translations found for locale: {locale}")
|
||||
|
||||
return translations
|
||||
|
||||
def reload(self, locale: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Reload translation files.
|
||||
|
||||
Args:
|
||||
locale: Specific locale to reload. If None, reloads all locales.
|
||||
|
||||
Returns:
|
||||
Dictionary of reloaded translations
|
||||
Format: {locale: {namespace: {key: value}}}
|
||||
"""
|
||||
if locale:
|
||||
logger.info(f"Reloading translations for locale: {locale}")
|
||||
return {locale: self.load_locale(locale)}
|
||||
else:
|
||||
logger.info("Reloading all translations")
|
||||
all_translations = {}
|
||||
for loc in self.get_available_locales():
|
||||
all_translations[loc] = self.load_locale(loc)
|
||||
return all_translations
|
||||
|
||||
def _deep_merge(self, base: Dict, override: Dict) -> Dict:
|
||||
"""
|
||||
Deep merge two dictionaries.
|
||||
|
||||
Args:
|
||||
base: Base dictionary
|
||||
override: Dictionary with values to override
|
||||
|
||||
Returns:
|
||||
Merged dictionary
|
||||
"""
|
||||
result = base.copy()
|
||||
|
||||
for key, value in override.items():
|
||||
if (
|
||||
key in result
|
||||
and isinstance(result[key], dict)
|
||||
and isinstance(value, dict)
|
||||
):
|
||||
result[key] = self._deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
382
api/app/i18n/logger.py
Normal file
382
api/app/i18n/logger.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""
|
||||
Translation logging for i18n system.
|
||||
|
||||
This module provides:
|
||||
- TranslationLogger for recording missing translations
|
||||
- Missing translation report generation
|
||||
- Integration with existing logging system
|
||||
- Structured logging for translation events
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Set
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TranslationLogger:
|
||||
"""
|
||||
Logger for translation events and missing translations.
|
||||
|
||||
Features:
|
||||
- Records missing translations with context
|
||||
- Generates missing translation reports
|
||||
- Integrates with existing logging system
|
||||
- Provides structured logging for analysis
|
||||
"""
|
||||
|
||||
def __init__(self, log_file: Optional[str] = None):
|
||||
"""
|
||||
Initialize translation logger.
|
||||
|
||||
Args:
|
||||
log_file: Optional custom log file path for missing translations
|
||||
"""
|
||||
self.log_file = log_file or "logs/i18n/missing_translations.log"
|
||||
self._missing_translations: Dict[str, Set[str]] = defaultdict(set)
|
||||
self._missing_with_context: List[Dict] = []
|
||||
self._max_context_entries = 10000 # Keep last 10k entries
|
||||
|
||||
# Ensure log directory exists
|
||||
log_path = Path(self.log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create dedicated file handler for missing translations
|
||||
self._file_handler = logging.FileHandler(
|
||||
self.log_file,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self._file_handler.setLevel(logging.WARNING)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(
|
||||
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
self._file_handler.setFormatter(formatter)
|
||||
|
||||
# Create dedicated logger for missing translations
|
||||
self._logger = logging.getLogger("i18n.missing_translations")
|
||||
self._logger.setLevel(logging.WARNING)
|
||||
self._logger.addHandler(self._file_handler)
|
||||
self._logger.propagate = False # Don't propagate to root logger
|
||||
|
||||
logger.info(f"TranslationLogger initialized with log file: {self.log_file}")
|
||||
|
||||
def log_missing_translation(
|
||||
self,
|
||||
key: str,
|
||||
locale: str,
|
||||
context: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Log a missing translation.
|
||||
|
||||
Args:
|
||||
key: Translation key that was not found
|
||||
locale: Locale code
|
||||
context: Optional context information (e.g., request path, user info)
|
||||
"""
|
||||
# Add to missing set
|
||||
self._missing_translations[locale].add(key)
|
||||
|
||||
# Create context entry
|
||||
entry = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"key": key,
|
||||
"locale": locale,
|
||||
"context": context or {}
|
||||
}
|
||||
|
||||
# Keep only recent entries to avoid memory bloat
|
||||
if len(self._missing_with_context) >= self._max_context_entries:
|
||||
self._missing_with_context.pop(0)
|
||||
|
||||
self._missing_with_context.append(entry)
|
||||
|
||||
# Log to file
|
||||
context_str = f" (context: {context})" if context else ""
|
||||
self._logger.warning(
|
||||
f"Missing translation: key='{key}', locale='{locale}'{context_str}"
|
||||
)
|
||||
|
||||
def log_translation_error(
|
||||
self,
|
||||
error_type: str,
|
||||
message: str,
|
||||
key: Optional[str] = None,
|
||||
locale: Optional[str] = None,
|
||||
context: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Log a translation error.
|
||||
|
||||
Args:
|
||||
error_type: Type of error (e.g., "format_error", "parameter_missing")
|
||||
message: Error message
|
||||
key: Translation key (optional)
|
||||
locale: Locale code (optional)
|
||||
context: Optional context information
|
||||
"""
|
||||
error_data = {
|
||||
"error_type": error_type,
|
||||
"message": message,
|
||||
"key": key,
|
||||
"locale": locale,
|
||||
"context": context or {},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self._logger.error(
|
||||
f"Translation error: {error_type} - {message} "
|
||||
f"(key: {key}, locale: {locale})"
|
||||
)
|
||||
|
||||
def log_translation_success(
|
||||
self,
|
||||
key: str,
|
||||
locale: str,
|
||||
duration_ms: Optional[float] = None
|
||||
):
|
||||
"""
|
||||
Log a successful translation (debug level).
|
||||
|
||||
Args:
|
||||
key: Translation key
|
||||
locale: Locale code
|
||||
duration_ms: Optional duration in milliseconds
|
||||
"""
|
||||
duration_str = f" ({duration_ms:.3f}ms)" if duration_ms else ""
|
||||
logger.debug(
|
||||
f"Translation success: key='{key}', locale='{locale}'{duration_str}"
|
||||
)
|
||||
|
||||
def get_missing_translations(
|
||||
self,
|
||||
locale: Optional[str] = None
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Get missing translations.
|
||||
|
||||
Args:
|
||||
locale: Specific locale (optional, returns all if None)
|
||||
|
||||
Returns:
|
||||
Dictionary of missing translations by locale
|
||||
"""
|
||||
if locale:
|
||||
return {locale: sorted(list(self._missing_translations.get(locale, set())))}
|
||||
|
||||
return {
|
||||
loc: sorted(list(keys))
|
||||
for loc, keys in self._missing_translations.items()
|
||||
}
|
||||
|
||||
def get_missing_with_context(
|
||||
self,
|
||||
locale: Optional[str] = None,
|
||||
limit: Optional[int] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Get missing translations with context.
|
||||
|
||||
Args:
|
||||
locale: Filter by locale (optional)
|
||||
limit: Maximum number of entries to return (optional)
|
||||
|
||||
Returns:
|
||||
List of missing translation entries with context
|
||||
"""
|
||||
entries = self._missing_with_context
|
||||
|
||||
# Filter by locale if specified
|
||||
if locale:
|
||||
entries = [e for e in entries if e["locale"] == locale]
|
||||
|
||||
# Apply limit if specified
|
||||
if limit:
|
||||
entries = entries[-limit:]
|
||||
|
||||
return entries
|
||||
|
||||
def generate_report(
|
||||
self,
|
||||
locale: Optional[str] = None,
|
||||
output_file: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Generate a missing translation report.
|
||||
|
||||
Args:
|
||||
locale: Specific locale (optional, generates for all if None)
|
||||
output_file: Optional file path to save report as JSON
|
||||
|
||||
Returns:
|
||||
Report dictionary
|
||||
"""
|
||||
missing = self.get_missing_translations(locale)
|
||||
|
||||
report = {
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"total_missing": sum(len(keys) for keys in missing.values()),
|
||||
"missing_by_locale": {
|
||||
loc: {
|
||||
"count": len(keys),
|
||||
"keys": keys
|
||||
}
|
||||
for loc, keys in missing.items()
|
||||
},
|
||||
"recent_context": self.get_missing_with_context(locale, limit=100)
|
||||
}
|
||||
|
||||
# Save to file if specified
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"Missing translation report saved to: {output_file}")
|
||||
|
||||
return report
|
||||
|
||||
def get_statistics(self) -> Dict:
|
||||
"""
|
||||
Get statistics about missing translations.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics
|
||||
"""
|
||||
total_missing = sum(len(keys) for keys in self._missing_translations.values())
|
||||
|
||||
# Count by namespace
|
||||
namespace_counts = defaultdict(int)
|
||||
for locale, keys in self._missing_translations.items():
|
||||
for key in keys:
|
||||
namespace = key.split('.')[0] if '.' in key else 'unknown'
|
||||
namespace_counts[namespace] += 1
|
||||
|
||||
return {
|
||||
"total_missing": total_missing,
|
||||
"locales_affected": len(self._missing_translations),
|
||||
"missing_by_locale": {
|
||||
loc: len(keys)
|
||||
for loc, keys in self._missing_translations.items()
|
||||
},
|
||||
"missing_by_namespace": dict(namespace_counts),
|
||||
"total_context_entries": len(self._missing_with_context)
|
||||
}
|
||||
|
||||
def clear(self, locale: Optional[str] = None):
|
||||
"""
|
||||
Clear missing translation records.
|
||||
|
||||
Args:
|
||||
locale: Specific locale to clear (optional, clears all if None)
|
||||
"""
|
||||
if locale:
|
||||
self._missing_translations.pop(locale, None)
|
||||
self._missing_with_context = [
|
||||
e for e in self._missing_with_context
|
||||
if e["locale"] != locale
|
||||
]
|
||||
logger.info(f"Cleared missing translations for locale: {locale}")
|
||||
else:
|
||||
self._missing_translations.clear()
|
||||
self._missing_with_context.clear()
|
||||
logger.info("Cleared all missing translations")
|
||||
|
||||
def export_to_json(self, output_file: str):
|
||||
"""
|
||||
Export all missing translations to JSON file.
|
||||
|
||||
Args:
|
||||
output_file: Output file path
|
||||
"""
|
||||
data = {
|
||||
"exported_at": datetime.now().isoformat(),
|
||||
"missing_translations": self.get_missing_translations(),
|
||||
"statistics": self.get_statistics(),
|
||||
"recent_context": self.get_missing_with_context(limit=1000)
|
||||
}
|
||||
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"Missing translations exported to: {output_file}")
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup file handler on deletion."""
|
||||
try:
|
||||
if hasattr(self, '_file_handler'):
|
||||
self._file_handler.close()
|
||||
self._logger.removeHandler(self._file_handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Global translation logger instance
|
||||
_translation_logger: Optional[TranslationLogger] = None
|
||||
|
||||
|
||||
def get_translation_logger() -> TranslationLogger:
|
||||
"""
|
||||
Get the global translation logger instance.
|
||||
|
||||
Returns:
|
||||
TranslationLogger singleton
|
||||
"""
|
||||
global _translation_logger
|
||||
if _translation_logger is None:
|
||||
_translation_logger = TranslationLogger()
|
||||
return _translation_logger
|
||||
|
||||
|
||||
def log_missing_translation(
|
||||
key: str,
|
||||
locale: str,
|
||||
context: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Log a missing translation (convenience function).
|
||||
|
||||
Args:
|
||||
key: Translation key
|
||||
locale: Locale code
|
||||
context: Optional context information
|
||||
"""
|
||||
translation_logger = get_translation_logger()
|
||||
translation_logger.log_missing_translation(key, locale, context)
|
||||
|
||||
|
||||
def log_translation_error(
|
||||
error_type: str,
|
||||
message: str,
|
||||
key: Optional[str] = None,
|
||||
locale: Optional[str] = None,
|
||||
context: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Log a translation error (convenience function).
|
||||
|
||||
Args:
|
||||
error_type: Type of error
|
||||
message: Error message
|
||||
key: Translation key (optional)
|
||||
locale: Locale code (optional)
|
||||
context: Optional context information
|
||||
"""
|
||||
translation_logger = get_translation_logger()
|
||||
translation_logger.log_translation_error(
|
||||
error_type, message, key, locale, context
|
||||
)
|
||||
337
api/app/i18n/metrics.py
Normal file
337
api/app/i18n/metrics.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
Performance monitoring and metrics for i18n system.
|
||||
|
||||
This module provides:
|
||||
- Translation request counters
|
||||
- Translation timing metrics
|
||||
- Missing translation tracking
|
||||
- Performance monitoring decorators
|
||||
- Prometheus-compatible metrics
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranslationMetrics:
|
||||
"""
|
||||
Metrics collector for translation operations.
|
||||
|
||||
Tracks:
|
||||
- Translation request counts
|
||||
- Translation timing (latency)
|
||||
- Missing translations
|
||||
- Cache performance
|
||||
- Locale usage
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize metrics collector."""
|
||||
# Request counters by locale
|
||||
self._request_counts: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# Missing translation tracker
|
||||
self._missing_translations: Dict[str, set] = defaultdict(set)
|
||||
|
||||
# Timing metrics (in milliseconds)
|
||||
self._timing_data: list = []
|
||||
self._max_timing_samples = 10000 # Keep last 10k samples
|
||||
|
||||
# Locale usage
|
||||
self._locale_usage: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# Namespace usage
|
||||
self._namespace_usage: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# Error counts
|
||||
self._error_counts: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# Start time
|
||||
self._start_time = datetime.now()
|
||||
|
||||
logger.info("TranslationMetrics initialized")
|
||||
|
||||
def record_request(self, locale: str, namespace: str = None):
|
||||
"""
|
||||
Record a translation request.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
namespace: Translation namespace (optional)
|
||||
"""
|
||||
self._request_counts[locale] += 1
|
||||
self._locale_usage[locale] += 1
|
||||
|
||||
if namespace:
|
||||
self._namespace_usage[namespace] += 1
|
||||
|
||||
def record_missing(self, key: str, locale: str):
|
||||
"""
|
||||
Record a missing translation.
|
||||
|
||||
Args:
|
||||
key: Translation key
|
||||
locale: Locale code
|
||||
"""
|
||||
self._missing_translations[locale].add(key)
|
||||
logger.debug(f"Missing translation recorded: {key} (locale: {locale})")
|
||||
|
||||
def record_timing(self, duration_ms: float, locale: str, operation: str = "translate"):
|
||||
"""
|
||||
Record translation operation timing.
|
||||
|
||||
Args:
|
||||
duration_ms: Duration in milliseconds
|
||||
locale: Locale code
|
||||
operation: Operation type
|
||||
"""
|
||||
# Keep only recent samples to avoid memory bloat
|
||||
if len(self._timing_data) >= self._max_timing_samples:
|
||||
self._timing_data.pop(0)
|
||||
|
||||
self._timing_data.append({
|
||||
"duration_ms": duration_ms,
|
||||
"locale": locale,
|
||||
"operation": operation,
|
||||
"timestamp": time.time()
|
||||
})
|
||||
|
||||
def record_error(self, error_type: str):
|
||||
"""
|
||||
Record an error.
|
||||
|
||||
Args:
|
||||
error_type: Type of error
|
||||
"""
|
||||
self._error_counts[error_type] += 1
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get metrics summary.
|
||||
|
||||
Returns:
|
||||
Dictionary with metrics summary
|
||||
"""
|
||||
total_requests = sum(self._request_counts.values())
|
||||
total_missing = sum(len(keys) for keys in self._missing_translations.values())
|
||||
|
||||
# Calculate timing statistics
|
||||
timing_stats = self._calculate_timing_stats()
|
||||
|
||||
# Calculate uptime
|
||||
uptime_seconds = (datetime.now() - self._start_time).total_seconds()
|
||||
|
||||
return {
|
||||
"uptime_seconds": round(uptime_seconds, 2),
|
||||
"total_requests": total_requests,
|
||||
"requests_per_locale": dict(self._request_counts),
|
||||
"total_missing_translations": total_missing,
|
||||
"missing_by_locale": {
|
||||
locale: len(keys)
|
||||
for locale, keys in self._missing_translations.items()
|
||||
},
|
||||
"timing": timing_stats,
|
||||
"locale_usage": dict(self._locale_usage),
|
||||
"namespace_usage": dict(self._namespace_usage),
|
||||
"error_counts": dict(self._error_counts)
|
||||
}
|
||||
|
||||
def _calculate_timing_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate timing statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with timing statistics
|
||||
"""
|
||||
if not self._timing_data:
|
||||
return {
|
||||
"count": 0,
|
||||
"avg_ms": 0,
|
||||
"min_ms": 0,
|
||||
"max_ms": 0,
|
||||
"p50_ms": 0,
|
||||
"p95_ms": 0,
|
||||
"p99_ms": 0
|
||||
}
|
||||
|
||||
durations = [d["duration_ms"] for d in self._timing_data]
|
||||
durations.sort()
|
||||
|
||||
count = len(durations)
|
||||
avg = sum(durations) / count
|
||||
|
||||
# Calculate percentiles
|
||||
p50_idx = int(count * 0.50)
|
||||
p95_idx = int(count * 0.95)
|
||||
p99_idx = int(count * 0.99)
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"avg_ms": round(avg, 3),
|
||||
"min_ms": round(durations[0], 3),
|
||||
"max_ms": round(durations[-1], 3),
|
||||
"p50_ms": round(durations[p50_idx], 3),
|
||||
"p95_ms": round(durations[p95_idx], 3),
|
||||
"p99_ms": round(durations[p99_idx], 3)
|
||||
}
|
||||
|
||||
def get_missing_translations(self, locale: Optional[str] = None) -> Dict[str, list]:
|
||||
"""
|
||||
Get missing translations.
|
||||
|
||||
Args:
|
||||
locale: Specific locale (optional, returns all if None)
|
||||
|
||||
Returns:
|
||||
Dictionary of missing translations by locale
|
||||
"""
|
||||
if locale:
|
||||
return {locale: list(self._missing_translations.get(locale, set()))}
|
||||
|
||||
return {
|
||||
locale: list(keys)
|
||||
for locale, keys in self._missing_translations.items()
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Reset all metrics."""
|
||||
self._request_counts.clear()
|
||||
self._missing_translations.clear()
|
||||
self._timing_data.clear()
|
||||
self._locale_usage.clear()
|
||||
self._namespace_usage.clear()
|
||||
self._error_counts.clear()
|
||||
self._start_time = datetime.now()
|
||||
logger.info("Metrics reset")
|
||||
|
||||
def export_prometheus(self) -> str:
|
||||
"""
|
||||
Export metrics in Prometheus format.
|
||||
|
||||
Returns:
|
||||
Prometheus-formatted metrics string
|
||||
"""
|
||||
lines = []
|
||||
|
||||
# Translation requests counter
|
||||
lines.append("# HELP i18n_translation_requests_total Total number of translation requests")
|
||||
lines.append("# TYPE i18n_translation_requests_total counter")
|
||||
for locale, count in self._request_counts.items():
|
||||
lines.append(f'i18n_translation_requests_total{{locale="{locale}"}} {count}')
|
||||
|
||||
# Missing translations counter
|
||||
lines.append("# HELP i18n_missing_translations_total Total number of missing translations")
|
||||
lines.append("# TYPE i18n_missing_translations_total counter")
|
||||
for locale, keys in self._missing_translations.items():
|
||||
lines.append(f'i18n_missing_translations_total{{locale="{locale}"}} {len(keys)}')
|
||||
|
||||
# Timing metrics
|
||||
timing_stats = self._calculate_timing_stats()
|
||||
lines.append("# HELP i18n_translation_duration_ms Translation operation duration in milliseconds")
|
||||
lines.append("# TYPE i18n_translation_duration_ms summary")
|
||||
lines.append(f'i18n_translation_duration_ms{{quantile="0.5"}} {timing_stats["p50_ms"]}')
|
||||
lines.append(f'i18n_translation_duration_ms{{quantile="0.95"}} {timing_stats["p95_ms"]}')
|
||||
lines.append(f'i18n_translation_duration_ms{{quantile="0.99"}} {timing_stats["p99_ms"]}')
|
||||
lines.append(f'i18n_translation_duration_ms_sum {sum(d["duration_ms"] for d in self._timing_data)}')
|
||||
lines.append(f'i18n_translation_duration_ms_count {timing_stats["count"]}')
|
||||
|
||||
# Error counter
|
||||
lines.append("# HELP i18n_errors_total Total number of i18n errors")
|
||||
lines.append("# TYPE i18n_errors_total counter")
|
||||
for error_type, count in self._error_counts.items():
|
||||
lines.append(f'i18n_errors_total{{type="{error_type}"}} {count}')
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# Global metrics instance
|
||||
_metrics: Optional[TranslationMetrics] = None
|
||||
|
||||
|
||||
def get_metrics() -> TranslationMetrics:
|
||||
"""
|
||||
Get the global metrics instance.
|
||||
|
||||
Returns:
|
||||
TranslationMetrics singleton
|
||||
"""
|
||||
global _metrics
|
||||
if _metrics is None:
|
||||
_metrics = TranslationMetrics()
|
||||
return _metrics
|
||||
|
||||
|
||||
def monitor_performance(operation: str = "translate"):
|
||||
"""
|
||||
Decorator to monitor translation operation performance.
|
||||
|
||||
Args:
|
||||
operation: Operation name for metrics
|
||||
|
||||
Returns:
|
||||
Decorated function
|
||||
|
||||
Example:
|
||||
@monitor_performance("translate")
|
||||
def translate(key: str, locale: str) -> str:
|
||||
...
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Record timing
|
||||
duration_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
# Try to extract locale from args/kwargs
|
||||
locale = kwargs.get("locale", "unknown")
|
||||
if not locale and len(args) > 1:
|
||||
locale = args[1] if isinstance(args[1], str) else "unknown"
|
||||
|
||||
metrics = get_metrics()
|
||||
metrics.record_timing(duration_ms, locale, operation)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Record error
|
||||
metrics = get_metrics()
|
||||
metrics.record_error(type(e).__name__)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def track_missing_translation(key: str, locale: str):
|
||||
"""
|
||||
Track a missing translation.
|
||||
|
||||
Args:
|
||||
key: Translation key
|
||||
locale: Locale code
|
||||
"""
|
||||
metrics = get_metrics()
|
||||
metrics.record_missing(key, locale)
|
||||
|
||||
|
||||
def track_translation_request(locale: str, namespace: str = None):
|
||||
"""
|
||||
Track a translation request.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
namespace: Translation namespace (optional)
|
||||
"""
|
||||
metrics = get_metrics()
|
||||
metrics.record_request(locale, namespace)
|
||||
202
api/app/i18n/middleware.py
Normal file
202
api/app/i18n/middleware.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Language detection middleware for i18n system.
|
||||
|
||||
This middleware determines the language to use for each request based on:
|
||||
1. Query parameter (?lang=en)
|
||||
2. Accept-Language HTTP header
|
||||
3. User language preference (from database)
|
||||
4. Tenant default language
|
||||
5. System default language
|
||||
|
||||
The detected language is injected into request.state.language and
|
||||
added to the response Content-Language header.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LanguageMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Language detection middleware.
|
||||
|
||||
Determines the language for each request based on multiple sources
|
||||
with a clear priority order, validates the language is supported,
|
||||
and injects it into the request context.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
Process the request and determine the language.
|
||||
|
||||
Args:
|
||||
request: The incoming request
|
||||
call_next: The next middleware/handler in the chain
|
||||
|
||||
Returns:
|
||||
Response with Content-Language header added
|
||||
"""
|
||||
# Determine the language for this request
|
||||
language = await self._determine_language(request)
|
||||
|
||||
# Validate language is supported
|
||||
from app.core.config import settings
|
||||
if language not in settings.I18N_SUPPORTED_LANGUAGES:
|
||||
logger.warning(
|
||||
f"Unsupported language '{language}' requested, "
|
||||
f"falling back to default: {settings.I18N_DEFAULT_LANGUAGE}"
|
||||
)
|
||||
language = settings.I18N_DEFAULT_LANGUAGE
|
||||
|
||||
# Inject language into request state
|
||||
request.state.language = language
|
||||
|
||||
# Also set in context variable for exception handling
|
||||
from app.i18n.exceptions import set_current_locale
|
||||
set_current_locale(language)
|
||||
|
||||
logger.debug(f"Request language set to: {language}")
|
||||
|
||||
# Process the request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add Content-Language header to response
|
||||
response.headers["Content-Language"] = language
|
||||
|
||||
return response
|
||||
|
||||
async def _determine_language(self, request: Request) -> str:
|
||||
"""
|
||||
Determine the language to use based on priority order.
|
||||
|
||||
Priority:
|
||||
1. Query parameter (?lang=en)
|
||||
2. Accept-Language HTTP header
|
||||
3. User language preference (from database)
|
||||
4. Tenant default language
|
||||
5. System default language
|
||||
|
||||
Args:
|
||||
request: The incoming request
|
||||
|
||||
Returns:
|
||||
Language code (e.g., "zh", "en")
|
||||
"""
|
||||
from app.core.config import settings
|
||||
|
||||
# 1. Check query parameter (?lang=en)
|
||||
if "lang" in request.query_params:
|
||||
lang = request.query_params["lang"].strip().lower()
|
||||
if lang:
|
||||
logger.debug(f"Language from query parameter: {lang}")
|
||||
return lang
|
||||
|
||||
# 2. Check Accept-Language HTTP header
|
||||
if "Accept-Language" in request.headers:
|
||||
lang = self._parse_accept_language(
|
||||
request.headers["Accept-Language"]
|
||||
)
|
||||
if lang:
|
||||
logger.debug(f"Language from Accept-Language header: {lang}")
|
||||
return lang
|
||||
|
||||
# 3. Check user language preference (requires authentication)
|
||||
# Note: This assumes user is already loaded into request.state by auth middleware
|
||||
if hasattr(request.state, "user") and request.state.user:
|
||||
user = request.state.user
|
||||
if hasattr(user, "preferred_language") and user.preferred_language:
|
||||
logger.debug(
|
||||
f"Language from user preference: {user.preferred_language}"
|
||||
)
|
||||
return user.preferred_language
|
||||
|
||||
# 4. Check tenant default language
|
||||
# Note: This assumes tenant is already loaded into request.state
|
||||
if hasattr(request.state, "tenant") and request.state.tenant:
|
||||
tenant = request.state.tenant
|
||||
if hasattr(tenant, "default_language") and tenant.default_language:
|
||||
logger.debug(
|
||||
f"Language from tenant default: {tenant.default_language}"
|
||||
)
|
||||
return tenant.default_language
|
||||
|
||||
# 5. Fall back to system default language
|
||||
logger.debug(
|
||||
f"Using system default language: {settings.I18N_DEFAULT_LANGUAGE}"
|
||||
)
|
||||
return settings.I18N_DEFAULT_LANGUAGE
|
||||
|
||||
def _parse_accept_language(self, header: str) -> Optional[str]:
|
||||
"""
|
||||
Parse the Accept-Language HTTP header.
|
||||
|
||||
The Accept-Language header format:
|
||||
Accept-Language: zh-CN,zh;q=0.9,en;q=0.8,en-US;q=0.7
|
||||
|
||||
This method:
|
||||
1. Parses all language codes and their quality values
|
||||
2. Extracts the base language code (zh-CN -> zh)
|
||||
3. Sorts by quality value (higher first)
|
||||
4. Returns the first supported language
|
||||
|
||||
Args:
|
||||
header: Accept-Language header value
|
||||
|
||||
Returns:
|
||||
Language code if found and supported, None otherwise
|
||||
|
||||
Examples:
|
||||
_parse_accept_language("zh-CN,zh;q=0.9,en;q=0.8")
|
||||
# => "zh" (if zh is supported)
|
||||
|
||||
_parse_accept_language("en-US,en;q=0.9")
|
||||
# => "en" (if en is supported)
|
||||
"""
|
||||
from app.core.config import settings
|
||||
|
||||
if not header:
|
||||
return None
|
||||
|
||||
# Parse language preferences with quality values
|
||||
languages = []
|
||||
|
||||
for item in header.split(","):
|
||||
item = item.strip()
|
||||
if not item:
|
||||
continue
|
||||
|
||||
# Split language code and quality value
|
||||
parts = item.split(";")
|
||||
lang_code = parts[0].strip()
|
||||
|
||||
# Extract base language code (zh-CN -> zh, en-US -> en)
|
||||
base_lang = lang_code.split("-")[0].lower()
|
||||
|
||||
# Extract quality value (default: 1.0)
|
||||
quality = 1.0
|
||||
if len(parts) > 1:
|
||||
# Look for q=0.9 pattern
|
||||
q_match = re.search(r"q=([\d.]+)", parts[1])
|
||||
if q_match:
|
||||
try:
|
||||
quality = float(q_match.group(1))
|
||||
except ValueError:
|
||||
quality = 1.0
|
||||
|
||||
languages.append((base_lang, quality))
|
||||
|
||||
# Sort by quality value (descending)
|
||||
languages.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Return the first supported language
|
||||
for lang_code, _ in languages:
|
||||
if lang_code in settings.I18N_SUPPORTED_LANGUAGES:
|
||||
return lang_code
|
||||
|
||||
return None
|
||||
221
api/app/i18n/serializers.py
Normal file
221
api/app/i18n/serializers.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
国际化响应序列化器
|
||||
|
||||
提供基础的 I18nResponseMixin 类,用于为 API 响应添加国际化字段。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class I18nResponseMixin:
|
||||
"""国际化响应混入类
|
||||
|
||||
为响应数据添加国际化字段,特别是为枚举值添加 _display 后缀的翻译字段。
|
||||
|
||||
使用方法:
|
||||
1. 继承此类
|
||||
2. 实现 _get_enum_fields() 方法定义需要翻译的枚举字段
|
||||
3. 调用 serialize_with_i18n() 方法序列化数据
|
||||
|
||||
示例:
|
||||
class WorkspaceSerializer(I18nResponseMixin):
|
||||
def _get_enum_fields(self) -> Dict[str, str]:
|
||||
return {
|
||||
"role": "workspace_role",
|
||||
"status": "workspace_status"
|
||||
}
|
||||
|
||||
def serialize(self, workspace: Workspace, locale: str = "zh") -> Dict:
|
||||
data = {
|
||||
"id": str(workspace.id),
|
||||
"name": workspace.name,
|
||||
"role": workspace.role,
|
||||
"status": workspace.status
|
||||
}
|
||||
return self.serialize_with_i18n(data, locale)
|
||||
"""
|
||||
|
||||
def serialize_with_i18n(
|
||||
self,
|
||||
data: Any,
|
||||
locale: str = "zh"
|
||||
) -> Union[Dict, List[Dict], Any]:
|
||||
"""序列化数据并添加国际化字段
|
||||
|
||||
Args:
|
||||
data: 要序列化的数据(字典、列表或 Pydantic 模型)
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
序列化后的数据,包含国际化字段
|
||||
"""
|
||||
# 如果是 Pydantic 模型,转换为字典
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump()
|
||||
|
||||
# 处理不同类型的数据
|
||||
if isinstance(data, dict):
|
||||
return self._serialize_dict(data, locale)
|
||||
elif isinstance(data, list):
|
||||
return [self._serialize_dict(item, locale) if isinstance(item, dict) else item for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
def _serialize_dict(self, data: Dict, locale: str) -> Dict:
|
||||
"""序列化字典并添加 _display 字段
|
||||
|
||||
Args:
|
||||
data: 字典数据
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
添加了 _display 字段的字典
|
||||
"""
|
||||
from app.i18n.service import get_translation_service
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
result = data.copy()
|
||||
|
||||
# 获取需要翻译的枚举字段
|
||||
enum_fields = self._get_enum_fields()
|
||||
|
||||
# 为每个枚举字段添加 _display 字段
|
||||
for field, enum_type in enum_fields.items():
|
||||
if field in result and result[field] is not None:
|
||||
value = result[field]
|
||||
# 翻译枚举值
|
||||
display_value = translation_service.translate_enum(
|
||||
enum_type=enum_type,
|
||||
value=str(value),
|
||||
locale=locale
|
||||
)
|
||||
# 添加 _display 字段
|
||||
result[f"{field}_display"] = display_value
|
||||
|
||||
return result
|
||||
|
||||
def _get_enum_fields(self) -> Dict[str, str]:
|
||||
"""获取需要翻译的枚举字段
|
||||
|
||||
子类必须实现此方法,返回字段名到枚举类型的映射。
|
||||
|
||||
Returns:
|
||||
字段名到枚举类型的映射
|
||||
例如: {"role": "workspace_role", "status": "workspace_status"}
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
class WorkspaceSerializer(I18nResponseMixin):
|
||||
"""工作空间序列化器
|
||||
|
||||
为工作空间响应添加国际化字段。
|
||||
"""
|
||||
|
||||
def _get_enum_fields(self) -> Dict[str, str]:
|
||||
"""定义工作空间的枚举字段"""
|
||||
return {
|
||||
"role": "workspace_role",
|
||||
"status": "workspace_status"
|
||||
}
|
||||
|
||||
def serialize(self, workspace_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
|
||||
"""序列化工作空间数据
|
||||
|
||||
Args:
|
||||
workspace_data: 工作空间数据(字典或 Pydantic 模型)
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
序列化后的工作空间数据,包含国际化字段
|
||||
"""
|
||||
return self.serialize_with_i18n(workspace_data, locale)
|
||||
|
||||
def serialize_list(self, workspaces: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
|
||||
"""序列化工作空间列表
|
||||
|
||||
Args:
|
||||
workspaces: 工作空间列表
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
序列化后的工作空间列表
|
||||
"""
|
||||
return [self.serialize(ws, locale) for ws in workspaces]
|
||||
|
||||
|
||||
class WorkspaceMemberSerializer(I18nResponseMixin):
|
||||
"""工作空间成员序列化器
|
||||
|
||||
为工作空间成员响应添加国际化字段。
|
||||
"""
|
||||
|
||||
def _get_enum_fields(self) -> Dict[str, str]:
|
||||
"""定义工作空间成员的枚举字段"""
|
||||
return {
|
||||
"role": "workspace_role"
|
||||
}
|
||||
|
||||
def serialize(self, member_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
|
||||
"""序列化工作空间成员数据
|
||||
|
||||
Args:
|
||||
member_data: 成员数据(字典或 Pydantic 模型)
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
序列化后的成员数据,包含国际化字段
|
||||
"""
|
||||
return self.serialize_with_i18n(member_data, locale)
|
||||
|
||||
def serialize_list(self, members: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
|
||||
"""序列化工作空间成员列表
|
||||
|
||||
Args:
|
||||
members: 成员列表
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
序列化后的成员列表
|
||||
"""
|
||||
return [self.serialize(member, locale) for member in members]
|
||||
|
||||
|
||||
class WorkspaceInviteSerializer(I18nResponseMixin):
|
||||
"""工作空间邀请序列化器
|
||||
|
||||
为工作空间邀请响应添加国际化字段。
|
||||
"""
|
||||
|
||||
def _get_enum_fields(self) -> Dict[str, str]:
|
||||
"""定义工作空间邀请的枚举字段"""
|
||||
return {
|
||||
"status": "invite_status",
|
||||
"role": "workspace_role"
|
||||
}
|
||||
|
||||
def serialize(self, invite_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
|
||||
"""序列化工作空间邀请数据
|
||||
|
||||
Args:
|
||||
invite_data: 邀请数据(字典或 Pydantic 模型)
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
序列化后的邀请数据,包含国际化字段
|
||||
"""
|
||||
return self.serialize_with_i18n(invite_data, locale)
|
||||
|
||||
def serialize_list(self, invites: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
|
||||
"""序列化工作空间邀请列表
|
||||
|
||||
Args:
|
||||
invites: 邀请列表
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
序列化后的邀请列表
|
||||
"""
|
||||
return [self.serialize(invite, locale) for invite in invites]
|
||||
370
api/app/i18n/service.py
Normal file
370
api/app/i18n/service.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
Translation service for i18n system.
|
||||
|
||||
This module provides the core translation functionality including:
|
||||
- Translation lookup with fallback mechanism
|
||||
- Parameterized message support
|
||||
- Enum value translation
|
||||
- Memory caching for performance
|
||||
- Performance monitoring and metrics
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from app.i18n.loader import TranslationLoader
|
||||
from app.i18n.cache import TranslationCache
|
||||
from app.i18n.metrics import get_metrics, monitor_performance, track_missing_translation, track_translation_request
|
||||
from app.i18n.logger import get_translation_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranslationService:
|
||||
"""
|
||||
Translation service that provides:
|
||||
- Fast translation lookup with memory cache
|
||||
- Parameterized message support ({param} syntax)
|
||||
- Fallback mechanism (current locale → default locale → key)
|
||||
- Enum value translation
|
||||
- Deep merge of multi-directory translations
|
||||
"""
|
||||
|
||||
def __init__(self, locales_dirs: Optional[list] = None):
|
||||
"""
|
||||
Initialize the translation service.
|
||||
|
||||
Args:
|
||||
locales_dirs: List of directories containing translation files.
|
||||
If None, will auto-detect from settings.
|
||||
"""
|
||||
from app.core.config import settings
|
||||
|
||||
self.loader = TranslationLoader(locales_dirs)
|
||||
self.default_locale = settings.I18N_DEFAULT_LANGUAGE
|
||||
self.fallback_locale = settings.I18N_FALLBACK_LANGUAGE
|
||||
self.log_missing = settings.I18N_LOG_MISSING_TRANSLATIONS
|
||||
self.enable_cache = settings.I18N_ENABLE_TRANSLATION_CACHE
|
||||
|
||||
# Initialize advanced cache with LRU
|
||||
lru_cache_size = getattr(settings, 'I18N_LRU_CACHE_SIZE', 1000)
|
||||
self.cache = TranslationCache(
|
||||
max_lru_size=lru_cache_size,
|
||||
enable_lazy_load=False # Load all at startup for now
|
||||
)
|
||||
|
||||
# Load all translations into cache
|
||||
self._load_all_locales()
|
||||
|
||||
# Initialize metrics
|
||||
self.metrics = get_metrics()
|
||||
|
||||
# Initialize translation logger
|
||||
self.translation_logger = get_translation_logger()
|
||||
|
||||
logger.info(
|
||||
f"TranslationService initialized with default locale: {self.default_locale}, "
|
||||
f"LRU cache size: {lru_cache_size}"
|
||||
)
|
||||
|
||||
def _load_all_locales(self):
|
||||
"""Load all available locales into memory cache."""
|
||||
available_locales = self.loader.get_available_locales()
|
||||
logger.info(f"Loading translations for locales: {available_locales}")
|
||||
|
||||
for locale in available_locales:
|
||||
locale_data = self.loader.load_locale(locale)
|
||||
self.cache.set_locale_data(locale, locale_data)
|
||||
|
||||
logger.info(f"Loaded {len(available_locales)} locales into cache")
|
||||
|
||||
@monitor_performance("translate")
|
||||
def translate(
|
||||
self,
|
||||
key: str,
|
||||
locale: Optional[str] = None,
|
||||
**params
|
||||
) -> str:
|
||||
"""
|
||||
Translate a key to the target locale.
|
||||
|
||||
Supports:
|
||||
- Dot-separated keys (e.g., "common.success.created")
|
||||
- Parameterized messages (e.g., "Hello {name}")
|
||||
- Fallback mechanism
|
||||
|
||||
Args:
|
||||
key: Translation key (format: "namespace.key.subkey")
|
||||
locale: Target locale (defaults to default locale)
|
||||
**params: Parameters for parameterized messages
|
||||
|
||||
Returns:
|
||||
Translated string, or the key itself if translation not found
|
||||
|
||||
Examples:
|
||||
translate("common.success.created", "zh")
|
||||
# => "创建成功"
|
||||
|
||||
translate("common.validation.required", "zh", field="名称")
|
||||
# => "名称不能为空"
|
||||
"""
|
||||
if locale is None:
|
||||
locale = self.default_locale
|
||||
|
||||
# Parse key (namespace.key.subkey)
|
||||
parts = key.split(".", 1)
|
||||
if len(parts) < 2:
|
||||
if self.log_missing:
|
||||
logger.warning(f"Invalid translation key format: {key}")
|
||||
return key
|
||||
|
||||
namespace = parts[0]
|
||||
key_path = parts[1].split(".")
|
||||
|
||||
# Track request
|
||||
track_translation_request(locale, namespace)
|
||||
|
||||
# Get translation from cache
|
||||
translation = self.cache.get_translation(locale, namespace, key_path)
|
||||
|
||||
# Fallback to default locale if not found
|
||||
if translation is None and locale != self.fallback_locale:
|
||||
translation = self.cache.get_translation(
|
||||
self.fallback_locale, namespace, key_path
|
||||
)
|
||||
|
||||
# If still not found, return the key itself
|
||||
if translation is None:
|
||||
if self.log_missing:
|
||||
logger.warning(
|
||||
f"Missing translation: {key} (locale: {locale})"
|
||||
)
|
||||
track_missing_translation(key, locale)
|
||||
|
||||
# Log to translation logger with context
|
||||
self.translation_logger.log_missing_translation(
|
||||
key=key,
|
||||
locale=locale,
|
||||
context={"namespace": namespace}
|
||||
)
|
||||
return key
|
||||
|
||||
# Apply parameters if provided
|
||||
if params:
|
||||
try:
|
||||
translation = translation.format(**params)
|
||||
except KeyError as e:
|
||||
error_msg = f"Missing parameter in translation '{key}': {e}"
|
||||
logger.error(error_msg)
|
||||
self.translation_logger.log_translation_error(
|
||||
error_type="parameter_missing",
|
||||
message=error_msg,
|
||||
key=key,
|
||||
locale=locale,
|
||||
context={"params": list(params.keys())}
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error formatting translation '{key}': {e}"
|
||||
logger.error(error_msg)
|
||||
self.translation_logger.log_translation_error(
|
||||
error_type="format_error",
|
||||
message=error_msg,
|
||||
key=key,
|
||||
locale=locale
|
||||
)
|
||||
|
||||
return translation
|
||||
|
||||
def _get_translation(
|
||||
self,
|
||||
locale: str,
|
||||
namespace: str,
|
||||
key_path: list
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get translation from cache (deprecated, use cache.get_translation).
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
namespace: Translation namespace
|
||||
key_path: List of nested keys
|
||||
|
||||
Returns:
|
||||
Translation string or None if not found
|
||||
"""
|
||||
return self.cache.get_translation(locale, namespace, key_path)
|
||||
|
||||
@monitor_performance("translate_enum")
|
||||
def translate_enum(
|
||||
self,
|
||||
enum_type: str,
|
||||
value: str,
|
||||
locale: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Translate an enum value.
|
||||
|
||||
Args:
|
||||
enum_type: Enum type name (e.g., "workspace_role")
|
||||
value: Enum value (e.g., "manager")
|
||||
locale: Target locale
|
||||
|
||||
Returns:
|
||||
Translated enum display name
|
||||
|
||||
Examples:
|
||||
translate_enum("workspace_role", "manager", "zh")
|
||||
# => "管理员"
|
||||
|
||||
translate_enum("invite_status", "pending", "en")
|
||||
# => "Pending"
|
||||
"""
|
||||
key = f"enums.{enum_type}.{value}"
|
||||
return self.translate(key, locale)
|
||||
|
||||
def has_translation(self, key: str, locale: str) -> bool:
|
||||
"""
|
||||
Check if a translation exists for the given key and locale.
|
||||
|
||||
Args:
|
||||
key: Translation key
|
||||
locale: Locale code
|
||||
|
||||
Returns:
|
||||
True if translation exists, False otherwise
|
||||
"""
|
||||
parts = key.split(".", 1)
|
||||
if len(parts) < 2:
|
||||
return False
|
||||
|
||||
namespace = parts[0]
|
||||
key_path = parts[1].split(".")
|
||||
|
||||
translation = self.cache.get_translation(locale, namespace, key_path)
|
||||
return translation is not None
|
||||
|
||||
def reload(self, locale: Optional[str] = None):
|
||||
"""
|
||||
Reload translation files.
|
||||
|
||||
Args:
|
||||
locale: Specific locale to reload. If None, reloads all locales.
|
||||
"""
|
||||
logger.info(f"Reloading translations for locale: {locale or 'all'}")
|
||||
|
||||
if locale:
|
||||
locale_data = self.loader.load_locale(locale)
|
||||
self.cache.set_locale_data(locale, locale_data)
|
||||
# Clear LRU cache for this locale
|
||||
self.cache.clear_locale(locale)
|
||||
else:
|
||||
self._load_all_locales()
|
||||
# Clear all LRU cache
|
||||
self.cache.clear_lru()
|
||||
|
||||
logger.info("Translation reload completed")
|
||||
|
||||
def get_available_locales(self) -> list:
|
||||
"""
|
||||
Get list of all available locales.
|
||||
|
||||
Returns:
|
||||
List of locale codes
|
||||
"""
|
||||
return self.cache.get_loaded_locales()
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
return self.cache.get_stats()
|
||||
|
||||
def get_metrics_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get metrics summary.
|
||||
|
||||
Returns:
|
||||
Dictionary with metrics summary
|
||||
"""
|
||||
return self.metrics.get_summary()
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get memory usage information.
|
||||
|
||||
Returns:
|
||||
Dictionary with memory usage information
|
||||
"""
|
||||
return self.cache.get_memory_usage()
|
||||
|
||||
def get_loaded_dirs(self) -> list:
|
||||
"""
|
||||
Get list of loaded translation directories.
|
||||
|
||||
Returns:
|
||||
List of directory paths
|
||||
"""
|
||||
return self.loader.locales_dirs
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_translation_service: Optional[TranslationService] = None
|
||||
|
||||
|
||||
def get_translation_service() -> TranslationService:
|
||||
"""
|
||||
Get the global translation service instance.
|
||||
|
||||
Returns:
|
||||
TranslationService singleton
|
||||
"""
|
||||
global _translation_service
|
||||
if _translation_service is None:
|
||||
_translation_service = TranslationService()
|
||||
return _translation_service
|
||||
|
||||
|
||||
# Convenience functions for easy access
|
||||
def t(key: str, locale: Optional[str] = None, **params) -> str:
|
||||
"""
|
||||
Translate a key (convenience function).
|
||||
|
||||
Args:
|
||||
key: Translation key
|
||||
locale: Target locale (optional, uses default if not provided)
|
||||
**params: Parameters for parameterized messages
|
||||
|
||||
Returns:
|
||||
Translated string
|
||||
|
||||
Examples:
|
||||
t("common.success.created")
|
||||
t("common.validation.required", field="名称")
|
||||
t("workspace.member_count", count=5)
|
||||
"""
|
||||
service = get_translation_service()
|
||||
return service.translate(key, locale, **params)
|
||||
|
||||
|
||||
def t_enum(enum_type: str, value: str, locale: Optional[str] = None) -> str:
|
||||
"""
|
||||
Translate an enum value (convenience function).
|
||||
|
||||
Args:
|
||||
enum_type: Enum type name
|
||||
value: Enum value
|
||||
locale: Target locale
|
||||
|
||||
Returns:
|
||||
Translated enum display name
|
||||
|
||||
Examples:
|
||||
t_enum("workspace_role", "manager")
|
||||
t_enum("invite_status", "pending", "en")
|
||||
"""
|
||||
service = get_translation_service()
|
||||
return service.translate_enum(enum_type, value, locale)
|
||||
26
api/app/locales/en/README.md
Normal file
26
api/app/locales/en/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# English Translation Files
|
||||
|
||||
This directory contains English translation files.
|
||||
|
||||
## File Structure
|
||||
|
||||
- `common.json` - Common translations (success messages, actions, validation)
|
||||
- `auth.json` - Authentication module translations
|
||||
- `workspace.json` - Workspace module translations
|
||||
- `tenant.json` - Tenant module translations
|
||||
- `errors.json` - Error message translations
|
||||
- `enums.json` - Enum value translations
|
||||
|
||||
## Translation File Format
|
||||
|
||||
All translation files use JSON format and support nested structures.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"success": {
|
||||
"created": "Created successfully",
|
||||
"updated": "Updated successfully"
|
||||
}
|
||||
}
|
||||
```
|
||||
55
api/app/locales/en/auth.json
Normal file
55
api/app/locales/en/auth.json
Normal file
@@ -0,0 +1,55 @@
|
||||
{
|
||||
"login": {
|
||||
"success": "Login successful",
|
||||
"failed": "Login failed",
|
||||
"invalid_credentials": "Invalid username or password",
|
||||
"account_locked": "Account has been locked",
|
||||
"account_disabled": "Account has been disabled"
|
||||
},
|
||||
"logout": {
|
||||
"success": "Logout successful",
|
||||
"failed": "Logout failed"
|
||||
},
|
||||
"token": {
|
||||
"refresh_success": "Token refreshed successfully",
|
||||
"invalid": "Invalid token",
|
||||
"expired": "Token has expired",
|
||||
"blacklisted": "Token has been invalidated",
|
||||
"invalid_refresh_token": "Invalid refresh token",
|
||||
"refresh_token_blacklisted": "Refresh token has been invalidated"
|
||||
},
|
||||
"registration": {
|
||||
"success": "Registration successful",
|
||||
"failed": "Registration failed",
|
||||
"email_exists": "Email already in use",
|
||||
"username_exists": "Username already taken"
|
||||
},
|
||||
"password": {
|
||||
"reset_success": "Password reset successful",
|
||||
"reset_failed": "Password reset failed",
|
||||
"change_success": "Password changed successfully",
|
||||
"change_failed": "Password change failed",
|
||||
"incorrect": "Incorrect password",
|
||||
"too_weak": "Password is too weak",
|
||||
"mismatch": "Passwords do not match"
|
||||
},
|
||||
"invite": {
|
||||
"invalid": "Invalid or expired invite code",
|
||||
"email_mismatch": "Invite email does not match login email",
|
||||
"accept_success": "Invite accepted successfully",
|
||||
"accept_failed": "Failed to accept invite",
|
||||
"password_verification_failed": "Failed to accept invite, password verification error",
|
||||
"bind_workspace_success": "Workspace bound successfully",
|
||||
"bind_workspace_failed": "Failed to bind workspace"
|
||||
},
|
||||
"user": {
|
||||
"not_found": "User not found",
|
||||
"already_exists": "User already exists",
|
||||
"created_with_invite": "User created successfully and joined workspace"
|
||||
},
|
||||
"session": {
|
||||
"expired": "Session expired, please login again",
|
||||
"invalid": "Invalid session",
|
||||
"single_session_enabled": "Single sign-on enabled, other device sessions will be logged out"
|
||||
}
|
||||
}
|
||||
132
api/app/locales/en/common.json
Normal file
132
api/app/locales/en/common.json
Normal file
@@ -0,0 +1,132 @@
|
||||
{
|
||||
"success": {
|
||||
"created": "Created successfully",
|
||||
"updated": "Updated successfully",
|
||||
"deleted": "Deleted successfully",
|
||||
"retrieved": "Retrieved successfully",
|
||||
"saved": "Saved successfully",
|
||||
"uploaded": "Uploaded successfully",
|
||||
"downloaded": "Downloaded successfully",
|
||||
"sent": "Sent successfully",
|
||||
"completed": "Completed",
|
||||
"confirmed": "Confirmed",
|
||||
"cancelled": "Cancelled",
|
||||
"archived": "Archived",
|
||||
"restored": "Restored"
|
||||
},
|
||||
"actions": {
|
||||
"create": "Create",
|
||||
"update": "Update",
|
||||
"delete": "Delete",
|
||||
"view": "View",
|
||||
"edit": "Edit",
|
||||
"save": "Save",
|
||||
"cancel": "Cancel",
|
||||
"confirm": "Confirm",
|
||||
"submit": "Submit",
|
||||
"upload": "Upload",
|
||||
"download": "Download",
|
||||
"send": "Send",
|
||||
"search": "Search",
|
||||
"filter": "Filter",
|
||||
"sort": "Sort",
|
||||
"export": "Export",
|
||||
"import": "Import",
|
||||
"refresh": "Refresh",
|
||||
"reset": "Reset",
|
||||
"back": "Back",
|
||||
"next": "Next",
|
||||
"previous": "Previous",
|
||||
"finish": "Finish",
|
||||
"close": "Close",
|
||||
"open": "Open",
|
||||
"archive": "Archive",
|
||||
"restore": "Restore",
|
||||
"duplicate": "Duplicate",
|
||||
"share": "Share",
|
||||
"invite": "Invite",
|
||||
"remove": "Remove",
|
||||
"add": "Add",
|
||||
"select": "Select",
|
||||
"clear": "Clear"
|
||||
},
|
||||
"validation": {
|
||||
"required": "{field} is required",
|
||||
"invalid_format": "{field} format is invalid",
|
||||
"too_long": "{field} cannot exceed {max} characters",
|
||||
"too_short": "{field} must be at least {min} characters",
|
||||
"invalid_email": "Invalid email format",
|
||||
"invalid_url": "Invalid URL format",
|
||||
"invalid_phone": "Invalid phone number format",
|
||||
"invalid_date": "Invalid date format",
|
||||
"invalid_number": "Must be a valid number",
|
||||
"out_of_range": "{field} must be between {min} and {max}",
|
||||
"already_exists": "{field} already exists",
|
||||
"not_found": "{field} not found",
|
||||
"invalid_value": "Invalid value for {field}",
|
||||
"password_mismatch": "Passwords do not match",
|
||||
"weak_password": "Password is too weak, please use a stronger password",
|
||||
"invalid_credentials": "Invalid username or password",
|
||||
"unauthorized": "Unauthorized access",
|
||||
"forbidden": "Permission denied",
|
||||
"expired": "{field} has expired",
|
||||
"invalid_token": "Invalid token",
|
||||
"file_too_large": "File size cannot exceed {max}",
|
||||
"invalid_file_type": "Unsupported file type",
|
||||
"duplicate": "Duplicate {field}"
|
||||
},
|
||||
"status": {
|
||||
"active": "Active",
|
||||
"inactive": "Inactive",
|
||||
"pending": "Pending",
|
||||
"processing": "Processing",
|
||||
"completed": "Completed",
|
||||
"failed": "Failed",
|
||||
"cancelled": "Cancelled",
|
||||
"archived": "Archived",
|
||||
"deleted": "Deleted",
|
||||
"draft": "Draft",
|
||||
"published": "Published",
|
||||
"suspended": "Suspended",
|
||||
"expired": "Expired"
|
||||
},
|
||||
"messages": {
|
||||
"loading": "Loading...",
|
||||
"saving": "Saving...",
|
||||
"processing": "Processing...",
|
||||
"uploading": "Uploading...",
|
||||
"downloading": "Downloading...",
|
||||
"no_data": "No data available",
|
||||
"no_results": "No results found",
|
||||
"confirm_delete": "Are you sure you want to delete? This action cannot be undone.",
|
||||
"confirm_action": "Are you sure you want to perform this action?",
|
||||
"operation_success": "Operation successful",
|
||||
"operation_failed": "Operation failed",
|
||||
"please_wait": "Please wait...",
|
||||
"try_again": "Please try again",
|
||||
"contact_support": "If the problem persists, please contact support"
|
||||
},
|
||||
"pagination": {
|
||||
"page": "Page {page}",
|
||||
"of": "of {total}",
|
||||
"items": "{total} items",
|
||||
"per_page": "{count} per page",
|
||||
"showing": "Showing {from} to {to} of {total}",
|
||||
"first": "First",
|
||||
"last": "Last",
|
||||
"next": "Next",
|
||||
"previous": "Previous"
|
||||
},
|
||||
"time": {
|
||||
"just_now": "Just now",
|
||||
"minutes_ago": "{count} minutes ago",
|
||||
"hours_ago": "{count} hours ago",
|
||||
"days_ago": "{count} days ago",
|
||||
"weeks_ago": "{count} weeks ago",
|
||||
"months_ago": "{count} months ago",
|
||||
"years_ago": "{count} years ago",
|
||||
"today": "Today",
|
||||
"yesterday": "Yesterday",
|
||||
"tomorrow": "Tomorrow"
|
||||
}
|
||||
}
|
||||
132
api/app/locales/en/enums.json
Normal file
132
api/app/locales/en/enums.json
Normal file
@@ -0,0 +1,132 @@
|
||||
{
|
||||
"workspace_role": {
|
||||
"owner": "Owner",
|
||||
"manager": "Manager",
|
||||
"member": "Member",
|
||||
"guest": "Guest"
|
||||
},
|
||||
"workspace_status": {
|
||||
"active": "Active",
|
||||
"inactive": "Inactive",
|
||||
"archived": "Archived",
|
||||
"suspended": "Suspended",
|
||||
"deleted": "Deleted"
|
||||
},
|
||||
"invite_status": {
|
||||
"pending": "Pending",
|
||||
"accepted": "Accepted",
|
||||
"rejected": "Rejected",
|
||||
"revoked": "Revoked",
|
||||
"expired": "Expired"
|
||||
},
|
||||
"user_status": {
|
||||
"active": "Active",
|
||||
"inactive": "Inactive",
|
||||
"suspended": "Suspended",
|
||||
"deleted": "Deleted",
|
||||
"pending": "Pending"
|
||||
},
|
||||
"tenant_status": {
|
||||
"active": "Active",
|
||||
"inactive": "Inactive",
|
||||
"suspended": "Suspended",
|
||||
"expired": "Expired",
|
||||
"trial": "Trial"
|
||||
},
|
||||
"file_status": {
|
||||
"uploading": "Uploading",
|
||||
"processing": "Processing",
|
||||
"completed": "Completed",
|
||||
"failed": "Failed",
|
||||
"deleted": "Deleted"
|
||||
},
|
||||
"task_status": {
|
||||
"pending": "Pending",
|
||||
"running": "Running",
|
||||
"completed": "Completed",
|
||||
"failed": "Failed",
|
||||
"cancelled": "Cancelled",
|
||||
"paused": "Paused"
|
||||
},
|
||||
"priority": {
|
||||
"low": "Low",
|
||||
"medium": "Medium",
|
||||
"high": "High",
|
||||
"urgent": "Urgent"
|
||||
},
|
||||
"visibility": {
|
||||
"public": "Public",
|
||||
"private": "Private",
|
||||
"internal": "Internal",
|
||||
"shared": "Shared"
|
||||
},
|
||||
"permission": {
|
||||
"read": "Read",
|
||||
"write": "Write",
|
||||
"delete": "Delete",
|
||||
"admin": "Admin",
|
||||
"owner": "Owner"
|
||||
},
|
||||
"notification_type": {
|
||||
"info": "Info",
|
||||
"warning": "Warning",
|
||||
"error": "Error",
|
||||
"success": "Success"
|
||||
},
|
||||
"language": {
|
||||
"zh": "Chinese (Simplified)",
|
||||
"en": "English",
|
||||
"ja": "Japanese",
|
||||
"ko": "Korean",
|
||||
"fr": "French",
|
||||
"de": "German",
|
||||
"es": "Spanish"
|
||||
},
|
||||
"timezone": {
|
||||
"utc": "UTC",
|
||||
"asia_shanghai": "Asia/Shanghai",
|
||||
"asia_tokyo": "Asia/Tokyo",
|
||||
"america_new_york": "America/New_York",
|
||||
"europe_london": "Europe/London"
|
||||
},
|
||||
"date_format": {
|
||||
"short": "Short",
|
||||
"medium": "Medium",
|
||||
"long": "Long",
|
||||
"full": "Full"
|
||||
},
|
||||
"sort_order": {
|
||||
"asc": "Ascending",
|
||||
"desc": "Descending"
|
||||
},
|
||||
"filter_operator": {
|
||||
"equals": "Equals",
|
||||
"not_equals": "Not Equals",
|
||||
"contains": "Contains",
|
||||
"not_contains": "Not Contains",
|
||||
"starts_with": "Starts With",
|
||||
"ends_with": "Ends With",
|
||||
"greater_than": "Greater Than",
|
||||
"less_than": "Less Than",
|
||||
"greater_or_equal": "Greater or Equal",
|
||||
"less_or_equal": "Less or Equal",
|
||||
"in": "In",
|
||||
"not_in": "Not In",
|
||||
"is_null": "Is Null",
|
||||
"is_not_null": "Is Not Null"
|
||||
},
|
||||
"log_level": {
|
||||
"debug": "Debug",
|
||||
"info": "Info",
|
||||
"warning": "Warning",
|
||||
"error": "Error",
|
||||
"critical": "Critical"
|
||||
},
|
||||
"api_method": {
|
||||
"get": "GET",
|
||||
"post": "POST",
|
||||
"put": "PUT",
|
||||
"patch": "PATCH",
|
||||
"delete": "DELETE"
|
||||
}
|
||||
}
|
||||
138
api/app/locales/en/errors.json
Normal file
138
api/app/locales/en/errors.json
Normal file
@@ -0,0 +1,138 @@
|
||||
{
|
||||
"common": {
|
||||
"internal_error": "Internal server error",
|
||||
"network_error": "Network connection error",
|
||||
"timeout": "Request timeout",
|
||||
"service_unavailable": "Service temporarily unavailable",
|
||||
"bad_request": "Bad request parameters",
|
||||
"unauthorized": "Unauthorized access",
|
||||
"forbidden": "Access forbidden",
|
||||
"not_found": "Resource not found",
|
||||
"method_not_allowed": "Method not allowed",
|
||||
"conflict": "Resource conflict",
|
||||
"too_many_requests": "Too many requests, please try again later",
|
||||
"validation_failed": "Validation failed",
|
||||
"database_error": "Database operation failed",
|
||||
"file_operation_error": "File operation failed"
|
||||
},
|
||||
"auth": {
|
||||
"invalid_credentials": "Invalid username or password",
|
||||
"token_expired": "Session expired, please login again",
|
||||
"token_invalid": "Invalid authentication token",
|
||||
"token_missing": "Authentication token missing",
|
||||
"unauthorized": "Unauthorized access",
|
||||
"forbidden": "Permission denied",
|
||||
"account_locked": "Account has been locked",
|
||||
"account_disabled": "Account has been disabled",
|
||||
"account_not_verified": "Account not verified",
|
||||
"password_incorrect": "Incorrect password",
|
||||
"password_too_weak": "Password is too weak",
|
||||
"password_expired": "Password expired, please change it",
|
||||
"email_not_verified": "Email not verified",
|
||||
"phone_not_verified": "Phone number not verified",
|
||||
"verification_code_invalid": "Invalid verification code",
|
||||
"verification_code_expired": "Verification code expired",
|
||||
"login_failed": "Login failed",
|
||||
"logout_failed": "Logout failed",
|
||||
"session_expired": "Session expired",
|
||||
"already_logged_in": "Already logged in",
|
||||
"not_logged_in": "Not logged in"
|
||||
},
|
||||
"user": {
|
||||
"not_found": "User not found",
|
||||
"already_exists": "User already exists",
|
||||
"email_already_exists": "Email already in use",
|
||||
"phone_already_exists": "Phone number already in use",
|
||||
"username_already_exists": "Username already taken",
|
||||
"invalid_email": "Invalid email format",
|
||||
"invalid_phone": "Invalid phone number format",
|
||||
"invalid_username": "Invalid username format",
|
||||
"create_failed": "Failed to create user",
|
||||
"update_failed": "Failed to update user",
|
||||
"delete_failed": "Failed to delete user",
|
||||
"cannot_delete_self": "Cannot delete yourself",
|
||||
"cannot_update_self_role": "Cannot update your own role",
|
||||
"profile_update_failed": "Failed to update profile",
|
||||
"avatar_upload_failed": "Failed to upload avatar",
|
||||
"password_change_failed": "Failed to change password",
|
||||
"old_password_incorrect": "Old password is incorrect"
|
||||
},
|
||||
"workspace": {
|
||||
"not_found": "Workspace not found",
|
||||
"already_exists": "Workspace already exists",
|
||||
"name_required": "Workspace name is required",
|
||||
"name_too_long": "Workspace name is too long",
|
||||
"create_failed": "Failed to create workspace",
|
||||
"update_failed": "Failed to update workspace",
|
||||
"delete_failed": "Failed to delete workspace",
|
||||
"permission_denied": "Permission denied to access this workspace",
|
||||
"not_member": "Not a workspace member",
|
||||
"already_member": "Already a workspace member",
|
||||
"member_limit_reached": "Member limit reached",
|
||||
"cannot_leave_last_manager": "Cannot leave, you are the last manager",
|
||||
"cannot_remove_last_manager": "Cannot remove the last manager",
|
||||
"cannot_remove_self": "Cannot remove yourself",
|
||||
"invite_not_found": "Invite not found",
|
||||
"invite_expired": "Invite has expired",
|
||||
"invite_already_accepted": "Invite already accepted",
|
||||
"invite_already_revoked": "Invite already revoked",
|
||||
"invite_send_failed": "Failed to send invite",
|
||||
"archived": "Workspace is archived",
|
||||
"suspended": "Workspace is suspended"
|
||||
},
|
||||
"tenant": {
|
||||
"not_found": "Tenant not found",
|
||||
"already_exists": "Tenant already exists",
|
||||
"create_failed": "Failed to create tenant",
|
||||
"update_failed": "Failed to update tenant",
|
||||
"delete_failed": "Failed to delete tenant",
|
||||
"suspended": "Tenant is suspended",
|
||||
"expired": "Tenant has expired",
|
||||
"license_invalid": "Invalid license",
|
||||
"license_expired": "License has expired",
|
||||
"quota_exceeded": "Quota exceeded"
|
||||
},
|
||||
"file": {
|
||||
"not_found": "File not found",
|
||||
"upload_failed": "File upload failed",
|
||||
"download_failed": "File download failed",
|
||||
"delete_failed": "File deletion failed",
|
||||
"too_large": "File size exceeds limit",
|
||||
"invalid_type": "Unsupported file type",
|
||||
"invalid_format": "Invalid file format",
|
||||
"corrupted": "File is corrupted",
|
||||
"storage_full": "Storage is full",
|
||||
"access_denied": "Access denied to this file"
|
||||
},
|
||||
"api": {
|
||||
"rate_limit_exceeded": "API rate limit exceeded",
|
||||
"quota_exceeded": "API quota exceeded",
|
||||
"invalid_api_key": "Invalid API key",
|
||||
"api_key_expired": "API key has expired",
|
||||
"api_key_revoked": "API key has been revoked",
|
||||
"endpoint_not_found": "API endpoint not found",
|
||||
"method_not_allowed": "Method not allowed",
|
||||
"invalid_request": "Invalid request",
|
||||
"missing_parameter": "Missing required parameter: {param}",
|
||||
"invalid_parameter": "Invalid parameter: {param}"
|
||||
},
|
||||
"database": {
|
||||
"connection_failed": "Database connection failed",
|
||||
"query_failed": "Database query failed",
|
||||
"transaction_failed": "Database transaction failed",
|
||||
"constraint_violation": "Data constraint violation",
|
||||
"duplicate_key": "Duplicate data",
|
||||
"foreign_key_violation": "Foreign key constraint violation",
|
||||
"deadlock": "Database deadlock"
|
||||
},
|
||||
"validation": {
|
||||
"invalid_input": "Invalid input data",
|
||||
"missing_field": "Missing required field: {field}",
|
||||
"invalid_field": "Invalid field: {field}",
|
||||
"field_too_long": "Field too long: {field}",
|
||||
"field_too_short": "Field too short: {field}",
|
||||
"invalid_format": "Invalid format: {field}",
|
||||
"invalid_value": "Invalid value: {field}",
|
||||
"out_of_range": "Value out of range: {field}"
|
||||
}
|
||||
}
|
||||
27
api/app/locales/en/i18n.json
Normal file
27
api/app/locales/en/i18n.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"language": {
|
||||
"not_found": "Language {locale} not found",
|
||||
"already_exists": "Language {locale} already exists",
|
||||
"add_instructions": "Language {locale} validated successfully. Please create translation files in {dir} directory to complete the addition.",
|
||||
"update_instructions": "Language {locale} update validated successfully. Please update I18N_SUPPORTED_LANGUAGES environment variable to apply configuration changes."
|
||||
},
|
||||
"namespace": {
|
||||
"not_found": "Namespace {namespace} not found in language {locale}"
|
||||
},
|
||||
"translation": {
|
||||
"invalid_key_format": "Invalid translation key format: {key}. Should use format: namespace.key.subkey",
|
||||
"update_instructions": "Translation {locale}/{key} update validated successfully. Please modify the corresponding JSON translation file to apply changes."
|
||||
},
|
||||
"reload": {
|
||||
"disabled": "Translation hot reload is disabled. Please enable I18N_ENABLE_HOT_RELOAD in configuration.",
|
||||
"success": "Translations reloaded successfully",
|
||||
"failed": "Translation reload failed: {error}"
|
||||
},
|
||||
"metrics": {
|
||||
"reset_success": "Performance metrics reset successfully"
|
||||
},
|
||||
"logs": {
|
||||
"export_success": "Missing translations exported to: {file}",
|
||||
"clear_success": "Missing translation logs cleared successfully"
|
||||
}
|
||||
}
|
||||
63
api/app/locales/en/tenant.json
Normal file
63
api/app/locales/en/tenant.json
Normal file
@@ -0,0 +1,63 @@
|
||||
{
|
||||
"info": {
|
||||
"get_success": "Tenant information retrieved successfully",
|
||||
"get_failed": "Failed to retrieve tenant information",
|
||||
"update_success": "Tenant information updated successfully",
|
||||
"update_failed": "Failed to update tenant information"
|
||||
},
|
||||
"create": {
|
||||
"success": "Tenant created successfully",
|
||||
"failed": "Failed to create tenant"
|
||||
},
|
||||
"delete": {
|
||||
"success": "Tenant deleted successfully",
|
||||
"failed": "Failed to delete tenant"
|
||||
},
|
||||
"status": {
|
||||
"activate_success": "Tenant activated successfully",
|
||||
"activate_failed": "Failed to activate tenant",
|
||||
"deactivate_success": "Tenant deactivated successfully",
|
||||
"deactivate_failed": "Failed to deactivate tenant"
|
||||
},
|
||||
"language": {
|
||||
"get_success": "Tenant language configuration retrieved successfully",
|
||||
"get_failed": "Failed to retrieve tenant language configuration",
|
||||
"update_success": "Tenant language configuration updated successfully",
|
||||
"update_failed": "Failed to update tenant language configuration",
|
||||
"invalid_language": "Unsupported language code",
|
||||
"default_not_in_supported": "Default language must be in the supported languages list"
|
||||
},
|
||||
"list": {
|
||||
"get_success": "Tenant list retrieved successfully",
|
||||
"get_failed": "Failed to retrieve tenant list"
|
||||
},
|
||||
"users": {
|
||||
"list_success": "Tenant user list retrieved successfully",
|
||||
"list_failed": "Failed to retrieve tenant user list",
|
||||
"assign_success": "User assigned to tenant successfully",
|
||||
"assign_failed": "Failed to assign user to tenant",
|
||||
"remove_success": "User removed from tenant successfully",
|
||||
"remove_failed": "Failed to remove user from tenant"
|
||||
},
|
||||
"statistics": {
|
||||
"get_success": "Tenant statistics retrieved successfully",
|
||||
"get_failed": "Failed to retrieve tenant statistics"
|
||||
},
|
||||
"validation": {
|
||||
"name_required": "Tenant name is required",
|
||||
"name_invalid": "Invalid tenant name format",
|
||||
"name_too_long": "Tenant name cannot exceed {max} characters",
|
||||
"description_too_long": "Tenant description cannot exceed {max} characters",
|
||||
"language_code_invalid": "Invalid language code format",
|
||||
"supported_languages_empty": "Supported languages list cannot be empty"
|
||||
},
|
||||
"errors": {
|
||||
"not_found": "Tenant not found",
|
||||
"already_exists": "Tenant name already exists",
|
||||
"permission_denied": "Permission denied to access this tenant",
|
||||
"has_users": "Cannot delete tenant, associated users exist",
|
||||
"has_workspaces": "Cannot delete tenant, associated workspaces exist",
|
||||
"already_active": "Tenant is already active",
|
||||
"already_inactive": "Tenant is already inactive"
|
||||
}
|
||||
}
|
||||
72
api/app/locales/en/users.json
Normal file
72
api/app/locales/en/users.json
Normal file
@@ -0,0 +1,72 @@
|
||||
{
|
||||
"info": {
|
||||
"get_success": "User information retrieved successfully",
|
||||
"get_failed": "Failed to retrieve user information",
|
||||
"update_success": "User information updated successfully",
|
||||
"update_failed": "Failed to update user information"
|
||||
},
|
||||
"create": {
|
||||
"success": "User created successfully",
|
||||
"failed": "Failed to create user",
|
||||
"superuser_success": "Superuser created successfully",
|
||||
"superuser_failed": "Failed to create superuser"
|
||||
},
|
||||
"delete": {
|
||||
"success": "User deleted successfully",
|
||||
"failed": "Failed to delete user",
|
||||
"deactivate_success": "User deactivated successfully",
|
||||
"deactivate_failed": "Failed to deactivate user"
|
||||
},
|
||||
"activate": {
|
||||
"success": "User activated successfully",
|
||||
"failed": "Failed to activate user"
|
||||
},
|
||||
"language": {
|
||||
"get_success": "Language preference retrieved successfully",
|
||||
"get_failed": "Failed to retrieve language preference",
|
||||
"update_success": "Language preference updated successfully",
|
||||
"update_failed": "Failed to update language preference",
|
||||
"invalid_language": "Unsupported language code",
|
||||
"current": "Current language preference"
|
||||
},
|
||||
"email": {
|
||||
"change_success": "Email changed successfully",
|
||||
"change_failed": "Failed to change email",
|
||||
"code_sent": "Verification code has been sent to your email",
|
||||
"code_send_failed": "Failed to send verification code",
|
||||
"code_invalid": "Invalid or expired verification code",
|
||||
"already_exists": "Email already in use"
|
||||
},
|
||||
"list": {
|
||||
"get_success": "User list retrieved successfully",
|
||||
"get_failed": "Failed to retrieve user list",
|
||||
"superusers_success": "Tenant superuser list retrieved successfully",
|
||||
"superusers_failed": "Failed to retrieve tenant superuser list"
|
||||
},
|
||||
"validation": {
|
||||
"username_required": "Username is required",
|
||||
"username_invalid": "Invalid username format",
|
||||
"username_too_long": "Username cannot exceed {max} characters",
|
||||
"email_required": "Email is required",
|
||||
"email_invalid": "Invalid email format",
|
||||
"password_required": "Password is required",
|
||||
"password_too_short": "Password must be at least {min} characters",
|
||||
"password_too_long": "Password cannot exceed {max} characters",
|
||||
"old_password_required": "Old password is required",
|
||||
"new_password_required": "New password is required",
|
||||
"verification_code_required": "Verification code is required",
|
||||
"verification_code_invalid": "Invalid verification code format"
|
||||
},
|
||||
"errors": {
|
||||
"not_found": "User not found",
|
||||
"already_exists": "User already exists",
|
||||
"permission_denied": "Permission denied to access this user",
|
||||
"cannot_delete_self": "Cannot delete yourself",
|
||||
"cannot_deactivate_self": "Cannot deactivate yourself",
|
||||
"already_deactivated": "User is already deactivated",
|
||||
"already_activated": "User is already activated",
|
||||
"password_verification_failed": "Password verification failed",
|
||||
"old_password_incorrect": "Old password is incorrect",
|
||||
"same_as_old_password": "New password cannot be the same as old password"
|
||||
}
|
||||
}
|
||||
44
api/app/locales/en/workspace.json
Normal file
44
api/app/locales/en/workspace.json
Normal file
@@ -0,0 +1,44 @@
|
||||
{
|
||||
"list_retrieved": "Workspace list retrieved successfully",
|
||||
"created": "Workspace created successfully",
|
||||
"updated": "Workspace updated successfully",
|
||||
"deleted": "Workspace deleted successfully",
|
||||
"switched": "Workspace switched successfully",
|
||||
"not_found": "Workspace not found or access denied",
|
||||
"already_exists": "Workspace already exists",
|
||||
"permission_denied": "No permission to access this workspace",
|
||||
"name_required": "Workspace name is required",
|
||||
"invalid_name": "Invalid workspace name format",
|
||||
"members": {
|
||||
"list_retrieved": "Workspace members list retrieved successfully",
|
||||
"role_updated": "Member role updated successfully",
|
||||
"deleted": "Member deleted successfully",
|
||||
"not_found": "Member not found",
|
||||
"cannot_remove_self": "Cannot remove yourself",
|
||||
"cannot_remove_last_manager": "Cannot remove the last manager",
|
||||
"already_member": "User is already a workspace member"
|
||||
},
|
||||
"invites": {
|
||||
"created": "Invite created successfully",
|
||||
"list_retrieved": "Invite list retrieved successfully",
|
||||
"validated": "Invite validated successfully",
|
||||
"revoked": "Invite revoked successfully",
|
||||
"accepted": "Invite accepted",
|
||||
"not_found": "Invite not found",
|
||||
"expired": "Invite has expired",
|
||||
"already_used": "Invite has already been used",
|
||||
"invalid_token": "Invalid invite token",
|
||||
"email_required": "Email address is required",
|
||||
"invalid_email": "Invalid email address format"
|
||||
},
|
||||
"storage": {
|
||||
"type_retrieved": "Storage type retrieved successfully",
|
||||
"type_updated": "Storage type updated successfully",
|
||||
"invalid_type": "Invalid storage type"
|
||||
},
|
||||
"models": {
|
||||
"config_retrieved": "Model configuration retrieved successfully",
|
||||
"config_updated": "Model configuration updated successfully",
|
||||
"invalid_config": "Invalid model configuration"
|
||||
}
|
||||
}
|
||||
26
api/app/locales/zh/README.md
Normal file
26
api/app/locales/zh/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# 中文翻译文件
|
||||
|
||||
此目录包含中文(简体)的翻译文件。
|
||||
|
||||
## 文件结构
|
||||
|
||||
- `common.json` - 通用翻译(成功消息、操作、验证)
|
||||
- `auth.json` - 认证模块翻译
|
||||
- `workspace.json` - 工作空间模块翻译
|
||||
- `tenant.json` - 租户模块翻译
|
||||
- `errors.json` - 错误消息翻译
|
||||
- `enums.json` - 枚举值翻译
|
||||
|
||||
## 翻译文件格式
|
||||
|
||||
所有翻译文件使用 JSON 格式,支持嵌套结构。
|
||||
|
||||
示例:
|
||||
```json
|
||||
{
|
||||
"success": {
|
||||
"created": "创建成功",
|
||||
"updated": "更新成功"
|
||||
}
|
||||
}
|
||||
```
|
||||
55
api/app/locales/zh/auth.json
Normal file
55
api/app/locales/zh/auth.json
Normal file
@@ -0,0 +1,55 @@
|
||||
{
|
||||
"login": {
|
||||
"success": "登录成功",
|
||||
"failed": "登录失败",
|
||||
"invalid_credentials": "用户名或密码错误",
|
||||
"account_locked": "账户已被锁定",
|
||||
"account_disabled": "账户已被禁用"
|
||||
},
|
||||
"logout": {
|
||||
"success": "登出成功",
|
||||
"failed": "登出失败"
|
||||
},
|
||||
"token": {
|
||||
"refresh_success": "token刷新成功",
|
||||
"invalid": "无效的token",
|
||||
"expired": "token已过期",
|
||||
"blacklisted": "token已失效",
|
||||
"invalid_refresh_token": "无效的refresh token",
|
||||
"refresh_token_blacklisted": "Refresh token已失效"
|
||||
},
|
||||
"registration": {
|
||||
"success": "注册成功",
|
||||
"failed": "注册失败",
|
||||
"email_exists": "邮箱已被使用",
|
||||
"username_exists": "用户名已被使用"
|
||||
},
|
||||
"password": {
|
||||
"reset_success": "密码重置成功",
|
||||
"reset_failed": "密码重置失败",
|
||||
"change_success": "密码修改成功",
|
||||
"change_failed": "密码修改失败",
|
||||
"incorrect": "密码错误",
|
||||
"too_weak": "密码强度不够",
|
||||
"mismatch": "两次输入的密码不一致"
|
||||
},
|
||||
"invite": {
|
||||
"invalid": "邀请码无效或已过期",
|
||||
"email_mismatch": "邀请邮箱与登录邮箱不匹配",
|
||||
"accept_success": "接受邀请成功",
|
||||
"accept_failed": "接受邀请失败",
|
||||
"password_verification_failed": "接受邀请失败,密码验证错误",
|
||||
"bind_workspace_success": "绑定工作空间成功",
|
||||
"bind_workspace_failed": "绑定工作空间失败"
|
||||
},
|
||||
"user": {
|
||||
"not_found": "用户不存在",
|
||||
"already_exists": "用户已存在",
|
||||
"created_with_invite": "用户创建成功并已加入工作空间"
|
||||
},
|
||||
"session": {
|
||||
"expired": "会话已过期,请重新登录",
|
||||
"invalid": "无效的会话",
|
||||
"single_session_enabled": "单点登录已启用,其他设备的登录将被注销"
|
||||
}
|
||||
}
|
||||
132
api/app/locales/zh/common.json
Normal file
132
api/app/locales/zh/common.json
Normal file
@@ -0,0 +1,132 @@
|
||||
{
|
||||
"success": {
|
||||
"created": "创建成功",
|
||||
"updated": "更新成功",
|
||||
"deleted": "删除成功",
|
||||
"retrieved": "获取成功",
|
||||
"saved": "保存成功",
|
||||
"uploaded": "上传成功",
|
||||
"downloaded": "下载成功",
|
||||
"sent": "发送成功",
|
||||
"completed": "完成",
|
||||
"confirmed": "已确认",
|
||||
"cancelled": "已取消",
|
||||
"archived": "已归档",
|
||||
"restored": "已恢复"
|
||||
},
|
||||
"actions": {
|
||||
"create": "创建",
|
||||
"update": "更新",
|
||||
"delete": "删除",
|
||||
"view": "查看",
|
||||
"edit": "编辑",
|
||||
"save": "保存",
|
||||
"cancel": "取消",
|
||||
"confirm": "确认",
|
||||
"submit": "提交",
|
||||
"upload": "上传",
|
||||
"download": "下载",
|
||||
"send": "发送",
|
||||
"search": "搜索",
|
||||
"filter": "筛选",
|
||||
"sort": "排序",
|
||||
"export": "导出",
|
||||
"import": "导入",
|
||||
"refresh": "刷新",
|
||||
"reset": "重置",
|
||||
"back": "返回",
|
||||
"next": "下一步",
|
||||
"previous": "上一步",
|
||||
"finish": "完成",
|
||||
"close": "关闭",
|
||||
"open": "打开",
|
||||
"archive": "归档",
|
||||
"restore": "恢复",
|
||||
"duplicate": "复制",
|
||||
"share": "分享",
|
||||
"invite": "邀请",
|
||||
"remove": "移除",
|
||||
"add": "添加",
|
||||
"select": "选择",
|
||||
"clear": "清除"
|
||||
},
|
||||
"validation": {
|
||||
"required": "{field}不能为空",
|
||||
"invalid_format": "{field}格式不正确",
|
||||
"too_long": "{field}长度不能超过{max}个字符",
|
||||
"too_short": "{field}长度不能少于{min}个字符",
|
||||
"invalid_email": "邮箱格式不正确",
|
||||
"invalid_url": "URL格式不正确",
|
||||
"invalid_phone": "手机号格式不正确",
|
||||
"invalid_date": "日期格式不正确",
|
||||
"invalid_number": "必须是有效的数字",
|
||||
"out_of_range": "{field}必须在{min}和{max}之间",
|
||||
"already_exists": "{field}已存在",
|
||||
"not_found": "{field}不存在",
|
||||
"invalid_value": "{field}的值无效",
|
||||
"password_mismatch": "两次输入的密码不一致",
|
||||
"weak_password": "密码强度不够,请使用更复杂的密码",
|
||||
"invalid_credentials": "用户名或密码错误",
|
||||
"unauthorized": "未授权访问",
|
||||
"forbidden": "没有权限执行此操作",
|
||||
"expired": "{field}已过期",
|
||||
"invalid_token": "无效的令牌",
|
||||
"file_too_large": "文件大小不能超过{max}",
|
||||
"invalid_file_type": "不支持的文件类型",
|
||||
"duplicate": "重复的{field}"
|
||||
},
|
||||
"status": {
|
||||
"active": "活跃",
|
||||
"inactive": "未激活",
|
||||
"pending": "待处理",
|
||||
"processing": "处理中",
|
||||
"completed": "已完成",
|
||||
"failed": "失败",
|
||||
"cancelled": "已取消",
|
||||
"archived": "已归档",
|
||||
"deleted": "已删除",
|
||||
"draft": "草稿",
|
||||
"published": "已发布",
|
||||
"suspended": "已暂停",
|
||||
"expired": "已过期"
|
||||
},
|
||||
"messages": {
|
||||
"loading": "加载中...",
|
||||
"saving": "保存中...",
|
||||
"processing": "处理中...",
|
||||
"uploading": "上传中...",
|
||||
"downloading": "下载中...",
|
||||
"no_data": "暂无数据",
|
||||
"no_results": "没有找到结果",
|
||||
"confirm_delete": "确定要删除吗?此操作不可恢复。",
|
||||
"confirm_action": "确定要执行此操作吗?",
|
||||
"operation_success": "操作成功",
|
||||
"operation_failed": "操作失败",
|
||||
"please_wait": "请稍候...",
|
||||
"try_again": "请重试",
|
||||
"contact_support": "如果问题持续,请联系技术支持"
|
||||
},
|
||||
"pagination": {
|
||||
"page": "第{page}页",
|
||||
"of": "共{total}页",
|
||||
"items": "共{total}条",
|
||||
"per_page": "每页{count}条",
|
||||
"showing": "显示第{from}到第{to}条,共{total}条",
|
||||
"first": "首页",
|
||||
"last": "末页",
|
||||
"next": "下一页",
|
||||
"previous": "上一页"
|
||||
},
|
||||
"time": {
|
||||
"just_now": "刚刚",
|
||||
"minutes_ago": "{count}分钟前",
|
||||
"hours_ago": "{count}小时前",
|
||||
"days_ago": "{count}天前",
|
||||
"weeks_ago": "{count}周前",
|
||||
"months_ago": "{count}个月前",
|
||||
"years_ago": "{count}年前",
|
||||
"today": "今天",
|
||||
"yesterday": "昨天",
|
||||
"tomorrow": "明天"
|
||||
}
|
||||
}
|
||||
132
api/app/locales/zh/enums.json
Normal file
132
api/app/locales/zh/enums.json
Normal file
@@ -0,0 +1,132 @@
|
||||
{
|
||||
"workspace_role": {
|
||||
"owner": "所有者",
|
||||
"manager": "管理员",
|
||||
"member": "成员",
|
||||
"guest": "访客"
|
||||
},
|
||||
"workspace_status": {
|
||||
"active": "活跃",
|
||||
"inactive": "未激活",
|
||||
"archived": "已归档",
|
||||
"suspended": "已暂停",
|
||||
"deleted": "已删除"
|
||||
},
|
||||
"invite_status": {
|
||||
"pending": "待处理",
|
||||
"accepted": "已接受",
|
||||
"rejected": "已拒绝",
|
||||
"revoked": "已撤销",
|
||||
"expired": "已过期"
|
||||
},
|
||||
"user_status": {
|
||||
"active": "活跃",
|
||||
"inactive": "未激活",
|
||||
"suspended": "已暂停",
|
||||
"deleted": "已删除",
|
||||
"pending": "待激活"
|
||||
},
|
||||
"tenant_status": {
|
||||
"active": "活跃",
|
||||
"inactive": "未激活",
|
||||
"suspended": "已暂停",
|
||||
"expired": "已过期",
|
||||
"trial": "试用中"
|
||||
},
|
||||
"file_status": {
|
||||
"uploading": "上传中",
|
||||
"processing": "处理中",
|
||||
"completed": "已完成",
|
||||
"failed": "失败",
|
||||
"deleted": "已删除"
|
||||
},
|
||||
"task_status": {
|
||||
"pending": "待处理",
|
||||
"running": "运行中",
|
||||
"completed": "已完成",
|
||||
"failed": "失败",
|
||||
"cancelled": "已取消",
|
||||
"paused": "已暂停"
|
||||
},
|
||||
"priority": {
|
||||
"low": "低",
|
||||
"medium": "中",
|
||||
"high": "高",
|
||||
"urgent": "紧急"
|
||||
},
|
||||
"visibility": {
|
||||
"public": "公开",
|
||||
"private": "私有",
|
||||
"internal": "内部",
|
||||
"shared": "共享"
|
||||
},
|
||||
"permission": {
|
||||
"read": "读取",
|
||||
"write": "写入",
|
||||
"delete": "删除",
|
||||
"admin": "管理",
|
||||
"owner": "所有者"
|
||||
},
|
||||
"notification_type": {
|
||||
"info": "信息",
|
||||
"warning": "警告",
|
||||
"error": "错误",
|
||||
"success": "成功"
|
||||
},
|
||||
"language": {
|
||||
"zh": "中文(简体)",
|
||||
"en": "English",
|
||||
"ja": "日本語",
|
||||
"ko": "한국어",
|
||||
"fr": "Français",
|
||||
"de": "Deutsch",
|
||||
"es": "Español"
|
||||
},
|
||||
"timezone": {
|
||||
"utc": "UTC",
|
||||
"asia_shanghai": "亚洲/上海",
|
||||
"asia_tokyo": "亚洲/东京",
|
||||
"america_new_york": "美洲/纽约",
|
||||
"europe_london": "欧洲/伦敦"
|
||||
},
|
||||
"date_format": {
|
||||
"short": "短日期",
|
||||
"medium": "中等日期",
|
||||
"long": "长日期",
|
||||
"full": "完整日期"
|
||||
},
|
||||
"sort_order": {
|
||||
"asc": "升序",
|
||||
"desc": "降序"
|
||||
},
|
||||
"filter_operator": {
|
||||
"equals": "等于",
|
||||
"not_equals": "不等于",
|
||||
"contains": "包含",
|
||||
"not_contains": "不包含",
|
||||
"starts_with": "开始于",
|
||||
"ends_with": "结束于",
|
||||
"greater_than": "大于",
|
||||
"less_than": "小于",
|
||||
"greater_or_equal": "大于等于",
|
||||
"less_or_equal": "小于等于",
|
||||
"in": "在列表中",
|
||||
"not_in": "不在列表中",
|
||||
"is_null": "为空",
|
||||
"is_not_null": "不为空"
|
||||
},
|
||||
"log_level": {
|
||||
"debug": "调试",
|
||||
"info": "信息",
|
||||
"warning": "警告",
|
||||
"error": "错误",
|
||||
"critical": "严重"
|
||||
},
|
||||
"api_method": {
|
||||
"get": "GET",
|
||||
"post": "POST",
|
||||
"put": "PUT",
|
||||
"patch": "PATCH",
|
||||
"delete": "DELETE"
|
||||
}
|
||||
}
|
||||
138
api/app/locales/zh/errors.json
Normal file
138
api/app/locales/zh/errors.json
Normal file
@@ -0,0 +1,138 @@
|
||||
{
|
||||
"common": {
|
||||
"internal_error": "服务器内部错误",
|
||||
"network_error": "网络连接错误",
|
||||
"timeout": "请求超时",
|
||||
"service_unavailable": "服务暂时不可用",
|
||||
"bad_request": "请求参数错误",
|
||||
"unauthorized": "未授权访问",
|
||||
"forbidden": "没有权限访问",
|
||||
"not_found": "请求的资源不存在",
|
||||
"method_not_allowed": "不支持的请求方法",
|
||||
"conflict": "资源冲突",
|
||||
"too_many_requests": "请求过于频繁,请稍后再试",
|
||||
"validation_failed": "数据验证失败",
|
||||
"database_error": "数据库操作失败",
|
||||
"file_operation_error": "文件操作失败"
|
||||
},
|
||||
"auth": {
|
||||
"invalid_credentials": "用户名或密码错误",
|
||||
"token_expired": "登录已过期,请重新登录",
|
||||
"token_invalid": "无效的登录令牌",
|
||||
"token_missing": "缺少登录令牌",
|
||||
"unauthorized": "未授权访问",
|
||||
"forbidden": "没有权限执行此操作",
|
||||
"account_locked": "账户已被锁定",
|
||||
"account_disabled": "账户已被禁用",
|
||||
"account_not_verified": "账户未验证",
|
||||
"password_incorrect": "密码错误",
|
||||
"password_too_weak": "密码强度不够",
|
||||
"password_expired": "密码已过期,请修改密码",
|
||||
"email_not_verified": "邮箱未验证",
|
||||
"phone_not_verified": "手机号未验证",
|
||||
"verification_code_invalid": "验证码无效",
|
||||
"verification_code_expired": "验证码已过期",
|
||||
"login_failed": "登录失败",
|
||||
"logout_failed": "登出失败",
|
||||
"session_expired": "会话已过期",
|
||||
"already_logged_in": "已经登录",
|
||||
"not_logged_in": "未登录"
|
||||
},
|
||||
"user": {
|
||||
"not_found": "用户不存在",
|
||||
"already_exists": "用户已存在",
|
||||
"email_already_exists": "邮箱已被使用",
|
||||
"phone_already_exists": "手机号已被使用",
|
||||
"username_already_exists": "用户名已被使用",
|
||||
"invalid_email": "邮箱格式不正确",
|
||||
"invalid_phone": "手机号格式不正确",
|
||||
"invalid_username": "用户名格式不正确",
|
||||
"create_failed": "创建用户失败",
|
||||
"update_failed": "更新用户失败",
|
||||
"delete_failed": "删除用户失败",
|
||||
"cannot_delete_self": "不能删除自己",
|
||||
"cannot_update_self_role": "不能修改自己的角色",
|
||||
"profile_update_failed": "更新个人资料失败",
|
||||
"avatar_upload_failed": "上传头像失败",
|
||||
"password_change_failed": "修改密码失败",
|
||||
"old_password_incorrect": "原密码错误"
|
||||
},
|
||||
"workspace": {
|
||||
"not_found": "工作空间不存在",
|
||||
"already_exists": "工作空间已存在",
|
||||
"name_required": "工作空间名称不能为空",
|
||||
"name_too_long": "工作空间名称过长",
|
||||
"create_failed": "创建工作空间失败",
|
||||
"update_failed": "更新工作空间失败",
|
||||
"delete_failed": "删除工作空间失败",
|
||||
"permission_denied": "没有权限访问此工作空间",
|
||||
"not_member": "不是工作空间成员",
|
||||
"already_member": "已经是工作空间成员",
|
||||
"member_limit_reached": "成员数量已达上限",
|
||||
"cannot_leave_last_manager": "不能离开,您是最后一个管理员",
|
||||
"cannot_remove_last_manager": "不能移除最后一个管理员",
|
||||
"cannot_remove_self": "不能移除自己",
|
||||
"invite_not_found": "邀请不存在",
|
||||
"invite_expired": "邀请已过期",
|
||||
"invite_already_accepted": "邀请已被接受",
|
||||
"invite_already_revoked": "邀请已被撤销",
|
||||
"invite_send_failed": "发送邀请失败",
|
||||
"archived": "工作空间已归档",
|
||||
"suspended": "工作空间已暂停"
|
||||
},
|
||||
"tenant": {
|
||||
"not_found": "租户不存在",
|
||||
"already_exists": "租户已存在",
|
||||
"create_failed": "创建租户失败",
|
||||
"update_failed": "更新租户失败",
|
||||
"delete_failed": "删除租户失败",
|
||||
"suspended": "租户已暂停",
|
||||
"expired": "租户已过期",
|
||||
"license_invalid": "许可证无效",
|
||||
"license_expired": "许可证已过期",
|
||||
"quota_exceeded": "配额已超限"
|
||||
},
|
||||
"file": {
|
||||
"not_found": "文件不存在",
|
||||
"upload_failed": "文件上传失败",
|
||||
"download_failed": "文件下载失败",
|
||||
"delete_failed": "文件删除失败",
|
||||
"too_large": "文件大小超过限制",
|
||||
"invalid_type": "不支持的文件类型",
|
||||
"invalid_format": "文件格式不正确",
|
||||
"corrupted": "文件已损坏",
|
||||
"storage_full": "存储空间已满",
|
||||
"access_denied": "没有权限访问此文件"
|
||||
},
|
||||
"api": {
|
||||
"rate_limit_exceeded": "API调用频率超限",
|
||||
"quota_exceeded": "API调用配额已用完",
|
||||
"invalid_api_key": "无效的API密钥",
|
||||
"api_key_expired": "API密钥已过期",
|
||||
"api_key_revoked": "API密钥已被撤销",
|
||||
"endpoint_not_found": "API端点不存在",
|
||||
"method_not_allowed": "不支持的请求方法",
|
||||
"invalid_request": "无效的请求",
|
||||
"missing_parameter": "缺少必需参数:{param}",
|
||||
"invalid_parameter": "参数无效:{param}"
|
||||
},
|
||||
"database": {
|
||||
"connection_failed": "数据库连接失败",
|
||||
"query_failed": "数据库查询失败",
|
||||
"transaction_failed": "数据库事务失败",
|
||||
"constraint_violation": "数据约束冲突",
|
||||
"duplicate_key": "数据重复",
|
||||
"foreign_key_violation": "外键约束冲突",
|
||||
"deadlock": "数据库死锁"
|
||||
},
|
||||
"validation": {
|
||||
"invalid_input": "输入数据无效",
|
||||
"missing_field": "缺少必需字段:{field}",
|
||||
"invalid_field": "字段无效:{field}",
|
||||
"field_too_long": "字段过长:{field}",
|
||||
"field_too_short": "字段过短:{field}",
|
||||
"invalid_format": "格式不正确:{field}",
|
||||
"invalid_value": "值无效:{field}",
|
||||
"out_of_range": "值超出范围:{field}"
|
||||
}
|
||||
}
|
||||
27
api/app/locales/zh/i18n.json
Normal file
27
api/app/locales/zh/i18n.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"language": {
|
||||
"not_found": "语言 {locale} 不存在",
|
||||
"already_exists": "语言 {locale} 已存在",
|
||||
"add_instructions": "语言 {locale} 验证成功。请在 {dir} 目录下创建翻译文件以完成添加。",
|
||||
"update_instructions": "语言 {locale} 更新验证成功。请更新环境变量 I18N_SUPPORTED_LANGUAGES 以应用配置更改。"
|
||||
},
|
||||
"namespace": {
|
||||
"not_found": "命名空间 {namespace} 在语言 {locale} 中不存在"
|
||||
},
|
||||
"translation": {
|
||||
"invalid_key_format": "翻译键格式无效: {key}。应使用格式: namespace.key.subkey",
|
||||
"update_instructions": "翻译 {locale}/{key} 更新验证成功。请修改对应的 JSON 翻译文件以应用更改。"
|
||||
},
|
||||
"reload": {
|
||||
"disabled": "翻译热重载功能已禁用。请在配置中启用 I18N_ENABLE_HOT_RELOAD。",
|
||||
"success": "翻译重载成功",
|
||||
"failed": "翻译重载失败: {error}"
|
||||
},
|
||||
"metrics": {
|
||||
"reset_success": "性能指标已重置"
|
||||
},
|
||||
"logs": {
|
||||
"export_success": "缺失翻译已导出到: {file}",
|
||||
"clear_success": "缺失翻译日志已清除"
|
||||
}
|
||||
}
|
||||
63
api/app/locales/zh/tenant.json
Normal file
63
api/app/locales/zh/tenant.json
Normal file
@@ -0,0 +1,63 @@
|
||||
{
|
||||
"info": {
|
||||
"get_success": "租户信息获取成功",
|
||||
"get_failed": "租户信息获取失败",
|
||||
"update_success": "租户信息更新成功",
|
||||
"update_failed": "租户信息更新失败"
|
||||
},
|
||||
"create": {
|
||||
"success": "租户创建成功",
|
||||
"failed": "租户创建失败"
|
||||
},
|
||||
"delete": {
|
||||
"success": "租户删除成功",
|
||||
"failed": "租户删除失败"
|
||||
},
|
||||
"status": {
|
||||
"activate_success": "租户启用成功",
|
||||
"activate_failed": "租户启用失败",
|
||||
"deactivate_success": "租户禁用成功",
|
||||
"deactivate_failed": "租户禁用失败"
|
||||
},
|
||||
"language": {
|
||||
"get_success": "租户语言配置获取成功",
|
||||
"get_failed": "租户语言配置获取失败",
|
||||
"update_success": "租户语言配置更新成功",
|
||||
"update_failed": "租户语言配置更新失败",
|
||||
"invalid_language": "不支持的语言代码",
|
||||
"default_not_in_supported": "默认语言必须在支持的语言列表中"
|
||||
},
|
||||
"list": {
|
||||
"get_success": "租户列表获取成功",
|
||||
"get_failed": "租户列表获取失败"
|
||||
},
|
||||
"users": {
|
||||
"list_success": "租户用户列表获取成功",
|
||||
"list_failed": "租户用户列表获取失败",
|
||||
"assign_success": "用户分配到租户成功",
|
||||
"assign_failed": "用户分配到租户失败",
|
||||
"remove_success": "用户从租户移除成功",
|
||||
"remove_failed": "用户从租户移除失败"
|
||||
},
|
||||
"statistics": {
|
||||
"get_success": "租户统计信息获取成功",
|
||||
"get_failed": "租户统计信息获取失败"
|
||||
},
|
||||
"validation": {
|
||||
"name_required": "租户名称不能为空",
|
||||
"name_invalid": "租户名称格式不正确",
|
||||
"name_too_long": "租户名称长度不能超过{max}个字符",
|
||||
"description_too_long": "租户描述长度不能超过{max}个字符",
|
||||
"language_code_invalid": "语言代码格式不正确",
|
||||
"supported_languages_empty": "支持的语言列表不能为空"
|
||||
},
|
||||
"errors": {
|
||||
"not_found": "租户不存在",
|
||||
"already_exists": "租户名称已存在",
|
||||
"permission_denied": "没有权限访问此租户",
|
||||
"has_users": "无法删除租户,存在关联的用户",
|
||||
"has_workspaces": "无法删除租户,存在关联的工作空间",
|
||||
"already_active": "租户已处于激活状态",
|
||||
"already_inactive": "租户已处于禁用状态"
|
||||
}
|
||||
}
|
||||
72
api/app/locales/zh/users.json
Normal file
72
api/app/locales/zh/users.json
Normal file
@@ -0,0 +1,72 @@
|
||||
{
|
||||
"info": {
|
||||
"get_success": "用户信息获取成功",
|
||||
"get_failed": "用户信息获取失败",
|
||||
"update_success": "用户信息更新成功",
|
||||
"update_failed": "用户信息更新失败"
|
||||
},
|
||||
"create": {
|
||||
"success": "用户创建成功",
|
||||
"failed": "用户创建失败",
|
||||
"superuser_success": "超级管理员创建成功",
|
||||
"superuser_failed": "超级管理员创建失败"
|
||||
},
|
||||
"delete": {
|
||||
"success": "用户删除成功",
|
||||
"failed": "用户删除失败",
|
||||
"deactivate_success": "用户停用成功",
|
||||
"deactivate_failed": "用户停用失败"
|
||||
},
|
||||
"activate": {
|
||||
"success": "用户激活成功",
|
||||
"failed": "用户激活失败"
|
||||
},
|
||||
"language": {
|
||||
"get_success": "语言偏好获取成功",
|
||||
"get_failed": "语言偏好获取失败",
|
||||
"update_success": "语言偏好更新成功",
|
||||
"update_failed": "语言偏好更新失败",
|
||||
"invalid_language": "不支持的语言代码",
|
||||
"current": "当前语言偏好"
|
||||
},
|
||||
"email": {
|
||||
"change_success": "邮箱修改成功",
|
||||
"change_failed": "邮箱修改失败",
|
||||
"code_sent": "验证码已发送到您的邮箱,请查收",
|
||||
"code_send_failed": "验证码发送失败",
|
||||
"code_invalid": "验证码无效或已过期",
|
||||
"already_exists": "该邮箱已被使用"
|
||||
},
|
||||
"list": {
|
||||
"get_success": "用户列表获取成功",
|
||||
"get_failed": "用户列表获取失败",
|
||||
"superusers_success": "租户超管列表获取成功",
|
||||
"superusers_failed": "租户超管列表获取失败"
|
||||
},
|
||||
"validation": {
|
||||
"username_required": "用户名不能为空",
|
||||
"username_invalid": "用户名格式不正确",
|
||||
"username_too_long": "用户名长度不能超过{max}个字符",
|
||||
"email_required": "邮箱不能为空",
|
||||
"email_invalid": "邮箱格式不正确",
|
||||
"password_required": "密码不能为空",
|
||||
"password_too_short": "密码长度不能少于{min}个字符",
|
||||
"password_too_long": "密码长度不能超过{max}个字符",
|
||||
"old_password_required": "旧密码不能为空",
|
||||
"new_password_required": "新密码不能为空",
|
||||
"verification_code_required": "验证码不能为空",
|
||||
"verification_code_invalid": "验证码格式不正确"
|
||||
},
|
||||
"errors": {
|
||||
"not_found": "用户不存在",
|
||||
"already_exists": "用户已存在",
|
||||
"permission_denied": "没有权限访问此用户",
|
||||
"cannot_delete_self": "不能删除自己",
|
||||
"cannot_deactivate_self": "不能停用自己",
|
||||
"already_deactivated": "用户已被停用",
|
||||
"already_activated": "用户已处于激活状态",
|
||||
"password_verification_failed": "密码验证失败",
|
||||
"old_password_incorrect": "旧密码不正确",
|
||||
"same_as_old_password": "新密码不能与旧密码相同"
|
||||
}
|
||||
}
|
||||
44
api/app/locales/zh/workspace.json
Normal file
44
api/app/locales/zh/workspace.json
Normal file
@@ -0,0 +1,44 @@
|
||||
{
|
||||
"list_retrieved": "工作空间列表获取成功",
|
||||
"created": "工作空间创建成功",
|
||||
"updated": "工作空间更新成功",
|
||||
"deleted": "工作空间删除成功",
|
||||
"switched": "工作空间切换成功",
|
||||
"not_found": "工作空间不存在或无权访问",
|
||||
"already_exists": "工作空间已存在",
|
||||
"permission_denied": "没有权限访问此工作空间",
|
||||
"name_required": "工作空间名称不能为空",
|
||||
"invalid_name": "工作空间名称格式不正确",
|
||||
"members": {
|
||||
"list_retrieved": "工作空间成员列表获取成功",
|
||||
"role_updated": "成员角色更新成功",
|
||||
"deleted": "成员删除成功",
|
||||
"not_found": "成员不存在",
|
||||
"cannot_remove_self": "不能删除自己",
|
||||
"cannot_remove_last_manager": "不能删除最后一个管理员",
|
||||
"already_member": "用户已经是工作空间成员"
|
||||
},
|
||||
"invites": {
|
||||
"created": "邀请创建成功",
|
||||
"list_retrieved": "邀请列表获取成功",
|
||||
"validated": "邀请验证成功",
|
||||
"revoked": "邀请撤销成功",
|
||||
"accepted": "邀请已接受",
|
||||
"not_found": "邀请不存在",
|
||||
"expired": "邀请已过期",
|
||||
"already_used": "邀请已被使用",
|
||||
"invalid_token": "无效的邀请令牌",
|
||||
"email_required": "邮箱地址不能为空",
|
||||
"invalid_email": "邮箱地址格式不正确"
|
||||
},
|
||||
"storage": {
|
||||
"type_retrieved": "存储类型获取成功",
|
||||
"type_updated": "存储类型更新成功",
|
||||
"invalid_type": "无效的存储类型"
|
||||
},
|
||||
"models": {
|
||||
"config_retrieved": "模型配置获取成功",
|
||||
"config_updated": "模型配置更新成功",
|
||||
"invalid_config": "无效的模型配置"
|
||||
}
|
||||
}
|
||||
196
api/app/main.py
196
api/app/main.py
@@ -92,6 +92,10 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Add i18n language detection middleware
|
||||
from app.i18n.middleware import LanguageMiddleware
|
||||
app.add_middleware(LanguageMiddleware)
|
||||
|
||||
logger.info("FastAPI应用程序启动")
|
||||
|
||||
|
||||
@@ -129,6 +133,11 @@ from app.core.exceptions import (
|
||||
from app.core.sensitive_filter import SensitiveDataFilter
|
||||
import traceback
|
||||
|
||||
# Import i18n exception support
|
||||
from app.i18n.exceptions import I18nException
|
||||
from app.i18n.service import get_translation_service
|
||||
from pydantic import ValidationError as PydanticValidationError
|
||||
|
||||
|
||||
# 处理验证异常
|
||||
@app.exception_handler(ValidationException)
|
||||
@@ -156,6 +165,131 @@ async def validation_exception_handler(request: Request, exc: ValidationExceptio
|
||||
)
|
||||
|
||||
|
||||
# 处理 i18n 异常(国际化异常)
|
||||
@app.exception_handler(I18nException)
|
||||
async def i18n_exception_handler(request: Request, exc: I18nException):
|
||||
"""
|
||||
处理国际化异常
|
||||
|
||||
I18nException 已经自动翻译了错误消息,直接返回即可
|
||||
"""
|
||||
# 获取当前语言
|
||||
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
|
||||
|
||||
# 获取异常详情(已经包含翻译后的消息)
|
||||
detail = exc.detail
|
||||
|
||||
# 过滤敏感信息
|
||||
if isinstance(detail, dict):
|
||||
filtered_message = SensitiveDataFilter.filter_string(detail.get("message", ""))
|
||||
filtered_detail = {
|
||||
**detail,
|
||||
"message": filtered_message
|
||||
}
|
||||
else:
|
||||
filtered_detail = SensitiveDataFilter.filter_string(str(detail))
|
||||
|
||||
logger.warning(
|
||||
f"I18n exception: {exc.error_key}",
|
||||
extra={
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
"error_code": exc.error_code,
|
||||
"error_key": exc.error_key,
|
||||
"language": language,
|
||||
"status_code": exc.status_code,
|
||||
"params": exc.params
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"success": False,
|
||||
**filtered_detail
|
||||
},
|
||||
headers=exc.headers
|
||||
)
|
||||
|
||||
|
||||
# 处理 Pydantic 验证错误(国际化支持)
|
||||
@app.exception_handler(PydanticValidationError)
|
||||
async def pydantic_validation_exception_handler(request: Request, exc: PydanticValidationError):
|
||||
"""
|
||||
处理 Pydantic 验证错误,支持国际化
|
||||
"""
|
||||
# 获取当前语言
|
||||
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
|
||||
|
||||
# 获取翻译服务
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# 翻译验证错误消息
|
||||
errors = []
|
||||
for error in exc.errors():
|
||||
field = ".".join(str(loc) for loc in error["loc"])
|
||||
error_type = error["type"]
|
||||
|
||||
# 尝试翻译错误消息
|
||||
if error_type == "value_error.missing":
|
||||
message = translation_service.translate(
|
||||
"errors.validation.missing_field",
|
||||
language,
|
||||
field=field
|
||||
)
|
||||
elif error_type == "value_error.any_str.max_length":
|
||||
message = translation_service.translate(
|
||||
"errors.validation.field_too_long",
|
||||
language,
|
||||
field=field
|
||||
)
|
||||
elif error_type == "value_error.any_str.min_length":
|
||||
message = translation_service.translate(
|
||||
"errors.validation.field_too_short",
|
||||
language,
|
||||
field=field
|
||||
)
|
||||
else:
|
||||
# 使用通用验证错误消息
|
||||
message = translation_service.translate(
|
||||
"errors.validation.invalid_field",
|
||||
language,
|
||||
field=field
|
||||
)
|
||||
|
||||
errors.append({
|
||||
"field": field,
|
||||
"message": message,
|
||||
"type": error_type
|
||||
})
|
||||
|
||||
# 翻译主错误消息
|
||||
main_message = translation_service.translate(
|
||||
"errors.common.validation_failed",
|
||||
language
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Pydantic validation error: {len(errors)} errors",
|
||||
extra={
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
"language": language,
|
||||
"errors": errors
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
"success": False,
|
||||
"error_code": "VALIDATION_FAILED",
|
||||
"message": main_message,
|
||||
"errors": errors
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 处理资源不存在异常
|
||||
@app.exception_handler(ResourceNotFoundException)
|
||||
async def not_found_exception_handler(request: Request, exc: ResourceNotFoundException):
|
||||
@@ -354,31 +488,66 @@ async def business_exception_handler(request: Request, exc: BusinessException):
|
||||
)
|
||||
|
||||
|
||||
# 统一异常处理:将HTTPException转换为统一响应结构
|
||||
# 统一异常处理:将HTTPException转换为统一响应结构(支持国际化)
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
"""处理HTTP异常"""
|
||||
# 过滤敏感信息
|
||||
filtered_detail = SensitiveDataFilter.filter_string(str(exc.detail))
|
||||
|
||||
"""处理HTTP异常,支持国际化"""
|
||||
# 获取当前语言
|
||||
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
|
||||
|
||||
# 获取翻译服务
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# 尝试翻译标准HTTP错误
|
||||
error_key_map = {
|
||||
400: "errors.common.bad_request",
|
||||
401: "errors.common.unauthorized",
|
||||
403: "errors.common.forbidden",
|
||||
404: "errors.common.not_found",
|
||||
405: "errors.common.method_not_allowed",
|
||||
409: "errors.common.conflict",
|
||||
422: "errors.common.validation_failed",
|
||||
429: "errors.common.too_many_requests",
|
||||
500: "errors.common.internal_error",
|
||||
503: "errors.common.service_unavailable",
|
||||
}
|
||||
|
||||
# 如果有对应的翻译键,使用翻译
|
||||
if exc.status_code in error_key_map:
|
||||
translated_message = translation_service.translate(
|
||||
error_key_map[exc.status_code],
|
||||
language
|
||||
)
|
||||
else:
|
||||
# 否则过滤原始消息
|
||||
translated_message = SensitiveDataFilter.filter_string(str(exc.detail))
|
||||
|
||||
logger.warning(
|
||||
f"HTTP exception: {filtered_detail}",
|
||||
f"HTTP exception: {translated_message}",
|
||||
extra={
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
"status_code": exc.status_code
|
||||
"status_code": exc.status_code,
|
||||
"language": language
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=fail(code=exc.status_code, msg=filtered_detail, error=filtered_detail)
|
||||
content=fail(code=exc.status_code, msg=translated_message, error=translated_message)
|
||||
)
|
||||
|
||||
|
||||
# 捕获未处理的异常,返回统一错误结构
|
||||
# 捕获未处理的异常,返回统一错误结构(支持国际化)
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
"""处理未捕获的异常"""
|
||||
"""处理未捕获的异常,支持国际化"""
|
||||
# 获取当前语言
|
||||
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
|
||||
|
||||
# 获取翻译服务
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# 记录完整的堆栈跟踪(日志过滤器会自动过滤敏感信息)
|
||||
logger.error(
|
||||
f"Unhandled exception: {exc}",
|
||||
@@ -386,6 +555,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
"exception_type": type(exc).__name__,
|
||||
"language": language,
|
||||
"traceback": traceback.format_exc()
|
||||
},
|
||||
exc_info=True
|
||||
@@ -394,7 +564,11 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
# 生产环境隐藏详细错误信息
|
||||
environment = os.getenv("ENVIRONMENT", "development")
|
||||
if environment == "production":
|
||||
message = "服务器内部错误,请稍后重试"
|
||||
# 使用翻译的通用错误消息
|
||||
message = translation_service.translate(
|
||||
"errors.common.internal_error",
|
||||
language
|
||||
)
|
||||
else:
|
||||
# 开发环境也要过滤敏感信息
|
||||
message = SensitiveDataFilter.filter_string(str(exc))
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from app.db import Base
|
||||
from app.schemas import FileType
|
||||
|
||||
|
||||
class PerceptualType(IntEnum):
|
||||
@@ -15,6 +16,16 @@ class PerceptualType(IntEnum):
|
||||
TEXT = 3
|
||||
CONVERSATION = 4
|
||||
|
||||
@staticmethod
|
||||
def trans_from_file_type(file_type: FileType | str):
|
||||
type_map = {
|
||||
FileType.IMAGE: PerceptualType.VISION,
|
||||
FileType.AUDIO: PerceptualType.AUDIO,
|
||||
FileType.VIDEO: PerceptualType.VISION,
|
||||
FileType.DOCUMENT: PerceptualType.TEXT
|
||||
}
|
||||
return type_map.get(file_type, PerceptualType.TEXT)
|
||||
|
||||
|
||||
class FileStorageService(IntEnum):
|
||||
LOCAL = 1
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, DateTime, Boolean
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy import Column, String, DateTime, Boolean, text
|
||||
from sqlalchemy.dialects.postgresql import UUID, ARRAY
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db import Base
|
||||
|
||||
@@ -20,6 +20,10 @@ class Tenants(Base):
|
||||
external_id = Column(String(100), nullable=True, index=True) # 外部企业ID
|
||||
external_source = Column(String(50), nullable=True) # 来源系统
|
||||
|
||||
# 国际化语言配置字段
|
||||
default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言
|
||||
supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表
|
||||
|
||||
# Relationship to users - one tenant has many users
|
||||
users = relationship("User", back_populates="tenant")
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db import Base
|
||||
@@ -22,6 +22,9 @@ class User(Base):
|
||||
external_id = Column(String(100), nullable=True) # 外部用户ID
|
||||
external_source = Column(String(50), nullable=True) # 来源系统
|
||||
|
||||
# 用户语言偏好
|
||||
preferred_language = Column(String(10), server_default=text("'zh'"), default='zh', nullable=False, index=True) # 用户偏好语言,默认中文
|
||||
|
||||
current_workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=True) # 当前工作空间ID,可为空
|
||||
|
||||
# Foreign key to tenant - each user belongs to exactly one tenant
|
||||
|
||||
@@ -2,7 +2,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from sqlalchemy import and_, desc
|
||||
from sqlalchemy import and_, desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
@@ -127,6 +127,17 @@ class MemoryPerceptualRepository:
|
||||
db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_by_url(
|
||||
self,
|
||||
file_url: str
|
||||
) -> list[MemoryPerceptualModel]:
|
||||
try:
|
||||
stmt = select(MemoryPerceptualModel).where(MemoryPerceptualModel.file_path == file_url)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
except Exception:
|
||||
db_logger.error(f"Failed to query perceptual memories by file_url: file_url={file_url}")
|
||||
raise
|
||||
|
||||
def get_by_type(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
|
||||
73
api/app/schemas/i18n_schema.py
Normal file
73
api/app/schemas/i18n_schema.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
I18n Management API Schemas
|
||||
|
||||
This module defines Pydantic schemas for i18n management APIs.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Language Management Schemas
|
||||
# ============================================================================
|
||||
|
||||
class LanguageInfo(BaseModel):
|
||||
"""Language information"""
|
||||
code: str = Field(..., description="Language code (e.g., 'zh', 'en')")
|
||||
name: str = Field(..., description="Language name (e.g., 'Chinese', 'English')")
|
||||
native_name: str = Field(..., description="Native language name (e.g., '中文', 'English')")
|
||||
is_enabled: bool = Field(..., description="Whether the language is enabled")
|
||||
is_default: bool = Field(..., description="Whether this is the default language")
|
||||
|
||||
|
||||
class LanguageListResponse(BaseModel):
|
||||
"""Response for language list"""
|
||||
languages: List[LanguageInfo] = Field(..., description="List of available languages")
|
||||
|
||||
|
||||
class LanguageCreateRequest(BaseModel):
|
||||
"""Request to add a new language"""
|
||||
code: str = Field(..., description="Language code (e.g., 'ja', 'ko')", min_length=2, max_length=10)
|
||||
name: str = Field(..., description="Language name", min_length=1, max_length=100)
|
||||
native_name: str = Field(..., description="Native language name", min_length=1, max_length=100)
|
||||
is_enabled: bool = Field(default=True, description="Whether to enable the language")
|
||||
|
||||
|
||||
class LanguageUpdateRequest(BaseModel):
|
||||
"""Request to update language configuration"""
|
||||
is_enabled: Optional[bool] = Field(None, description="Whether the language is enabled")
|
||||
is_default: Optional[bool] = Field(None, description="Whether this is the default language")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Translation Management Schemas
|
||||
# ============================================================================
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
"""Response for translation data"""
|
||||
translations: Dict[str, Dict[str, Any]] = Field(
|
||||
...,
|
||||
description="Translations organized by locale and namespace"
|
||||
)
|
||||
|
||||
|
||||
class TranslationUpdateRequest(BaseModel):
|
||||
"""Request to update a translation"""
|
||||
value: str = Field(..., description="New translation value", min_length=1)
|
||||
description: Optional[str] = Field(None, description="Optional description of the translation")
|
||||
|
||||
|
||||
class MissingTranslationsResponse(BaseModel):
|
||||
"""Response for missing translations"""
|
||||
missing_translations: Dict[str, List[str]] = Field(
|
||||
...,
|
||||
description="Missing translation keys organized by locale"
|
||||
)
|
||||
|
||||
|
||||
class ReloadResponse(BaseModel):
|
||||
"""Response for translation reload"""
|
||||
success: bool = Field(..., description="Whether the reload was successful")
|
||||
reloaded_locales: List[str] = Field(..., description="List of reloaded locales")
|
||||
total_locales: int = Field(..., description="Total number of available locales")
|
||||
@@ -25,5 +25,6 @@ class AgentMemory_Long_Term(ABC):
|
||||
STRATEGY_CHUNK = "chunk"
|
||||
STRATEGY_TIME = "time"
|
||||
DEFAULT_SCOPE = 6
|
||||
TIME_SCOPE=5
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -85,7 +84,6 @@ class Semantic(BaseModel):
|
||||
|
||||
|
||||
class Content(BaseModel):
|
||||
summary: str
|
||||
keywords: list[str]
|
||||
topic: str
|
||||
domain: str
|
||||
|
||||
@@ -326,3 +326,14 @@ class ModelBaseQuery(BaseModel):
|
||||
is_official: Optional[bool] = Field(None, description="是否官方模型")
|
||||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""模型信息Schema"""
|
||||
model_name: str = Field(..., description="模型名称")
|
||||
provider: str = Field(..., description="模型提供商")
|
||||
api_key: str = Field(..., description="API密钥")
|
||||
api_base: str = Field(..., description="API基础URL")
|
||||
is_omni: bool = Field(default=False, description="是否为omni模型")
|
||||
model_type: ModelType = Field(..., description="模型类型")
|
||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ class TenantBase(BaseModel):
|
||||
name: str = Field(..., description="租户名称", max_length=255)
|
||||
description: Optional[str] = Field(None, description="租户描述", max_length=1000)
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
default_language: Optional[str] = Field('zh', description="租户默认语言", max_length=10)
|
||||
supported_languages: Optional[List[str]] = Field(['zh', 'en'], description="租户支持的语言列表")
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
@@ -18,6 +20,26 @@ class TenantBase(BaseModel):
|
||||
if not v or not v.strip():
|
||||
raise ValidationException('租户名称不能为空', code=BizCode.VALIDATION_FAILED)
|
||||
return v.strip()
|
||||
|
||||
@field_validator('default_language')
|
||||
@classmethod
|
||||
def validate_default_language(cls, v):
|
||||
if v:
|
||||
# Validate language code format (2-letter code, optionally with region)
|
||||
import re
|
||||
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', v):
|
||||
raise ValidationException('语言代码格式不正确', code=BizCode.VALIDATION_FAILED)
|
||||
return v
|
||||
|
||||
@field_validator('supported_languages')
|
||||
@classmethod
|
||||
def validate_supported_languages(cls, v):
|
||||
if v:
|
||||
import re
|
||||
for lang in v:
|
||||
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', lang):
|
||||
raise ValidationException(f'语言代码格式不正确: {lang}', code=BizCode.VALIDATION_FAILED)
|
||||
return v
|
||||
|
||||
|
||||
class TenantCreate(TenantBase):
|
||||
@@ -30,6 +52,8 @@ class TenantUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, description="租户名称", max_length=255)
|
||||
description: Optional[str] = Field(None, description="租户描述", max_length=1000)
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
default_language: Optional[str] = Field(None, description="租户默认语言", max_length=10)
|
||||
supported_languages: Optional[List[str]] = Field(None, description="租户支持的语言列表")
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
@@ -37,6 +61,25 @@ class TenantUpdate(BaseModel):
|
||||
if v is not None and (not v or not v.strip()):
|
||||
raise ValidationException('租户名称不能为空', code=BizCode.VALIDATION_FAILED)
|
||||
return v.strip() if v else v
|
||||
|
||||
@field_validator('default_language')
|
||||
@classmethod
|
||||
def validate_default_language(cls, v):
|
||||
if v:
|
||||
import re
|
||||
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', v):
|
||||
raise ValidationException('语言代码格式不正确', code=BizCode.VALIDATION_FAILED)
|
||||
return v
|
||||
|
||||
@field_validator('supported_languages')
|
||||
@classmethod
|
||||
def validate_supported_languages(cls, v):
|
||||
if v:
|
||||
import re
|
||||
for lang in v:
|
||||
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', lang):
|
||||
raise ValidationException(f'语言代码格式不正确: {lang}', code=BizCode.VALIDATION_FAILED)
|
||||
return v
|
||||
|
||||
|
||||
class Tenant(TenantBase):
|
||||
@@ -62,4 +105,29 @@ class TenantList(BaseModel):
|
||||
total: int
|
||||
page: int
|
||||
size: int
|
||||
pages: int
|
||||
pages: int
|
||||
|
||||
|
||||
class TenantLanguageConfig(BaseModel):
|
||||
"""租户语言配置Schema"""
|
||||
default_language: str = Field(..., description="租户默认语言", max_length=10)
|
||||
supported_languages: List[str] = Field(..., description="租户支持的语言列表")
|
||||
|
||||
@field_validator('default_language')
|
||||
@classmethod
|
||||
def validate_default_language(cls, v):
|
||||
import re
|
||||
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', v):
|
||||
raise ValidationException('语言代码格式不正确', code=BizCode.VALIDATION_FAILED)
|
||||
return v
|
||||
|
||||
@field_validator('supported_languages')
|
||||
@classmethod
|
||||
def validate_supported_languages(cls, v):
|
||||
if not v:
|
||||
raise ValidationException('支持的语言列表不能为空', code=BizCode.VALIDATION_FAILED)
|
||||
import re
|
||||
for lang in v:
|
||||
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', lang):
|
||||
raise ValidationException(f'语言代码格式不正确: {lang}', code=BizCode.VALIDATION_FAILED)
|
||||
return v
|
||||
|
||||
@@ -58,6 +58,16 @@ class VerifyPasswordRequest(BaseModel):
|
||||
password: str = Field(..., description="密码")
|
||||
|
||||
|
||||
class LanguagePreferenceRequest(BaseModel):
|
||||
"""语言偏好设置请求"""
|
||||
language: str = Field(..., min_length=2, max_length=10, description="语言代码,如 'zh', 'en'")
|
||||
|
||||
|
||||
class LanguagePreferenceResponse(BaseModel):
|
||||
"""语言偏好响应"""
|
||||
language: str = Field(..., description="当前语言偏好")
|
||||
|
||||
|
||||
class ChangePasswordResponse(BaseModel):
|
||||
"""修改密码响应"""
|
||||
message: str
|
||||
@@ -74,6 +84,7 @@ class User(UserBase):
|
||||
current_workspace_id: Optional[uuid.UUID] = None
|
||||
current_workspace_name: Optional[str] = None
|
||||
role: Optional[WorkspaceRole] = None
|
||||
preferred_language: Optional[str] = "zh" # 用户语言偏好
|
||||
|
||||
# 将 datetime 转换为毫秒时间戳
|
||||
@validator("created_at", pre=True)
|
||||
|
||||
@@ -8,25 +8,21 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db
|
||||
from app.models import MultiAgentConfig, AgentConfig
|
||||
from app.models import MultiAgentConfig, AgentConfig, ModelType
|
||||
from app.models import WorkflowConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \
|
||||
AgentRunService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.tool_service import ToolService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
|
||||
logger = get_business_logger()
|
||||
@@ -126,8 +122,17 @@ class AppChatService:
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 调用 Agent(支持多模态)
|
||||
@@ -266,8 +271,17 @@ class AppChatService:
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 流式调用 Agent(支持多模态)
|
||||
|
||||
@@ -75,7 +75,7 @@ class AudioTranscriptionService:
|
||||
try:
|
||||
# 下载音频文件
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
audio_response = await client.get(audio_url)
|
||||
audio_response = await client.get(audio_url, follow_redirects=True)
|
||||
audio_response.raise_for_status()
|
||||
audio_data = audio_response.content
|
||||
|
||||
|
||||
@@ -80,6 +80,7 @@ def authenticate_user_or_raise(db: Session, email: str, password: str) -> User:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_auth_logger
|
||||
from app.i18n.service import t
|
||||
|
||||
logger = get_auth_logger()
|
||||
|
||||
@@ -87,17 +88,17 @@ def authenticate_user_or_raise(db: Session, email: str, password: str) -> User:
|
||||
user = user_repository.get_user_by_email(db, email=email)
|
||||
if not user:
|
||||
logger.warning(f"用户不存在: {email}")
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 检查用户状态
|
||||
if not user.is_active:
|
||||
logger.warning(f"用户未激活: {email}")
|
||||
raise BusinessException("用户未激活", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.login.account_disabled"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 验证密码
|
||||
if not verify_password(password, user.hashed_password):
|
||||
logger.warning(f"密码错误: {email}")
|
||||
raise BusinessException("密码错误", code=BizCode.PASSWORD_ERROR)
|
||||
raise BusinessException(t("auth.password.incorrect"), code=BizCode.PASSWORD_ERROR)
|
||||
|
||||
logger.info(f"用户认证成功: {email}")
|
||||
return user
|
||||
@@ -254,6 +255,8 @@ def decode_access_token(token: str) -> dict:
|
||||
Raises:
|
||||
BusinessException: token 无效
|
||||
"""
|
||||
from app.i18n.service import t
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, TOKEN_SECRET_KEY, algorithms=[TOKEN_ALGORITHM])
|
||||
return {
|
||||
@@ -261,4 +264,4 @@ def decode_access_token(token: str) -> dict:
|
||||
"share_token": payload["share_token"]
|
||||
}
|
||||
except jwt.InvalidTokenError:
|
||||
raise BusinessException("无效的访问 token", BizCode.INVALID_TOKEN)
|
||||
raise BusinessException(t("auth.token.invalid"), BizCode.INVALID_TOKEN)
|
||||
@@ -23,9 +23,10 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.models import AgentConfig, ModelConfig, ModelType
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services import task_service
|
||||
from app.services.conversation_service import ConversationService
|
||||
@@ -501,9 +502,18 @@ class AgentRunService:
|
||||
processed_files = None
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
@@ -704,9 +714,18 @@ class AgentRunService:
|
||||
processed_files = None
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
@@ -841,7 +860,8 @@ class AgentRunService:
|
||||
"api_key": api_key.api_key,
|
||||
"api_base": api_key.api_base,
|
||||
"api_key_id": api_key.id,
|
||||
"is_omni": api_key.is_omni
|
||||
"is_omni": api_key.is_omni,
|
||||
"capability": api_key.capability
|
||||
}
|
||||
|
||||
async def _ensure_conversation(
|
||||
|
||||
@@ -274,7 +274,7 @@ class MemoryAgentService:
|
||||
|
||||
Args:
|
||||
end_user_id: Group identifier (also used as end_user_id)
|
||||
message: Message to write
|
||||
messages: Message to write
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
|
||||
@@ -1,19 +1,27 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
import json_repair
|
||||
from jinja2 import Template
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
||||
from app.schemas import FileType
|
||||
from app.schemas.memory_perceptual_schema import (
|
||||
PerceptualQuerySchema,
|
||||
PerceptualTimelineResponse,
|
||||
PerceptualMemoryItem,
|
||||
AudioModal, Content, VideoModal, TextModal
|
||||
)
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
@@ -99,7 +107,7 @@ class MemoryPerceptualService:
|
||||
"keywords": content.keywords,
|
||||
"topic": content.topic,
|
||||
"domain": content.domain,
|
||||
"created_time": int(memory.created_time.timestamp()*1000),
|
||||
"created_time": int(memory.created_time.timestamp() * 1000),
|
||||
**detail
|
||||
}
|
||||
|
||||
@@ -108,7 +116,8 @@ class MemoryPerceptualService:
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}")
|
||||
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
|
||||
exc_info=True)
|
||||
raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
|
||||
BizCode.DB_ERROR)
|
||||
|
||||
@@ -138,7 +147,7 @@ class MemoryPerceptualService:
|
||||
for memory in memories:
|
||||
meta_data = memory.meta_data or {}
|
||||
content = meta_data.get("content", {})
|
||||
|
||||
|
||||
# 安全地提取 content 字段,提供默认值
|
||||
if content:
|
||||
content_obj = Content(**content)
|
||||
@@ -149,7 +158,7 @@ class MemoryPerceptualService:
|
||||
topic = "Unknown"
|
||||
domain = "Unknown"
|
||||
keywords = []
|
||||
|
||||
|
||||
memory_item = PerceptualMemoryItem(
|
||||
id=memory.id,
|
||||
perceptual_type=PerceptualType(memory.perceptual_type),
|
||||
@@ -161,7 +170,7 @@ class MemoryPerceptualService:
|
||||
topic=topic,
|
||||
domain=domain,
|
||||
keywords=keywords,
|
||||
created_time=int(memory.created_time.timestamp()*1000),
|
||||
created_time=int(memory.created_time.timestamp() * 1000),
|
||||
storage_service=FileStorageService(memory.storage_service),
|
||||
)
|
||||
memory_items.append(memory_item)
|
||||
@@ -183,3 +192,98 @@ class MemoryPerceptualService:
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
||||
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
||||
|
||||
async def generate_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model_config: ModelInfo,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict,
|
||||
):
|
||||
memories = self.repository.get_by_url(file_url)
|
||||
if memories:
|
||||
business_logger.info(f"Perceptual memory already exists: {file_url}")
|
||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
||||
memory_cache = memories[0]
|
||||
self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
||||
file_path=memory_cache.file_path,
|
||||
file_name=memory_cache.file_name,
|
||||
file_ext=memory_cache.file_ext,
|
||||
summary=memory_cache.summary,
|
||||
meta_data=memory_cache.meta_data
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
return
|
||||
llm = RedBearLLM(RedBearModelConfig(
|
||||
model_name=model_config.model_name,
|
||||
provider=model_config.provider,
|
||||
api_key=model_config.api_key,
|
||||
base_url=model_config.api_base,
|
||||
is_omni=model_config.is_omni
|
||||
), type=model_config.model_type)
|
||||
try:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
|
||||
except FileNotFoundError:
|
||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||
messages = [
|
||||
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
|
||||
{"role": RoleType.USER.value, "content": [
|
||||
{"type": "text", "text": "Summarize the following file"}, file_message
|
||||
]}
|
||||
]
|
||||
result = await llm.ainvoke(messages)
|
||||
content = json_repair.repair_json(result.content, return_objects=True)
|
||||
path = urlparse(file_url).path
|
||||
filename = os.path.basename(path)
|
||||
filename = unquote(filename)
|
||||
file_ext = os.path.splitext(filename)[1]
|
||||
if not file_ext:
|
||||
if file_type == FileType.AUDIO:
|
||||
file_ext = ".mp3"
|
||||
elif file_type == FileType.VIDEO:
|
||||
file_ext = ".mp4"
|
||||
elif file_type == FileType.DOCUMENT:
|
||||
file_ext = ".txt"
|
||||
elif file_type == FileType.IMAGE:
|
||||
file_ext = ".jpg"
|
||||
filename += file_ext
|
||||
file_content = {
|
||||
"keywords": content.get("keywords", []),
|
||||
"topic": content.get("topic"),
|
||||
"domain": content.get("domain")
|
||||
}
|
||||
if file_type in [FileType.IMAGE, FileType.VIDEO]:
|
||||
file_modalities = {
|
||||
"scene": content.get("scene")
|
||||
}
|
||||
elif file_type in [FileType.DOCUMENT]:
|
||||
file_modalities = {
|
||||
"section_count": content.get("section_count"),
|
||||
"title": content.get("title"),
|
||||
"first_line": content.get("first_line")
|
||||
}
|
||||
else:
|
||||
file_modalities = {
|
||||
"speaker_count": content.get("speaker_count")
|
||||
}
|
||||
self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType.trans_from_file_type(file_type),
|
||||
file_path=file_url,
|
||||
file_name=filename,
|
||||
file_ext=file_ext,
|
||||
summary=content.get('summary'),
|
||||
meta_data={
|
||||
"content": file_content,
|
||||
"modalities": file_modalities
|
||||
}
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
"""
|
||||
import base64
|
||||
import io
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
@@ -23,9 +24,12 @@ from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models import ModelApiKey
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||
from app.tasks import write_perceptual_memory
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -39,6 +43,7 @@ DOC_MIME = [
|
||||
|
||||
class MultimodalFormatStrategy(ABC):
|
||||
"""多模态格式策略基类"""
|
||||
|
||||
def __init__(self, file: FileInput):
|
||||
self.file = file
|
||||
|
||||
@@ -95,7 +100,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
if transcription:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
|
||||
"text": f"<audio url=\"{url}\">\ntext_transcription:{transcription}\n</audio>"
|
||||
}
|
||||
# 通义千问音频格式:{"type": "audio", "audio": "url"}
|
||||
return {
|
||||
@@ -125,7 +130,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
# 下载图片
|
||||
if content is None:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
self.file.set_content(content)
|
||||
@@ -231,7 +236,7 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
audio_data = content
|
||||
if content is None:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
audio_data = response.content
|
||||
self.file.set_content(audio_data)
|
||||
@@ -284,34 +289,56 @@ PROVIDER_STRATEGIES = {
|
||||
|
||||
|
||||
class MultimodalService:
|
||||
"""多模态文件处理服务"""
|
||||
"""
|
||||
Service for handling multimodal file processing.
|
||||
|
||||
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None,
|
||||
enable_audio_transcription: bool = False, is_omni: bool = False):
|
||||
Attributes:
|
||||
db (Session): Database session.
|
||||
model_api_key (str): API key for the model provider.
|
||||
provider (str): Name of the model provider.
|
||||
is_omni (bool): Indicates whether the model supports full multimodal capability.
|
||||
capability (list): Capability configuration of the model.
|
||||
audio_api_key (str | None): API key used for audio transcription.
|
||||
enable_audio_transcription (bool): Whether audio transcription is enabled.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
api_config: ModelInfo | None = None,
|
||||
audio_api_key: Optional[str] = None,
|
||||
enable_audio_transcription: bool = False,
|
||||
):
|
||||
"""
|
||||
初始化多模态服务
|
||||
|
||||
Initialize the multimodal service.
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider: 模型提供商(dashscope, bedrock, anthropic, openai 等)
|
||||
api_key: API 密钥(用于音频转文本)
|
||||
enable_audio_transcription: 是否启用音频转文本
|
||||
is_omni: 是否为 Omni 模型(dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
|
||||
db (Session): Database session.
|
||||
api_config (ModelApiKey | None): Model API configuration.
|
||||
audio_api_key (str | None): API key for audio transcription.
|
||||
enable_audio_transcription (bool): Enable audio transcription.
|
||||
"""
|
||||
self.db = db
|
||||
self.provider = provider.lower()
|
||||
self.api_key = api_key
|
||||
self.api_config = api_config
|
||||
if self.api_config is not None:
|
||||
self.model_api_key = api_config.api_key
|
||||
self.provider = api_config.provider.lower()
|
||||
self.is_omni = api_config.is_omni
|
||||
self.capability = api_config.capability
|
||||
self.audio_api_key = audio_api_key
|
||||
self.enable_audio_transcription = enable_audio_transcription
|
||||
self.is_omni = is_omni
|
||||
|
||||
async def process_files(
|
||||
self,
|
||||
files: Optional[List[FileInput]]
|
||||
end_user_id: uuid.UUID | str,
|
||||
files: Optional[List[FileInput]],
|
||||
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理文件列表,返回 LLM 可用的格式
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
files: 文件输入列表
|
||||
|
||||
Returns:
|
||||
@@ -319,6 +346,8 @@ class MultimodalService:
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
if isinstance(end_user_id, uuid.UUID):
|
||||
end_user_id = str(end_user_id)
|
||||
|
||||
# 获取对应的策略
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||
@@ -333,19 +362,25 @@ class MultimodalService:
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
strategy = strategy_class(file)
|
||||
if not file.url:
|
||||
file.url = await self.get_file_url(file)
|
||||
try:
|
||||
if file.type == FileType.IMAGE:
|
||||
if file.type == FileType.IMAGE and "vision" in self.capability:
|
||||
content = await self._process_image(file, strategy)
|
||||
result.append(content)
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
content = await self._process_document(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.AUDIO:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||
content = await self._process_audio(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.VIDEO:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.VIDEO and "video" in self.capability:
|
||||
content = await self._process_video(file, strategy)
|
||||
result.append(content)
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {file.type}")
|
||||
except Exception as e:
|
||||
@@ -355,7 +390,8 @@ class MultimodalService:
|
||||
"file_index": idx,
|
||||
"file_type": file.type,
|
||||
"error": str(e)
|
||||
}
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
# 继续处理其他文件,不中断整个流程
|
||||
result.append({
|
||||
@@ -366,6 +402,17 @@ class MultimodalService:
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
def write_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict
|
||||
):
|
||||
"""写入感知记忆"""
|
||||
if end_user_id and self.api_config:
|
||||
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
|
||||
|
||||
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理图片文件
|
||||
@@ -387,43 +434,6 @@ class MultimodalService:
|
||||
"text": f"[图片处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _download_and_encode_image(url: str) -> tuple[str, str]:
|
||||
"""
|
||||
下载图片并转换为 base64
|
||||
|
||||
Args:
|
||||
url: 图片 URL
|
||||
|
||||
Returns:
|
||||
tuple: (base64_data, media_type)
|
||||
"""
|
||||
from mimetypes import guess_type
|
||||
|
||||
# 下载图片
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# 获取图片数据
|
||||
image_data = response.content
|
||||
|
||||
# 确定 media type
|
||||
content_type = response.headers.get("content-type")
|
||||
if content_type and content_type.startswith("image/"):
|
||||
media_type = content_type
|
||||
else:
|
||||
# 从 URL 推断
|
||||
guessed_type, _ = guess_type(url)
|
||||
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||
|
||||
# 转换为 base64
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||
|
||||
return base64_data, media_type
|
||||
|
||||
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理文档文件(PDF、Word 等)
|
||||
@@ -436,7 +446,6 @@ class MultimodalService:
|
||||
Dict: 根据 provider 返回不同格式的文档内容
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
# 远程文档暂不支持提取
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
||||
@@ -471,12 +480,12 @@ class MultimodalService:
|
||||
|
||||
# 如果启用音频转文本且有 API Key
|
||||
transcription = None
|
||||
if self.enable_audio_transcription and self.api_key:
|
||||
if self.enable_audio_transcription and self.audio_api_key:
|
||||
logger.info(f"开始音频转文本: {url}")
|
||||
if self.provider == "dashscope":
|
||||
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
|
||||
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.audio_api_key)
|
||||
elif self.provider == "openai":
|
||||
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
|
||||
transcription = await AudioTranscriptionService.transcribe_openai(url, self.audio_api_key)
|
||||
else:
|
||||
logger.warning(f"Provider {self.provider} 不支持音频转文本")
|
||||
|
||||
@@ -557,7 +566,7 @@ class MultimodalService:
|
||||
file_content = file.get_content()
|
||||
if not file_content:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file.url)
|
||||
response = await client.get(file.url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
file_content = response.content
|
||||
file.set_content(file_content)
|
||||
|
||||
53
api/app/services/prompt/perceptual_summary_system.jinja2
Normal file
53
api/app/services/prompt/perceptual_summary_system.jinja2
Normal file
@@ -0,0 +1,53 @@
|
||||
{% raw %}You are a professional information extraction system.
|
||||
|
||||
Your task is to analyze the provided document content and generate structured metadata.
|
||||
|
||||
Extract the following fields:
|
||||
|
||||
* **summary**: A concise summary of the document in 2–4 sentences.
|
||||
* **keywords**: 5–10 important keywords or key phrases that best represent the document. This field MUST be a JSON array of strings.
|
||||
* **topic**: The primary topic of the document expressed as a short phrase (3–8 words).
|
||||
* **domain**: The broader knowledge domain or field the document belongs to (e.g., Artificial Intelligence, Computer Science, Finance, Healthcare, Education, Law, etc.).
|
||||
|
||||
STRICT RULES:
|
||||
|
||||
1. Output MUST be valid JSON.
|
||||
2. Do NOT output markdown.
|
||||
3. Do NOT output explanations.
|
||||
4. Do NOT output any text before or after the JSON.
|
||||
5. The JSON MUST contain EXACTLY these four keys:
|
||||
* summary
|
||||
* keywords
|
||||
* topic
|
||||
* domain{% endraw %}
|
||||
{% if file_type == 'image' or file_type == 'video' %} * scene {% endif %}
|
||||
{% if file_type == 'audio' %} * speaker_count {% endif %}
|
||||
{% if file_type == 'document' %} * section_count
|
||||
* title
|
||||
* first_line
|
||||
{% endif %}
|
||||
{% raw %}
|
||||
6. `keywords` MUST be a JSON array of strings.
|
||||
7. If the document content is insufficient, infer the best possible answer based on context.
|
||||
8. Ensure the JSON is syntactically correct.
|
||||
{% endraw %}
|
||||
9. Output using the language {{ language }}
|
||||
{% raw %}
|
||||
Required JSON format:
|
||||
|
||||
{
|
||||
"summary": "string",
|
||||
"keywords": ["keyword1", "keyword2", "keyword3", "keyword4", "keyword5"],
|
||||
"topic": "string",
|
||||
"domain": "string",
|
||||
{% endraw %}
|
||||
{% if file_type == 'image' or file_type == 'video' %} "scene": ["string", "string"] {% endif %}
|
||||
{% if file_type == 'document' %} "section_count": integer
|
||||
"title": "string",
|
||||
"first_line": "string"
|
||||
{% endif %}
|
||||
{% if file_type == 'audio' %} "speaker_count": integer {% endif %}
|
||||
{% raw %}
|
||||
}
|
||||
|
||||
Now analyze the following document and return the JSON result.{% endraw %}
|
||||
@@ -217,4 +217,55 @@ class TenantService:
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active
|
||||
)
|
||||
)
|
||||
|
||||
def get_tenant_language_config(self, tenant_id: uuid.UUID) -> Optional[dict]:
|
||||
"""获取租户语言配置"""
|
||||
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
|
||||
if not tenant:
|
||||
raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND)
|
||||
|
||||
return {
|
||||
"default_language": tenant.default_language,
|
||||
"supported_languages": tenant.supported_languages
|
||||
}
|
||||
|
||||
def update_tenant_language_config(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
default_language: str,
|
||||
supported_languages: list
|
||||
) -> Optional[dict]:
|
||||
"""更新租户语言配置"""
|
||||
# 检查租户是否存在
|
||||
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
|
||||
if not tenant:
|
||||
raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND)
|
||||
|
||||
# 验证默认语言在支持的语言列表中
|
||||
if default_language not in supported_languages:
|
||||
raise BusinessException(
|
||||
"默认语言必须在支持的语言列表中",
|
||||
code=BizCode.VALIDATION_FAILED
|
||||
)
|
||||
|
||||
try:
|
||||
# 更新语言配置
|
||||
tenant.default_language = default_language
|
||||
tenant.supported_languages = supported_languages
|
||||
self.db.commit()
|
||||
self.db.refresh(tenant)
|
||||
|
||||
business_logger.info(
|
||||
f"更新租户语言配置成功: {tenant.name} (ID: {tenant.id}), "
|
||||
f"默认语言: {default_language}, 支持语言: {supported_languages}"
|
||||
)
|
||||
|
||||
return {
|
||||
"default_language": tenant.default_language,
|
||||
"supported_languages": tenant.supported_languages
|
||||
}
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
business_logger.error(f"更新租户语言配置失败: {str(e)}")
|
||||
raise BusinessException(f"更新租户语言配置失败: {str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
@@ -438,24 +438,26 @@ def update_last_login_time(db: Session, user_id: uuid.UUID) -> User:
|
||||
|
||||
async def change_password(db: Session, user_id: uuid.UUID, old_password: str, new_password: str, current_user: User) -> User:
|
||||
"""普通用户修改自己的密码"""
|
||||
from app.i18n.service import t
|
||||
|
||||
business_logger.info(f"用户修改密码请求: user_id={user_id}, current_user={current_user.id}")
|
||||
|
||||
# 检查权限:只能修改自己的密码
|
||||
if current_user.id != user_id:
|
||||
business_logger.warning(f"用户尝试修改他人密码: current_user={current_user.id}, target_user={user_id}")
|
||||
raise PermissionDeniedException("You can only change your own password")
|
||||
raise PermissionDeniedException(t("auth.password.change_failed"))
|
||||
|
||||
try:
|
||||
# 获取用户
|
||||
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
if not db_user:
|
||||
business_logger.warning(f"用户不存在: {user_id}")
|
||||
raise BusinessException("User not found", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 验证旧密码
|
||||
if not verify_password(old_password, db_user.hashed_password):
|
||||
business_logger.warning(f"用户旧密码验证失败: {user_id}")
|
||||
raise BusinessException("当前密码不正确", code=BizCode.VALIDATION_FAILED)
|
||||
raise BusinessException(t("auth.password.incorrect"), code=BizCode.VALIDATION_FAILED)
|
||||
|
||||
# 更新密码
|
||||
db_user.hashed_password = get_password_hash(new_password)
|
||||
@@ -471,7 +473,7 @@ async def change_password(db: Session, user_id: uuid.UUID, old_password: str, ne
|
||||
except Exception as e:
|
||||
business_logger.error(f"修改用户密码失败: user_id={user_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise BusinessException(f"修改用户密码失败: user_id={user_id} - {str(e)}", code=BizCode.DB_ERROR)
|
||||
raise BusinessException(t("auth.password.change_failed"), code=BizCode.DB_ERROR)
|
||||
|
||||
|
||||
async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_password: str = None, current_user: User = None) -> tuple[User, str]:
|
||||
@@ -487,6 +489,8 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
|
||||
Returns:
|
||||
tuple[User, str]: (更新后的用户对象, 实际使用的密码)
|
||||
"""
|
||||
from app.i18n.service import t
|
||||
|
||||
business_logger.info(f"管理员修改用户密码请求: admin={current_user.id}, target_user={target_user_id}")
|
||||
|
||||
# 检查权限:只有超级管理员可以修改他人密码
|
||||
@@ -496,7 +500,7 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
|
||||
try:
|
||||
permission_service.check_superuser(
|
||||
subject,
|
||||
error_message="只有超级管理员可以修改他人密码"
|
||||
error_message=t("auth.password.change_failed")
|
||||
)
|
||||
except PermissionDeniedException as e:
|
||||
business_logger.warning(f"非超管用户尝试修改他人密码: current_user={current_user.id}")
|
||||
@@ -507,12 +511,12 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
|
||||
target_user = user_repository.get_user_by_id(db=db, user_id=target_user_id)
|
||||
if not target_user:
|
||||
business_logger.warning(f"目标用户不存在: {target_user_id}")
|
||||
raise BusinessException("目标用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 检查租户权限:超管只能修改同租户用户的密码
|
||||
if current_user.tenant_id != target_user.tenant_id:
|
||||
business_logger.warning(f"跨租户密码修改尝试: admin_tenant={current_user.tenant_id}, target_tenant={target_user.tenant_id}")
|
||||
raise BusinessException("不可跨租户修改用户密码", code=BizCode.FORBIDDEN)
|
||||
raise BusinessException(t("auth.password.change_failed"), code=BizCode.FORBIDDEN)
|
||||
|
||||
# 如果没有提供新密码,则生成随机密码
|
||||
actual_password = new_password if new_password else generate_random_password()
|
||||
@@ -532,7 +536,7 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
|
||||
except Exception as e:
|
||||
business_logger.error(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise BusinessException(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}", code=BizCode.DB_ERROR)
|
||||
raise BusinessException(t("auth.password.change_failed"), code=BizCode.DB_ERROR)
|
||||
|
||||
|
||||
def generate_random_password(length: int = 12) -> str:
|
||||
@@ -740,3 +744,54 @@ async def verify_and_change_email(db: Session, user_id: uuid.UUID, new_email: Em
|
||||
#
|
||||
# business_logger.info(f"用户邮箱修改成功: {db_user.username}, new_email={new_email}")
|
||||
# return db_user
|
||||
|
||||
|
||||
def get_user_language_preference(db: Session, user_id: uuid.UUID, current_user: User) -> str:
|
||||
"""获取用户语言偏好"""
|
||||
business_logger.info(f"获取用户语言偏好: user_id={user_id}")
|
||||
|
||||
# 权限检查:只能获取自己的语言偏好
|
||||
if current_user.id != user_id:
|
||||
raise PermissionDeniedException("只能获取自己的语言偏好")
|
||||
|
||||
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
if not db_user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
language = db_user.preferred_language or "zh"
|
||||
business_logger.info(f"用户语言偏好: {db_user.username}, language={language}")
|
||||
return language
|
||||
|
||||
|
||||
def update_user_language_preference(
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
language: str,
|
||||
current_user: User
|
||||
) -> User:
|
||||
"""更新用户语言偏好"""
|
||||
business_logger.info(f"更新用户语言偏好: user_id={user_id}, language={language}")
|
||||
|
||||
# 权限检查:只能修改自己的语言偏好
|
||||
if current_user.id != user_id:
|
||||
raise PermissionDeniedException("只能修改自己的语言偏好")
|
||||
|
||||
# 验证语言代码是否支持
|
||||
from app.core.config import settings
|
||||
if language not in settings.I18N_SUPPORTED_LANGUAGES:
|
||||
raise BusinessException(
|
||||
f"不支持的语言代码: {language}。支持的语言: {', '.join(settings.I18N_SUPPORTED_LANGUAGES)}",
|
||||
code=BizCode.VALIDATION_FAILED
|
||||
)
|
||||
|
||||
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
if not db_user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 更新语言偏好
|
||||
db_user.preferred_language = language
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
business_logger.info(f"用户语言偏好更新成功: {db_user.username}, language={language}")
|
||||
return db_user
|
||||
|
||||
392
api/app/tasks.py
392
api/app/tasks.py
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -11,20 +10,48 @@ from datetime import datetime, timezone
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import redis
|
||||
import requests
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Import a unified Celery instance
|
||||
from app.celery_app import celery_app
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.rag.crawler.web_crawler import WebCrawler
|
||||
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
|
||||
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
|
||||
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
||||
from app.core.rag.integrations.feishu.models import FileInfo
|
||||
from app.core.rag.integrations.yuque.client import YuqueAPIClient
|
||||
from app.core.rag.integrations.yuque.models import YuqueDocInfo
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.prompts.generator import question_proposal
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models import Document, File, Knowledge
|
||||
from app.schemas import document_schema, file_schema
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
from app.utils.redis_lock import RedisLock
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 模块级同步 Redis 连接池,供 Celery 任务共享使用
|
||||
# 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致
|
||||
# 使用连接池而非单例客户端,提供更好的并发性能和自动重连
|
||||
_sync_redis_pool: redis.ConnectionPool = None
|
||||
_sync_redis_pool: redis.ConnectionPool | None = None
|
||||
|
||||
def _get_or_create_redis_pool() -> redis.ConnectionPool:
|
||||
|
||||
def _get_or_create_redis_pool() -> redis.ConnectionPool | None:
|
||||
"""获取或创建 Redis 连接池(懒初始化)"""
|
||||
global _sync_redis_pool
|
||||
if _sync_redis_pool is None:
|
||||
@@ -47,6 +74,7 @@ def _get_or_create_redis_pool() -> redis.ConnectionPool:
|
||||
return None
|
||||
return _sync_redis_pool
|
||||
|
||||
|
||||
def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
||||
"""获取同步 Redis 客户端(使用连接池)
|
||||
|
||||
@@ -60,7 +88,7 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
||||
pool = _get_or_create_redis_pool()
|
||||
if pool is None:
|
||||
return None
|
||||
|
||||
|
||||
client = redis.StrictRedis(connection_pool=pool)
|
||||
# 验证连接可用性
|
||||
client.ping()
|
||||
@@ -72,32 +100,18 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
||||
logger.error(f"Unexpected error getting Redis client: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# Import a unified Celery instance
|
||||
from app.celery_app import celery_app
|
||||
from app.core.config import settings
|
||||
from app.core.rag.crawler.web_crawler import WebCrawler
|
||||
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
|
||||
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
|
||||
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
||||
from app.core.rag.integrations.feishu.models import FileInfo
|
||||
from app.core.rag.integrations.yuque.client import YuqueAPIClient
|
||||
from app.core.rag.integrations.yuque.models import YuqueDocInfo
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.prompts.generator import question_proposal
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models.document_model import Document
|
||||
from app.models.file_model import File
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.schemas import document_schema, file_schema
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
def set_asyncio_event_loop():
|
||||
"""Set the asyncio event loop for the current thread."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.process_item")
|
||||
@@ -294,9 +308,18 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
vector_size = len(vts[0])
|
||||
init_graphrag(task, vector_size)
|
||||
|
||||
async def _run(row: dict, document_ids: list[str], language: str, parser_config: dict, vector_service,
|
||||
chat_model, embedding_model, callback, with_resolution: bool = True,
|
||||
with_community: bool = True, ) -> dict:
|
||||
async def _run(
|
||||
row: dict,
|
||||
document_ids: list[str],
|
||||
language: str,
|
||||
parser_config: dict,
|
||||
vector_service,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
with_resolution: bool = True,
|
||||
with_community: bool = True
|
||||
) -> dict:
|
||||
await trio.sleep(5) # Delay for 10 seconds
|
||||
nonlocal progress_msg # Declare the use of an external progress_msg variable
|
||||
result = await run_graphrag_for_kb(
|
||||
@@ -329,6 +352,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
with_community=with_community,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(sync_task)
|
||||
@@ -448,6 +472,7 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
with_community=with_community,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(sync_task)
|
||||
@@ -1002,29 +1027,21 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
||||
# Log but continue - will fail later with proper error
|
||||
pass
|
||||
|
||||
async def _run() -> str:
|
||||
async def _run() -> dict:
|
||||
with get_db_context() as db:
|
||||
service = MemoryAgentService()
|
||||
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db,
|
||||
storage_type, user_rag_memory_id)
|
||||
return await service.read_memory(
|
||||
end_user_id,
|
||||
message,
|
||||
history,
|
||||
search_switch,
|
||||
actual_config_id, db,
|
||||
storage_type, user_rag_memory_id
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1056,7 +1073,8 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
|
||||
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, user_rag_memory_id: str,
|
||||
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str,
|
||||
user_rag_memory_id: str,
|
||||
language: str = "zh") -> Dict[str, Any]:
|
||||
"""Celery task to process a write message via MemoryAgentService.
|
||||
Args:
|
||||
@@ -1073,10 +1091,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
Raises:
|
||||
Exception on failure
|
||||
"""
|
||||
from app.core.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id} (type: {type(config_id).__name__}), storage_type={storage_type}, language={language}")
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
|
||||
f"config_id={config_id} (type: {type(config_id).__name__}), "
|
||||
f"storage_type={storage_type}, language={language}")
|
||||
start_time = time.time()
|
||||
|
||||
# Convert config_id to UUID
|
||||
@@ -1086,13 +1105,14 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
actual_config_id = resolve_config_id(config_id, db)
|
||||
print(100*'-')
|
||||
print(100 * '-')
|
||||
print(actual_config_id)
|
||||
print(100*'-')
|
||||
print(100 * '-')
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}")
|
||||
logger.error(
|
||||
f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": f"Invalid config_id format: {config_id} - {str(e)}",
|
||||
@@ -1116,7 +1136,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
async def _run() -> str:
|
||||
with get_db_context() as db:
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
|
||||
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||
service = MemoryAgentService()
|
||||
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type,
|
||||
user_rag_memory_id, language)
|
||||
@@ -1124,22 +1145,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
return result
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1193,28 +1200,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
}
|
||||
|
||||
|
||||
def reflection_engine() -> None:
|
||||
"""Empty function placeholder for timed background reflection.
|
||||
|
||||
Intentionally left blank; replace with real reflection logic later.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
|
||||
|
||||
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
|
||||
asyncio.run(self_reflexion(host_id))
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.reflection.timer")
|
||||
def reflection_timer_task() -> None:
|
||||
"""Periodic Celery task that invokes reflection_engine.
|
||||
|
||||
Raises an exception on failure.
|
||||
"""
|
||||
reflection_engine()
|
||||
|
||||
|
||||
# unused task
|
||||
# @celery_app.task(name="app.core.memory.agent.health.check_read_service")
|
||||
# def check_read_service_task() -> Dict[str, str]:
|
||||
@@ -1368,6 +1353,8 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
||||
"workspace_id": workspace_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.write_all_workspaces_memory_task",
|
||||
bind=True,
|
||||
@@ -1391,15 +1378,12 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.repositories.memory_increment_repository import write_memory_increment
|
||||
from app.services.memory_storage_service import search_all
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有活跃的工作空间
|
||||
@@ -1408,7 +1392,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
).all()
|
||||
|
||||
if not workspaces:
|
||||
api_logger.warning("没有找到活跃的工作空间")
|
||||
logger.warning("没有找到活跃的工作空间")
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": "没有找到活跃的工作空间",
|
||||
@@ -1416,13 +1400,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
"workspace_results": []
|
||||
}
|
||||
|
||||
api_logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
|
||||
logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
|
||||
all_workspace_results = []
|
||||
|
||||
# 遍历每个工作空间
|
||||
for workspace in workspaces:
|
||||
workspace_id = workspace.id
|
||||
api_logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
|
||||
logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
|
||||
|
||||
try:
|
||||
# 1. 查询当前workspace下的所有app(仅未删除的)
|
||||
@@ -1447,7 +1431,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
"memory_increment_id": str(memory_increment.id),
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
})
|
||||
api_logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0")
|
||||
logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0")
|
||||
continue
|
||||
|
||||
# 2. 查询所有app下的end_user_id(去重)
|
||||
@@ -1472,7 +1456,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
})
|
||||
except Exception as e:
|
||||
# 记录单个用户查询失败,但继续处理其他用户
|
||||
api_logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
|
||||
logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
|
||||
end_user_details.append({
|
||||
"end_user_id": str(end_user_id),
|
||||
"total": 0,
|
||||
@@ -1496,13 +1480,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
})
|
||||
|
||||
api_logger.info(
|
||||
logger.info(
|
||||
f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间
|
||||
api_logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
|
||||
logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
|
||||
all_workspace_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"workspace_name": workspace.name,
|
||||
@@ -1525,7 +1509,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆增量统计任务执行失败: {str(e)}")
|
||||
logger.error(f"记忆增量统计任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
@@ -1534,22 +1518,8 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1597,11 +1567,9 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.services.user_memory_service import UserMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("开始执行记忆缓存重新生成定时任务")
|
||||
|
||||
service = UserMemoryService()
|
||||
@@ -1734,22 +1702,8 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1785,15 +1739,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.services.memory_reflection_service import (
|
||||
MemoryReflectionService,
|
||||
WorkspaceAppService,
|
||||
)
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有工作空间
|
||||
@@ -1812,7 +1763,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
# 遍历每个工作空间
|
||||
for workspace in workspaces:
|
||||
workspace_id = workspace.id
|
||||
api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}")
|
||||
logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}")
|
||||
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
@@ -1824,7 +1775,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
workspace_reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['memory_configs'] == []:
|
||||
if not data['memory_configs']:
|
||||
continue
|
||||
|
||||
releases = data['releases']
|
||||
@@ -1835,7 +1786,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str(
|
||||
user['app_id']):
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_reflection_from_data(
|
||||
config_data=config,
|
||||
@@ -1855,12 +1806,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
"reflection_results": workspace_reflection_results
|
||||
})
|
||||
|
||||
api_logger.info(
|
||||
logger.info(
|
||||
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback() # Rollback failed transaction to allow next query
|
||||
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
||||
logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
||||
all_reflection_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"error": str(e),
|
||||
@@ -1879,7 +1830,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"工作空间反思任务执行失败: {str(e)}")
|
||||
logger.error(f"工作空间反思任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
@@ -1888,22 +1839,8 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1944,18 +1881,16 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
api_logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}")
|
||||
logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}")
|
||||
|
||||
forget_service = MemoryForgetService()
|
||||
|
||||
# 运行遗忘周期
|
||||
# FIXME: MemeoryForgetService
|
||||
report = await forget_service.trigger_forgetting(
|
||||
db=db,
|
||||
end_user_id=None, # 处理所有组
|
||||
@@ -1964,7 +1899,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
api_logger.info(
|
||||
logger.info(
|
||||
f"遗忘周期定时任务完成: "
|
||||
f"融合 {report['merged_count']} 对节点, "
|
||||
f"失败 {report['failed_count']} 对, "
|
||||
@@ -1980,7 +1915,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
api_logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
|
||||
logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"status": "FAILED",
|
||||
@@ -1997,6 +1932,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Long-term Memory Storage Tasks (Batched Write Strategies)
|
||||
# =============================================================================
|
||||
@@ -2222,9 +2158,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
from app.repositories.implicit_emotions_storage_repository import (
|
||||
ImplicitEmotionsStorageRepository,
|
||||
@@ -2233,7 +2168,6 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("开始执行隐性记忆和情绪数据更新定时任务")
|
||||
|
||||
total_users = 0
|
||||
@@ -2267,7 +2201,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
for end_user_id in refresh_iter:
|
||||
logger.info(f"开始处理用户: {end_user_id}")
|
||||
user_start_time = time.time()
|
||||
|
||||
|
||||
implicit_success = False
|
||||
emotion_success = False
|
||||
errors = []
|
||||
@@ -2318,7 +2252,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
failed += 1
|
||||
|
||||
user_elapsed = time.time() - user_start_time
|
||||
|
||||
|
||||
# 记录用户处理结果
|
||||
user_result = {
|
||||
"end_user_id": end_user_id,
|
||||
@@ -2460,22 +2394,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -2521,14 +2441,12 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str,
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.implicit_emotions_storage_repository import (
|
||||
ImplicitEmotionsStorageRepository,
|
||||
)
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
@@ -2587,20 +2505,7 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str,
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
@@ -2633,6 +2538,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
默认生成中文(zh)兴趣分布数据。
|
||||
|
||||
Args:
|
||||
self: task object
|
||||
end_user_ids: 需要检查的用户ID列表
|
||||
|
||||
Returns:
|
||||
@@ -2641,11 +2547,9 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
@@ -2694,20 +2598,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
@@ -2720,3 +2611,54 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.write_perceptual_memory",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def write_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model_api_config: dict,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict
|
||||
):
|
||||
"""
|
||||
Write perceptual memory for a user into PostgreSQL and Neo4j.
|
||||
|
||||
This task generates or updates the user's perceptual memory
|
||||
in the backend databases. It is intended to be executed asynchronously
|
||||
via Celery.
|
||||
|
||||
Args:
|
||||
end_user_id (uuid.UUID): The unique identifier of the end user.
|
||||
model_api_config (ModelInfo): API configuration for the model
|
||||
used to generate perceptual memory.
|
||||
file_type (str): The file type
|
||||
file_url (url): The url of file
|
||||
file_message (dict): The file message containing details about the file
|
||||
to be processed.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
|
||||
set_asyncio_event_loop()
|
||||
with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
|
||||
model_info = ModelInfo(**model_api_config)
|
||||
with get_db_context() as db:
|
||||
memory_perceptual_service = MemoryPerceptualService(db)
|
||||
return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
|
||||
end_user_id,
|
||||
model_info,
|
||||
file_type,
|
||||
file_url,
|
||||
file_message,
|
||||
))
|
||||
|
||||
61
api/app/utils/redis_lock.py
Normal file
61
api/app/utils/redis_lock.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import redis
|
||||
import uuid
|
||||
import time
|
||||
|
||||
UNLOCK_SCRIPT = """
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("del", KEYS[1])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
"""
|
||||
|
||||
|
||||
class RedisLock:
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
redis_client: redis.StrictRedis,
|
||||
expire: int = 60,
|
||||
retry_interval: float = 0.1,
|
||||
timeout: float = 30
|
||||
|
||||
):
|
||||
self.key = key
|
||||
self.expire = expire
|
||||
self.value = str(uuid.uuid4())
|
||||
self._locked = False
|
||||
self.retry_interval = retry_interval
|
||||
self.timeout = timeout
|
||||
self.redis_client = redis_client
|
||||
|
||||
def acquire(self) -> bool:
|
||||
start = time.time()
|
||||
while True:
|
||||
ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True)
|
||||
if ok:
|
||||
self._locked = True
|
||||
return True
|
||||
if time.time() - start >= self.timeout:
|
||||
return False
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
def release(self):
|
||||
if not self._locked:
|
||||
return
|
||||
self.redis_client.eval(
|
||||
UNLOCK_SCRIPT,
|
||||
1,
|
||||
self.key,
|
||||
self.value
|
||||
)
|
||||
self._locked = False
|
||||
|
||||
def __enter__(self):
|
||||
ok = self.acquire()
|
||||
if not ok:
|
||||
raise RuntimeError(f"Get redis lock timeout: {self.key}")
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.release()
|
||||
38
api/migrations/versions/01587a13522f_202603131028.py
Normal file
38
api/migrations/versions/01587a13522f_202603131028.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""202603131028
|
||||
|
||||
Revision ID: 01587a13522f
|
||||
Revises: fb834419b18f
|
||||
Create Date: 2026-03-13 10:28:43.601370
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '01587a13522f'
|
||||
down_revision: Union[str, None] = 'fb834419b18f'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('tenants', sa.Column('default_language', sa.String(length=10), server_default='zh', nullable=False))
|
||||
op.add_column('tenants', sa.Column('supported_languages', postgresql.ARRAY(sa.String(length=10)), server_default=sa.text("'{zh,en}'"), nullable=False))
|
||||
op.create_index(op.f('ix_tenants_default_language'), 'tenants', ['default_language'], unique=False)
|
||||
op.add_column('users', sa.Column('preferred_language', sa.String(length=10), server_default=sa.text("'zh'"), nullable=False))
|
||||
op.create_index(op.f('ix_users_preferred_language'), 'users', ['preferred_language'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_users_preferred_language'), table_name='users')
|
||||
op.drop_column('users', 'preferred_language')
|
||||
op.drop_index(op.f('ix_tenants_default_language'), table_name='tenants')
|
||||
op.drop_column('tenants', 'supported_languages')
|
||||
op.drop_column('tenants', 'default_language')
|
||||
# ### end Alembic commands ###
|
||||
30
api/migrations/versions/ea31b4e347d8_202603131452.py
Normal file
30
api/migrations/versions/ea31b4e347d8_202603131452.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""202603131452
|
||||
|
||||
Revision ID: ea31b4e347d8
|
||||
Revises: 01587a13522f
|
||||
Create Date: 2026-03-13 14:53:20.587580
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'ea31b4e347d8'
|
||||
down_revision: Union[str, None] = '01587a13522f'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('app_shares', sa.Column('permission', sa.String(), nullable=False, comment='权限模式: readonly | editable'))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('app_shares', 'permission')
|
||||
# ### end Alembic commands ###
|
||||
@@ -36,6 +36,7 @@
|
||||
"codemirror": "^6.0.2",
|
||||
"copy-to-clipboard": "^3.3.3",
|
||||
"crypto-js": "^4.2.0",
|
||||
"d3": "^7.9.0",
|
||||
"dayjs": "^1.11.18",
|
||||
"echarts": "^5.6.0",
|
||||
"echarts-for-react": "^3.0.2",
|
||||
@@ -67,6 +68,7 @@
|
||||
"@tailwindcss/vite": "^4.1.14",
|
||||
"@types/codemirror": "^5.60.17",
|
||||
"@types/crypto-js": "^4.2.2",
|
||||
"@types/d3": "^7.4.3",
|
||||
"@types/js-yaml": "^4.0.9",
|
||||
"@types/node": "^24.6.0",
|
||||
"@types/react": "^18.2.0",
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 14:00:06
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-04 10:58:41
|
||||
* @Last Modified time: 2026-03-13 10:48:41
|
||||
*/
|
||||
import { request } from '@/utils/request'
|
||||
import type { AxiosRequestConfig } from 'axios'
|
||||
import type {
|
||||
MemoryFormData,
|
||||
} from '@/views/MemoryManagement/types'
|
||||
@@ -94,8 +95,12 @@ export const updatedEndUserProfile = (values: EndUser) => {
|
||||
return request.post(`/memory-storage/updated_end_user/profile`, values)
|
||||
}
|
||||
// User Memory - Relationship network
|
||||
export const getMemorySearchEdges = (end_user_id: string) => {
|
||||
return request.get(`/memory-storage/analytics/graph_data`, { end_user_id })
|
||||
export const getMemorySearchEdges = (end_user_id: string, config?: AxiosRequestConfig) => {
|
||||
return request.get(`/memory-storage/analytics/graph_data`, { end_user_id }, config)
|
||||
}
|
||||
// User Memory - Community graph
|
||||
export const getMemoryCommunityGraph = (end_user_id: string, config?: AxiosRequestConfig) => {
|
||||
return request.get(`/memory-storage/analytics/community_graph`, { end_user_id }, config)
|
||||
}
|
||||
// User Memory - User interest distribution
|
||||
export const getInterestDistributionByUser = (end_user_id: string) => {
|
||||
|
||||
67
web/src/components/D3Graph/CommunityGraph.tsx
Normal file
67
web/src/components/D3Graph/CommunityGraph.tsx
Normal file
@@ -0,0 +1,67 @@
|
||||
import React, { useState, useRef, useMemo, useEffect, type FC } from 'react'
|
||||
import Empty from '@/components/Empty'
|
||||
import { GRAPH_COLORS, initCommunityGraph } from './utils'
|
||||
import { useD3Graph } from './hooks'
|
||||
import type { CommunityD3Node, D3Link, CommunityGraphProps } from './types'
|
||||
|
||||
// ─── Component ────────────────────────────────────────────────────────────────
|
||||
// Renders a D3-powered community graph with optional tooltip and legend.
|
||||
|
||||
const CommunityGraph: FC<CommunityGraphProps> = ({
|
||||
data,
|
||||
empty: emptyProp,
|
||||
colors = GRAPH_COLORS,
|
||||
renderTooltip,
|
||||
showLegend = true,
|
||||
onCommunityClick,
|
||||
onNodeClick,
|
||||
defaultZoom = 1,
|
||||
}) => {
|
||||
// Tooltip position and hovered node state
|
||||
const [tooltip, setTooltip] = useState<{ x: number; y: number; node: CommunityD3Node } | null>(null)
|
||||
|
||||
// Keep callback refs stable to avoid re-initializing the graph on every render
|
||||
const onCommunityClickRef = useRef(onCommunityClick)
|
||||
const onNodeClickRef = useRef(onNodeClick)
|
||||
const renderTooltipRef = useRef(renderTooltip)
|
||||
useEffect(() => { onCommunityClickRef.current = onCommunityClick }, [onCommunityClick])
|
||||
useEffect(() => { onNodeClickRef.current = onNodeClick }, [onNodeClick])
|
||||
useEffect(() => { renderTooltipRef.current = renderTooltip }, [renderTooltip])
|
||||
|
||||
const graphState = useMemo(() => data, [data])
|
||||
// Show empty state when explicitly flagged or when there are no nodes
|
||||
const isEmpty = emptyProp ?? !data?.nodes.length
|
||||
|
||||
// Initialize (or re-initialize) the D3 graph whenever relevant state changes
|
||||
const containerRef = useD3Graph((container) => {
|
||||
if (!graphState) return
|
||||
return initCommunityGraph(
|
||||
container,
|
||||
graphState.nodes,
|
||||
graphState.links as D3Link[],
|
||||
graphState.communityMap,
|
||||
graphState.communityCaption,
|
||||
graphState.communityNodeMap,
|
||||
{ colors, showLegend, defaultZoom, setTooltip: renderTooltip ? setTooltip : () => {}, onCommunityClickRef, onNodeClickRef }
|
||||
)
|
||||
}, [graphState, showLegend, defaultZoom])
|
||||
|
||||
// Resolve tooltip content: use custom renderer if provided, otherwise fall back to DefaultTooltip
|
||||
const tooltipNode = tooltip && renderTooltipRef.current
|
||||
? renderTooltipRef.current(tooltip.node)
|
||||
: null
|
||||
|
||||
if (isEmpty) return <Empty className="rb:h-full" />
|
||||
return (
|
||||
<div className="rb:w-full rb:h-full rb:relative">
|
||||
<div ref={containerRef} className="rb:w-full rb:h-full" />
|
||||
{tooltipNode ? (
|
||||
<div style={{ position: 'absolute', left: tooltip!.x + 14, top: tooltip!.y - 10, pointerEvents: 'none', zIndex: 20 }}>
|
||||
{tooltipNode}
|
||||
</div>
|
||||
) : undefined}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(CommunityGraph)
|
||||
24
web/src/components/D3Graph/hooks.ts
Normal file
24
web/src/components/D3Graph/hooks.ts
Normal file
@@ -0,0 +1,24 @@
|
||||
import { useRef, useEffect } from 'react'
|
||||
import * as d3 from 'd3'
|
||||
|
||||
/**
|
||||
* Generic hook that mounts a D3 graph inside a div container.
|
||||
* Clears any existing SVG before calling initFn, and runs cleanup on unmount or dep change.
|
||||
*/
|
||||
export function useD3Graph<T>(
|
||||
initFn: (container: HTMLDivElement) => (() => void) | void,
|
||||
deps: T[]
|
||||
) {
|
||||
const containerRef = useRef<HTMLDivElement>(null)
|
||||
useEffect(() => {
|
||||
const container = containerRef.current
|
||||
if (!container) return
|
||||
d3.select(container).selectAll('svg').remove()
|
||||
const cleanup = initFn(container)
|
||||
return () => {
|
||||
cleanup?.()
|
||||
d3.select(container).selectAll('svg').remove()
|
||||
}
|
||||
}, deps)
|
||||
return containerRef
|
||||
}
|
||||
102
web/src/components/D3Graph/types.ts
Normal file
102
web/src/components/D3Graph/types.ts
Normal file
@@ -0,0 +1,102 @@
|
||||
import type { ReactNode, RefObject } from 'react'
|
||||
import type * as d3 from 'd3'
|
||||
|
||||
// ─── Raw input types (mirror of API response, no external dependency) ─────────
|
||||
// These interfaces map 1-to-1 with the graph API response shape.
|
||||
|
||||
export interface RawCommunityNode {
|
||||
id: string
|
||||
label: 'Community'
|
||||
properties: {
|
||||
name: string
|
||||
summary: string
|
||||
member_entity_ids: string[]
|
||||
member_count: number
|
||||
core_entities: string[]
|
||||
community_id: string
|
||||
end_user_id?: string
|
||||
updated_at?: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface RawEntityNode {
|
||||
id: string
|
||||
label: 'ExtractedEntity'
|
||||
properties: {
|
||||
name: string
|
||||
description: string
|
||||
entity_type: string
|
||||
community_name?: string
|
||||
[key: string]: unknown
|
||||
}
|
||||
}
|
||||
|
||||
export interface RawEdge {
|
||||
id: string
|
||||
source: string
|
||||
target: string
|
||||
}
|
||||
|
||||
export interface RawCommunityGraphData {
|
||||
nodes: (RawCommunityNode | RawEntityNode)[]
|
||||
edges: RawEdge[]
|
||||
}
|
||||
|
||||
// ─── D3 graph types ───────────────────────────────────────────────────────────
|
||||
// Runtime node shape used by D3 simulations; extends SimulationNodeDatum for x/y/vx/vy.
|
||||
|
||||
export interface CommunityD3Node extends d3.SimulationNodeDatum {
|
||||
id: string
|
||||
name: string
|
||||
community: string
|
||||
label: string
|
||||
symbolSize: number
|
||||
color: string
|
||||
properties?: RawEntityNode['properties']
|
||||
}
|
||||
|
||||
export interface D3Link extends d3.SimulationLinkDatum<CommunityD3Node> {
|
||||
isCross: boolean
|
||||
}
|
||||
|
||||
// Convex-hull shape rendered behind each community cluster.
|
||||
export interface HullDatum {
|
||||
id: string
|
||||
path: string
|
||||
color: string
|
||||
labelX: number
|
||||
labelY: number
|
||||
dashed: boolean
|
||||
caption: string
|
||||
}
|
||||
|
||||
// Fully transformed graph data ready to be passed into initCommunityGraph.
|
||||
export interface CommunityGraphData {
|
||||
nodes: CommunityD3Node[]
|
||||
links: Array<{ source: string; target: string; isCross: boolean }>
|
||||
communityMap: Map<string, string[]>
|
||||
communityCaption: Map<string, string>
|
||||
communityNodeMap: Map<string, RawCommunityNode>
|
||||
}
|
||||
|
||||
// Props accepted by the CommunityGraph React component.
|
||||
export interface CommunityGraphProps {
|
||||
data: CommunityGraphData | null
|
||||
empty?: boolean
|
||||
colors?: string[]
|
||||
renderTooltip?: (node: CommunityD3Node) => ReactNode
|
||||
showLegend?: boolean
|
||||
onCommunityClick?: (node: RawCommunityNode) => void
|
||||
onNodeClick?: (node: CommunityD3Node) => void
|
||||
defaultZoom?: number
|
||||
}
|
||||
|
||||
// Options forwarded from the React component into the D3 initializer.
|
||||
export interface InitOptions {
|
||||
colors: string[]
|
||||
showLegend: boolean
|
||||
defaultZoom: number
|
||||
setTooltip: (s: { x: number; y: number; node: CommunityD3Node } | null) => void
|
||||
onCommunityClickRef: RefObject<((node: RawCommunityNode) => void) | undefined>
|
||||
onNodeClickRef: RefObject<((node: CommunityD3Node) => void) | undefined>
|
||||
}
|
||||
547
web/src/components/D3Graph/utils.ts
Normal file
547
web/src/components/D3Graph/utils.ts
Normal file
@@ -0,0 +1,547 @@
|
||||
import * as d3 from 'd3'
|
||||
import type { CommunityD3Node, D3Link, HullDatum, CommunityGraphData, RawCommunityGraphData, RawCommunityNode, RawEntityNode, InitOptions } from './types'
|
||||
|
||||
// ─── Colors ───────────────────────────────────────────────────────────────────
|
||||
|
||||
export const GRAPH_COLORS = ['#155EEF', '#369F21', '#4DA8FF', '#FF5D34', '#9C6FFF', '#FF8A4C', '#8BAEF7', '#FFB048']
|
||||
export const colorAt = (i: number) => GRAPH_COLORS[i % GRAPH_COLORS.length]
|
||||
|
||||
export function connectionToRadius(connections: number): number {
|
||||
if (connections <= 1) return 5
|
||||
if (connections <= 10) return 8
|
||||
if (connections <= 15) return 11
|
||||
if (connections <= 20) return 16
|
||||
return 22
|
||||
}
|
||||
|
||||
// ─── Arrow markers ────────────────────────────────────────────────────────────
|
||||
|
||||
export function addArrowMarkers(
|
||||
defs: d3.Selection<SVGDefsElement, unknown, null, undefined>,
|
||||
markers: { id: string; color: string }[]
|
||||
) {
|
||||
markers.forEach(({ id, color }) => {
|
||||
defs.append('marker')
|
||||
.attr('id', id)
|
||||
.attr('viewBox', '0 -4 8 8')
|
||||
.attr('refX', 8).attr('refY', 0)
|
||||
.attr('markerWidth', 6).attr('markerHeight', 6)
|
||||
.attr('orient', 'auto')
|
||||
.append('path').attr('d', 'M0,-4L8,0L0,4').attr('fill', color)
|
||||
})
|
||||
}
|
||||
|
||||
// ─── Zoom ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
export function addZoom(
|
||||
svg: d3.Selection<SVGSVGElement, unknown, null, undefined>,
|
||||
g: d3.Selection<SVGGElement, unknown, null, undefined>
|
||||
) {
|
||||
svg.call(
|
||||
d3.zoom<SVGSVGElement, unknown>().scaleExtent([0.2, 4])
|
||||
.on('zoom', e => g.attr('transform', e.transform))
|
||||
)
|
||||
}
|
||||
|
||||
// ─── Node drag ────────────────────────────────────────────────────────────────
|
||||
|
||||
export function makeNodeDrag<N extends d3.SimulationNodeDatum>(
|
||||
simulation: d3.Simulation<N, d3.SimulationLinkDatum<N>>
|
||||
) {
|
||||
return d3.drag<SVGGElement, N>()
|
||||
.on('start', (e, d) => { if (!e.active) simulation.alphaTarget(0.3).restart(); d.fx = d.x; d.fy = d.y })
|
||||
.on('drag', (e, d) => { d.fx = e.x; d.fy = e.y })
|
||||
.on('end', (e, d) => { if (!e.active) simulation.alphaTarget(0); d.fx = e.x; d.fy = e.y })
|
||||
}
|
||||
|
||||
// ─── Cluster force ────────────────────────────────────────────────────────────
|
||||
// Works for both string and number group keys.
|
||||
|
||||
export function makeClusterForce<N extends d3.SimulationNodeDatum & { x?: number; y?: number; vx?: number; vy?: number }>(
|
||||
nodes: N[],
|
||||
getGroup: (d: N) => string | number,
|
||||
centers: Record<string | number, { x: number; y: number }>,
|
||||
width: number,
|
||||
height: number,
|
||||
opts: { pullStrength?: number; minSepRatio?: number; pushStrength?: number } = {}
|
||||
) {
|
||||
const { pullStrength = 0.45, minSepRatio = 0.68, pushStrength = 1.0 } = opts
|
||||
return (alpha: number) => {
|
||||
// pre-group nodes by key to avoid repeated filter() in hot path
|
||||
const groups = new Map<string, N[]>()
|
||||
nodes.forEach(d => {
|
||||
const k = String(getGroup(d))
|
||||
if (!groups.has(k)) groups.set(k, [])
|
||||
groups.get(k)!.push(d)
|
||||
})
|
||||
// pull toward group center
|
||||
nodes.forEach(d => {
|
||||
const c = centers[getGroup(d)]
|
||||
if (!c) return
|
||||
d.vx = (d.vx ?? 0) + (c.x - (d.x ?? 0)) * pullStrength * alpha
|
||||
d.vy = (d.vy ?? 0) + (c.y - (d.y ?? 0)) * pullStrength * alpha
|
||||
})
|
||||
// live centroids
|
||||
const centroids: Record<string, { x: number; y: number; n: number }> = {}
|
||||
nodes.forEach(d => {
|
||||
const g = String(getGroup(d))
|
||||
if (!centroids[g]) centroids[g] = { x: 0, y: 0, n: 0 }
|
||||
centroids[g].x += d.x ?? 0
|
||||
centroids[g].y += d.y ?? 0
|
||||
centroids[g].n++
|
||||
})
|
||||
Object.values(centroids).forEach(c => { c.x /= c.n; c.y /= c.n })
|
||||
// push groups apart
|
||||
const keys = Object.keys(centroids)
|
||||
const minSep = Math.min(width, height) * minSepRatio
|
||||
for (let i = 0; i < keys.length; i++) {
|
||||
for (let j = i + 1; j < keys.length; j++) {
|
||||
const ci = centroids[keys[i]], cj = centroids[keys[j]]
|
||||
const dx = cj.x - ci.x, dy = cj.y - ci.y
|
||||
const dist = Math.sqrt(dx * dx + dy * dy) || 1
|
||||
if (dist >= minSep) continue
|
||||
const push = ((minSep - dist) / dist) * pushStrength * alpha
|
||||
const fx = dx * push, fy = dy * push
|
||||
groups.get(keys[i])?.forEach(d => { d.vx = (d.vx ?? 0) - fx; d.vy = (d.vy ?? 0) - fy })
|
||||
groups.get(keys[j])?.forEach(d => { d.vx = (d.vx ?? 0) + fx; d.vy = (d.vy ?? 0) + fy })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Group centers ────────────────────────────────────────────────────────────
|
||||
|
||||
export function buildGroupCenters(
|
||||
keys: (string | number)[],
|
||||
width: number,
|
||||
height: number,
|
||||
radiusRatio = 0.4
|
||||
): Record<string | number, { x: number; y: number }> {
|
||||
const centers: Record<string | number, { x: number; y: number }> = {}
|
||||
const r = Math.min(width, height) * radiusRatio
|
||||
keys.forEach((key, i) => {
|
||||
const angle = (i / keys.length) * 2 * Math.PI - Math.PI / 2
|
||||
centers[key] = { x: width / 2 + r * Math.cos(angle), y: height / 2 + r * Math.sin(angle) }
|
||||
})
|
||||
return centers
|
||||
}
|
||||
|
||||
// ─── Community graph data transform ─────────────────────────────────────────
|
||||
|
||||
export function buildCommunityGraphData(raw: RawCommunityGraphData, colors: string[] = GRAPH_COLORS): CommunityGraphData | null {
|
||||
const getColor = (i: number) => colors[i % colors.length]
|
||||
|
||||
const communityNodes = raw.nodes.filter(n => n.label === 'Community') as RawCommunityNode[]
|
||||
const communityCaption = new Map<string, string>()
|
||||
const communityMap = new Map<string, string[]>()
|
||||
|
||||
communityNodes.forEach(n => {
|
||||
communityCaption.set(n.id, n.properties.name)
|
||||
communityMap.set(n.id, n.properties.member_entity_ids)
|
||||
})
|
||||
|
||||
const entityToCommunity = new Map<string, string>()
|
||||
communityMap.forEach((members, commId) => members.forEach(eid => entityToCommunity.set(eid, commId)))
|
||||
|
||||
const commKeys = Array.from(communityMap.keys())
|
||||
const commIndex = new Map(commKeys.map((k, i) => [k, i]))
|
||||
|
||||
const entityNodes = raw.nodes.filter(n => n.label === 'ExtractedEntity') as RawEntityNode[]
|
||||
const entityNodeSet = new Set(entityNodes.map(n => n.id))
|
||||
|
||||
const connectionCount: Record<string, number> = {}
|
||||
raw.edges.forEach(e => {
|
||||
if (entityNodeSet.has(e.source)) connectionCount[e.source] = (connectionCount[e.source] || 0) + 1
|
||||
if (entityNodeSet.has(e.target)) connectionCount[e.target] = (connectionCount[e.target] || 0) + 1
|
||||
})
|
||||
|
||||
const nodes: CommunityD3Node[] = entityNodes.map(n => {
|
||||
const commId = entityToCommunity.get(n.id) ?? commKeys[0]
|
||||
return {
|
||||
id: n.id,
|
||||
name: n.properties.name,
|
||||
community: commId,
|
||||
label: n.label,
|
||||
symbolSize: connectionToRadius(connectionCount[n.id] || 0),
|
||||
color: getColor(commIndex.get(commId) ?? 0),
|
||||
properties: n.properties,
|
||||
}
|
||||
})
|
||||
|
||||
if (!nodes.length) return null
|
||||
|
||||
const links = raw.edges
|
||||
.filter(e => entityNodeSet.has(e.source) && entityNodeSet.has(e.target))
|
||||
.map(e => ({
|
||||
source: e.source,
|
||||
target: e.target,
|
||||
isCross: entityToCommunity.get(e.source) !== entityToCommunity.get(e.target),
|
||||
}))
|
||||
|
||||
const communityNodeMap = new Map<string, RawCommunityNode>(
|
||||
communityNodes.map(n => [n.id, n])
|
||||
)
|
||||
return { nodes, links, communityMap, communityCaption, communityNodeMap }
|
||||
}
|
||||
|
||||
// ─── Hull helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
const smoothLine = d3.line<[number, number]>()
|
||||
.x(d => d[0]).y(d => d[1])
|
||||
.curve(d3.curveCatmullRomClosed.alpha(0.5))
|
||||
|
||||
function expandPoints(pts: [number, number][], pad: number): [number, number][] {
|
||||
const cx = pts.reduce((s, p) => s + p[0], 0) / pts.length
|
||||
const cy = pts.reduce((s, p) => s + p[1], 0) / pts.length
|
||||
return pts.map(([x, y]) => {
|
||||
const dx = x - cx, dy = y - cy
|
||||
const len = Math.sqrt(dx * dx + dy * dy) || 1
|
||||
return [x + (dx / len) * pad, y + (dy / len) * pad]
|
||||
})
|
||||
}
|
||||
|
||||
function toHullPoints(pts: [number, number][]): [number, number][] {
|
||||
if (pts.length === 1) {
|
||||
const [x, y] = pts[0]
|
||||
return [[x - 1, y - 1], [x + 1, y - 1], [x, y + 1]]
|
||||
}
|
||||
if (pts.length === 2) {
|
||||
const [[x1, y1], [x2, y2]] = pts
|
||||
return [[x1, y1], [x2, y2], [(x1 + x2) / 2, (y1 + y2) / 2 - 1]]
|
||||
}
|
||||
return d3.polygonHull(pts) ?? pts
|
||||
}
|
||||
|
||||
const CIRCLE_THRESHOLD = 4 // 节点数 < 此值时使用圆形
|
||||
const CIRCLE_SEGMENTS = 32
|
||||
|
||||
function circlePoints(cx: number, cy: number, r: number): [number, number][] {
|
||||
return Array.from({ length: CIRCLE_SEGMENTS }, (_, i) => {
|
||||
const a = (i / CIRCLE_SEGMENTS) * 2 * Math.PI
|
||||
return [cx + r * Math.cos(a), cy + r * Math.sin(a)] as [number, number]
|
||||
})
|
||||
}
|
||||
|
||||
export function buildHullData(
|
||||
nodes: CommunityD3Node[],
|
||||
communityMap: Map<string, string[]>,
|
||||
communityCaption: Map<string, string>,
|
||||
colors: string[]
|
||||
): HullDatum[] {
|
||||
const getColor = (i: number) => colors[i % colors.length]
|
||||
const byComm = new Map<string, [number, number][]>()
|
||||
communityMap.forEach((_, id) => byComm.set(id, []))
|
||||
nodes.forEach(d => {
|
||||
if (d.x != null && d.y != null) byComm.get(d.community)?.push([d.x, d.y])
|
||||
})
|
||||
|
||||
const hulls: HullDatum[] = []
|
||||
let ci = 0
|
||||
byComm.forEach((pts, id) => {
|
||||
const color = getColor(ci++)
|
||||
if (!pts.length) return
|
||||
let pathPoints: [number, number][]
|
||||
if (pts.length < CIRCLE_THRESHOLD) {
|
||||
const cx = pts.reduce((s, p) => s + p[0], 0) / pts.length
|
||||
const cy = pts.reduce((s, p) => s + p[1], 0) / pts.length
|
||||
pathPoints = circlePoints(cx, cy, 60)
|
||||
} else {
|
||||
pathPoints = expandPoints(toHullPoints(pts), 60) as [number, number][]
|
||||
}
|
||||
const path = smoothLine(pathPoints)
|
||||
if (!path) return
|
||||
hulls.push({
|
||||
id, path, color,
|
||||
labelX: pathPoints.reduce((s, p) => s + p[0], 0) / pathPoints.length,
|
||||
labelY: Math.min(...pathPoints.map(p => p[1])) - 10,
|
||||
dashed: pts.length <= 2,
|
||||
caption: communityCaption.get(id) ?? id,
|
||||
})
|
||||
})
|
||||
return hulls
|
||||
}
|
||||
|
||||
// ─── Hull render ──────────────────────────────────────────────────────────────
|
||||
|
||||
export function renderHulls(
|
||||
hullG: d3.Selection<SVGGElement, unknown, null, undefined>,
|
||||
hulls: HullDatum[],
|
||||
hiddenCommunities: Set<string>,
|
||||
nodes: CommunityD3Node[],
|
||||
simulation: d3.Simulation<CommunityD3Node, D3Link>,
|
||||
onCommunityClick?: (node: RawCommunityNode) => void,
|
||||
communityNodeMap?: Map<string, RawCommunityNode>
|
||||
) {
|
||||
let dragNodes: CommunityD3Node[] = []
|
||||
let dragStart = { x: 0, y: 0 }
|
||||
const communityDrag = d3.drag<SVGPathElement, HullDatum>()
|
||||
.on('start', (event, d) => {
|
||||
if (!event.active) simulation.alphaTarget(0.3).restart()
|
||||
dragNodes = nodes.filter(n => n.community === d.id)
|
||||
dragStart = { x: event.x, y: event.y }
|
||||
dragNodes.forEach(n => { n.fx = n.x; n.fy = n.y })
|
||||
})
|
||||
.on('drag', (event) => {
|
||||
const dx = event.x - dragStart.x, dy = event.y - dragStart.y
|
||||
dragStart = { x: event.x, y: event.y }
|
||||
dragNodes.forEach(n => { n.fx = (n.fx ?? n.x ?? 0) + dx; n.fy = (n.fy ?? n.y ?? 0) + dy })
|
||||
})
|
||||
.on('end', (event) => { if (!event.active) simulation.alphaTarget(0) })
|
||||
|
||||
const pathSel = hullG.selectAll<SVGPathElement, HullDatum>('path.hull').data(hulls, d => d.id)
|
||||
pathSel.enter().append('path').attr('class', 'hull').style('cursor', 'grab')
|
||||
.merge(pathSel)
|
||||
.call(communityDrag)
|
||||
.attr('d', d => d.path)
|
||||
.attr('fill', d => d.color).attr('fill-opacity', 0.08)
|
||||
.attr('stroke', d => d.color).attr('stroke-opacity', 0.5).attr('stroke-width', 1.5)
|
||||
.attr('stroke-dasharray', 'none')
|
||||
.style('display', d => hiddenCommunities.has(d.id) ? 'none' : null)
|
||||
.on('click', (event, d) => {
|
||||
if ((event as MouseEvent).defaultPrevented) return
|
||||
const node = communityNodeMap?.get(d.id)
|
||||
if (node) onCommunityClick?.(node)
|
||||
})
|
||||
pathSel.exit().remove()
|
||||
|
||||
const labelSel = hullG.selectAll<SVGTextElement, HullDatum>('text.hull-label').data(hulls, d => d.id)
|
||||
labelSel.enter().append('text').attr('class', 'hull-label')
|
||||
.attr('text-anchor', 'middle').attr('font-size', '12px').attr('font-weight', '500')
|
||||
.style('pointer-events', 'none')
|
||||
.merge(labelSel)
|
||||
.attr('x', d => d.labelX).attr('y', d => d.labelY)
|
||||
.attr('fill', d => d.color)
|
||||
.style('display', d => hiddenCommunities.has(d.id) ? 'none' : null)
|
||||
.text(d => d.caption)
|
||||
labelSel.exit().remove()
|
||||
}
|
||||
|
||||
// ─── Community graph init ─────────────────────────────────────────────────────
|
||||
|
||||
export function initCommunityGraph(
|
||||
container: HTMLDivElement,
|
||||
nodes: CommunityD3Node[],
|
||||
links: D3Link[],
|
||||
communityMap: Map<string, string[]>,
|
||||
communityCaption: Map<string, string>,
|
||||
communityNodeMap: Map<string, RawCommunityNode>,
|
||||
opts: InitOptions
|
||||
) {
|
||||
const { colors, showLegend, defaultZoom, setTooltip, onCommunityClickRef, onNodeClickRef } = opts
|
||||
const getColor = (i: number) => colors[i % colors.length]
|
||||
|
||||
const width = container.clientWidth || 600
|
||||
const height = container.clientHeight || 518
|
||||
|
||||
const svg = d3.select(container).append('svg')
|
||||
.attr('width', width).attr('height', height)
|
||||
.style('width', '100%').style('height', '100%')
|
||||
.style('background', '#F6F8FC')
|
||||
|
||||
const g = svg.append('g')
|
||||
|
||||
const zoom = d3.zoom<SVGSVGElement, unknown>()
|
||||
.scaleExtent([0.2, 4])
|
||||
.on('zoom', e => g.attr('transform', e.transform))
|
||||
svg.call(zoom)
|
||||
if (defaultZoom !== 1) {
|
||||
svg.call(zoom.transform, d3.zoomIdentity
|
||||
.translate(width / 2 * (1 - defaultZoom), height / 2 * (1 - defaultZoom))
|
||||
.scale(defaultZoom)
|
||||
)
|
||||
}
|
||||
|
||||
const defs = svg.append('defs')
|
||||
addArrowMarkers(defs, [{ id: 'arrow', color: 'rgba(91, 97, 103, 0.7)' }])
|
||||
|
||||
const commKeys = Array.from(communityMap.keys())
|
||||
const centers = buildGroupCenters(commKeys, width, height, 0.45)
|
||||
const linkedIds = new Set(links.flatMap(l => [l.source as string, l.target as string]))
|
||||
|
||||
const simulation = d3.forceSimulation(nodes)
|
||||
.force('link', d3.forceLink<CommunityD3Node, D3Link>(links).id(d => d.id).distance(60))
|
||||
.force('charge', d3.forceManyBody().strength(-120))
|
||||
.force('center', d3.forceCenter(width / 2, height / 2).strength(0.02))
|
||||
.force('collision', d3.forceCollide<CommunityD3Node>(d => d.symbolSize + 16))
|
||||
.force('cluster', makeClusterForce(nodes, d => d.community, centers, width, height, {
|
||||
pullStrength: 0.45, minSepRatio: 0.68, pushStrength: 1.0,
|
||||
}))
|
||||
.force('isolatedPull', (alpha: number) => {
|
||||
nodes.forEach(d => {
|
||||
if (linkedIds.has(d.id)) return
|
||||
const c = centers[d.community]
|
||||
if (!c) return
|
||||
d.vx = (d.vx ?? 0) + (c.x - (d.x ?? 0)) * 0.4 * alpha
|
||||
d.vy = (d.vy ?? 0) + (c.y - (d.y ?? 0)) * 0.4 * alpha
|
||||
})
|
||||
})
|
||||
|
||||
const hullG = g.append('g').attr('class', 'hulls')
|
||||
const hiddenCommunities = new Set<string>()
|
||||
|
||||
const linkSel = g.append('g').selectAll<SVGLineElement, D3Link>('line')
|
||||
.data(links).enter().append('line')
|
||||
.attr('stroke', '#5B6167')
|
||||
.attr('stroke-opacity', d => d.isCross ? 0.3 : 0.5)
|
||||
.attr('stroke-width', d => d.isCross ? 1 : 1.2)
|
||||
.attr('marker-end', 'url(#arrow)')
|
||||
|
||||
const nodeSel = g.append('g').selectAll<SVGGElement, CommunityD3Node>('g')
|
||||
.data(nodes).enter().append('g')
|
||||
.call(makeNodeDrag(simulation))
|
||||
|
||||
nodeSel.append('circle')
|
||||
.attr('r', d => d.symbolSize)
|
||||
.attr('fill', d => d.color).attr('fill-opacity', 0.85)
|
||||
.attr('stroke', '#fff').attr('stroke-width', 1.5)
|
||||
.style('cursor', 'pointer')
|
||||
.on('mouseenter', (event: MouseEvent, d: CommunityD3Node) => {
|
||||
const { left, top } = container.getBoundingClientRect()
|
||||
setTooltip({ x: event.clientX - left, y: event.clientY - top, node: d })
|
||||
})
|
||||
.on('mousemove', (event: MouseEvent) => {
|
||||
const { left, top } = container.getBoundingClientRect()
|
||||
const nd = d3.select<SVGCircleElement, CommunityD3Node>(event.target as SVGCircleElement).datum()
|
||||
setTooltip({ x: event.clientX - left, y: event.clientY - top, node: nd })
|
||||
})
|
||||
.on('mouseleave', () => setTooltip(null))
|
||||
.on('click', (_event: MouseEvent, d: CommunityD3Node) => onNodeClickRef.current?.(d))
|
||||
|
||||
nodeSel.append('text')
|
||||
.text(d => d.name)
|
||||
.attr('x', 0).attr('dy', d => -(d.symbolSize + 5))
|
||||
.attr('text-anchor', 'middle').attr('font-size', '11px').attr('fill', '#444')
|
||||
.style('pointer-events', 'none')
|
||||
|
||||
if (showLegend) {
|
||||
renderLegend(
|
||||
svg,
|
||||
commKeys.map((cid, i) => ({ key: cid, label: communityCaption.get(cid) ?? cid, color: getColor(i) })),
|
||||
width, height,
|
||||
(key, hidden) => {
|
||||
const cid = key as string
|
||||
if (hidden) hiddenCommunities.add(cid)
|
||||
else hiddenCommunities.delete(cid)
|
||||
nodeSel.style('display', d => hiddenCommunities.has(d.community) ? 'none' : null)
|
||||
linkSel.style('display', d => {
|
||||
const s = d.source as CommunityD3Node, t = d.target as CommunityD3Node
|
||||
return hiddenCommunities.has(s.community) || hiddenCommunities.has(t.community) ? 'none' : null
|
||||
})
|
||||
hullG.selectAll<SVGPathElement, HullDatum>('path.hull').style('display', d => hiddenCommunities.has(d.id) ? 'none' : null)
|
||||
hullG.selectAll<SVGTextElement, HullDatum>('text.hull-label').style('display', d => hiddenCommunities.has(d.id) ? 'none' : null)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
simulation.on('tick', () => {
|
||||
linkSel
|
||||
.attr('x1', d => (d.source as CommunityD3Node).x ?? 0)
|
||||
.attr('y1', d => (d.source as CommunityD3Node).y ?? 0)
|
||||
.attr('x2', d => {
|
||||
const s = d.source as CommunityD3Node, t = d.target as CommunityD3Node
|
||||
const dx = (t.x ?? 0) - (s.x ?? 0), dy = (t.y ?? 0) - (s.y ?? 0)
|
||||
const dist = Math.sqrt(dx * dx + dy * dy) || 1
|
||||
return (t.x ?? 0) - (dx / dist) * (t.symbolSize + 2)
|
||||
})
|
||||
.attr('y2', d => {
|
||||
const s = d.source as CommunityD3Node, t = d.target as CommunityD3Node
|
||||
const dx = (t.x ?? 0) - (s.x ?? 0), dy = (t.y ?? 0) - (s.y ?? 0)
|
||||
const dist = Math.sqrt(dx * dx + dy * dy) || 1
|
||||
return (t.y ?? 0) - (dy / dist) * (t.symbolSize + 2)
|
||||
})
|
||||
nodeSel.attr('transform', d => `translate(${d.x ?? 0},${d.y ?? 0})`)
|
||||
renderHulls(hullG, buildHullData(nodes, communityMap, communityCaption, colors), hiddenCommunities, nodes, simulation, (n) => onCommunityClickRef.current?.(n), communityNodeMap)
|
||||
})
|
||||
|
||||
return () => { simulation.stop(); d3.select(container).selectAll('svg').remove() }
|
||||
}
|
||||
|
||||
// ─── Legend ───────────────────────────────────────────────────────────────────
|
||||
|
||||
export interface LegendItem {
|
||||
key: string | number
|
||||
label: string
|
||||
color: string
|
||||
}
|
||||
|
||||
const LEGEND_GAP = 12
|
||||
const LEGEND_RECT_W = 20
|
||||
const LEGEND_RECT_H = 10
|
||||
const LEGEND_TEXT_OFFSET = 24
|
||||
const LEGEND_FONT_SIZE = 11
|
||||
const LEGEND_ROW_H = 24
|
||||
const LEGEND_BOTTOM_PAD = 8
|
||||
|
||||
// Approximate text width using canvas measureText if available, else char-based estimate
|
||||
function measureText(text: string, fontSize: number): number {
|
||||
try {
|
||||
const ctx = document.createElement('canvas').getContext('2d')
|
||||
if (ctx) { ctx.font = `${fontSize}px sans-serif`; return ctx.measureText(text).width }
|
||||
} catch { /* noop */ }
|
||||
return text.length * fontSize * 0.6
|
||||
}
|
||||
|
||||
export function renderLegend(
|
||||
svg: d3.Selection<SVGSVGElement, unknown, null, undefined>,
|
||||
items: LegendItem[],
|
||||
width: number,
|
||||
height: number,
|
||||
onToggle: (key: string | number, hidden: boolean) => void
|
||||
) {
|
||||
// Compute per-item width: rect + text-offset + textW
|
||||
const itemWidths = items.map(item =>
|
||||
LEGEND_RECT_W + LEGEND_TEXT_OFFSET + measureText(item.label, LEGEND_FONT_SIZE)
|
||||
)
|
||||
|
||||
// Layout items into rows
|
||||
const rows: { item: LegendItem; w: number; x: number; row: number }[] = []
|
||||
let rowIdx = 0, curX = 0
|
||||
itemWidths.forEach((w, i) => {
|
||||
const slotW = w + LEGEND_GAP
|
||||
if (curX > 0 && curX + w > width - LEGEND_GAP * 2) { rowIdx++; curX = 0 }
|
||||
rows.push({ item: items[i], w, x: curX, row: rowIdx })
|
||||
curX += slotW
|
||||
})
|
||||
|
||||
const totalRows = rowIdx + 1
|
||||
const totalH = totalRows * LEGEND_ROW_H
|
||||
const baseY = height - totalH - LEGEND_BOTTOM_PAD
|
||||
|
||||
// Center each row
|
||||
const rowWidths: number[] = Array(totalRows).fill(0)
|
||||
rows.forEach(({ w, row }, i) => {
|
||||
rowWidths[row] += w + (i > 0 && rows[i - 1].row === row ? LEGEND_GAP : 0)
|
||||
})
|
||||
// Recalculate row widths properly
|
||||
const rowTotals: number[] = Array(totalRows).fill(0)
|
||||
const rowCounts: number[] = Array(totalRows).fill(0)
|
||||
rows.forEach(r => { rowCounts[r.row]++; rowTotals[r.row] += r.w })
|
||||
rowTotals.forEach((_, ri) => { rowTotals[ri] += Math.max(0, rowCounts[ri] - 1) * LEGEND_GAP })
|
||||
|
||||
const legendG = svg.append('g')
|
||||
|
||||
rows.forEach(({ item, x, row }) => {
|
||||
const rowOffsetX = (width - rowTotals[row]) / 2
|
||||
const g = legendG.append('g')
|
||||
.attr('transform', `translate(${rowOffsetX + x},${baseY + row * LEGEND_ROW_H + LEGEND_ROW_H / 2})`)
|
||||
.style('cursor', 'pointer')
|
||||
|
||||
const rect = g.append('rect')
|
||||
.attr('x', 0).attr('y', -LEGEND_RECT_H / 2)
|
||||
.attr('width', LEGEND_RECT_W).attr('height', LEGEND_RECT_H).attr('rx', 2)
|
||||
.attr('fill', item.color)
|
||||
|
||||
const text = g.append('text')
|
||||
.text(item.label)
|
||||
.attr('x', LEGEND_TEXT_OFFSET).attr('dy', '0.35em')
|
||||
.attr('font-size', `${LEGEND_FONT_SIZE}px`).attr('fill', '#5B6167')
|
||||
|
||||
let hidden = false
|
||||
g.on('click', () => {
|
||||
hidden = !hidden
|
||||
rect.attr('fill', hidden ? '#ccc' : item.color)
|
||||
text.attr('fill', hidden ? '#bbb' : '#5B6167')
|
||||
onToggle(item.key, hidden)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -1482,6 +1482,33 @@ export const en = {
|
||||
memoryNum: 'memories',
|
||||
memory_config_name: 'Memory Engine',
|
||||
searchPlaceholder: 'Search memory store name',
|
||||
|
||||
communityNetwork: 'Community Graph',
|
||||
community: 'Community',
|
||||
"Person": "Person Entity Node",
|
||||
"Organization": "Organization Entity Node",
|
||||
"ORG": "Organization Entity Node",
|
||||
"Location": "Location Entity Node",
|
||||
"LOC": "Location Entity Node",
|
||||
"Event": "Event Entity Node",
|
||||
"Concept": "Concept Entity Node",
|
||||
"Time": "Time Entity Node",
|
||||
"Position": "Position Entity Node",
|
||||
"WorkRole": "Work Role Entity Node",
|
||||
"System": "System Entity Node",
|
||||
"Policy": "Policy Entity Node",
|
||||
"HistoricalPeriod": "Historical Period Entity Node",
|
||||
"HistoricalState": "Historical State Entity Node",
|
||||
"HistoricalEvent": "Historical Event Entity Node",
|
||||
"EconomicFactor": "Economic Factor Entity Node",
|
||||
"Condition": "Condition Entity Node",
|
||||
"Numeric": "Numeric Entity Node",
|
||||
"Work": "Work / Output",
|
||||
member_count: 'Member Count',
|
||||
member_count_desc: 'entities',
|
||||
summary: 'Summary',
|
||||
core_entities: 'Core Entities',
|
||||
communityDetailEmptyDesc: 'Click on a community in the chart on the left to view details',
|
||||
},
|
||||
space: {
|
||||
createSpace: 'Create Space',
|
||||
|
||||
@@ -1480,6 +1480,33 @@ export const zh = {
|
||||
memoryNum: '条记忆',
|
||||
memory_config_name: '记忆引擎',
|
||||
searchPlaceholder: '搜索记忆库名称',
|
||||
|
||||
communityNetwork: '社区图谱',
|
||||
community: '社区',
|
||||
"Person": "人物实体节点",
|
||||
"Organization": "组织实体节点",
|
||||
"ORG": "组织实体节点",
|
||||
"Location": "地点实体节点",
|
||||
"LOC": "地点实体节点",
|
||||
"Event": "事件实体节点",
|
||||
"Concept": "概念实体节点",
|
||||
"Time": "时间实体节点",
|
||||
"Position": "职位实体节点",
|
||||
"WorkRole": "职业实体节点",
|
||||
"System": "系统实体节点",
|
||||
"Policy": "政策实体节点",
|
||||
"HistoricalPeriod": "历史时期实体节点",
|
||||
"HistoricalState": "历史国家实体节点",
|
||||
"HistoricalEvent": "历史事件实体节点",
|
||||
"EconomicFactor": "经济因素实体节点",
|
||||
"Condition": "条件实体节点",
|
||||
"Numeric": "数值实体节点",
|
||||
"Work": "作品/工作成果",
|
||||
member_count: '成员数',
|
||||
member_count_desc: '个实体',
|
||||
summary: '摘要',
|
||||
core_entities: '核心实体',
|
||||
communityDetailEmptyDesc: '点击左侧图表中的社区查看详情',
|
||||
},
|
||||
space: {
|
||||
createSpace: '创建空间',
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
import React, { useState, type FC, useEffect } from 'react'
|
||||
import { useParams } from 'react-router-dom'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import type { CommunityD3Node, CommunityGraphData, RawCommunityGraphData, RawCommunityNode } from '@/components/D3Graph/types'
|
||||
import { buildCommunityGraphData } from '@/components/D3Graph/utils'
|
||||
import CommunityGraph from '@/components/D3Graph/CommunityGraph'
|
||||
import { getMemoryCommunityGraph } from '@/api/memory'
|
||||
|
||||
// ─── Tooltip ──────────────────────────────────────────────────────────────────
|
||||
|
||||
const NodeTooltip: FC<{ node: CommunityD3Node }> = ({ node }) => {
|
||||
const { t } = useTranslation()
|
||||
return (
|
||||
<div style={{
|
||||
background: '#fff', border: '1px solid #DFE4ED', borderRadius: 8,
|
||||
boxShadow: '0 4px 16px rgba(0,0,0,0.12)', padding: '10px 14px',
|
||||
minWidth: 180, maxWidth: 260, fontSize: 13,
|
||||
}}>
|
||||
<div style={{ fontWeight: 600, marginBottom: 6, color: '#1a1a1a', fontSize: 14 }}>
|
||||
{node.properties?.name ?? node.name}
|
||||
</div>
|
||||
{node.properties?.description && (
|
||||
<div style={{ color: '#5B6167', lineHeight: '20px', marginBottom: 4 }}>
|
||||
{node.properties.description}
|
||||
</div>
|
||||
)}
|
||||
<div style={{ color: '#5B6167', lineHeight: '22px' }}>
|
||||
{t('userMemory.type')}:
|
||||
<span style={{ color: '#1a1a1a' }}>{t(`userMemory.${node.properties?.entity_type}`)}</span>
|
||||
</div>
|
||||
<div style={{ color: '#5B6167', lineHeight: '22px' }}>
|
||||
{t('userMemory.community')}:
|
||||
<span style={{ color: node.color, fontWeight: 500 }}>{node.properties?.community_name}</span>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ─── Component ────────────────────────────────────────────────────────────────
|
||||
|
||||
const CommunityNetwork: FC<{ onSelectCommunity?: (node: RawCommunityNode) => void }> = ({ onSelectCommunity }) => {
|
||||
const { id } = useParams()
|
||||
const [graphData, setGraphData] = useState<CommunityGraphData | null>(null)
|
||||
const [empty, setEmpty] = useState(false)
|
||||
|
||||
useEffect(() => {
|
||||
if (!id) return
|
||||
const controller = new AbortController()
|
||||
setEmpty(false)
|
||||
setGraphData(null)
|
||||
getMemoryCommunityGraph(id, { signal: controller.signal }).then(res => {
|
||||
const raw = res as RawCommunityGraphData
|
||||
if (!raw.nodes?.length) { setEmpty(true); return }
|
||||
const built = buildCommunityGraphData(raw)
|
||||
if (!built) { setEmpty(true); return }
|
||||
setGraphData(built)
|
||||
}).catch((e) => { if (e?.code !== 'ERR_CANCELED') setEmpty(true) })
|
||||
return () => controller.abort()
|
||||
}, [id])
|
||||
|
||||
return (
|
||||
<CommunityGraph
|
||||
data={graphData}
|
||||
empty={empty}
|
||||
showLegend={false}
|
||||
onCommunityClick={onSelectCommunity}
|
||||
renderTooltip={node => <NodeTooltip node={node} />}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(CommunityNetwork)
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 18:32:00
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 18:32:00
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-13 14:51:17
|
||||
*/
|
||||
/**
|
||||
* Relationship Network Component
|
||||
@@ -13,18 +13,20 @@
|
||||
import React, { type FC, useEffect, useState, useRef, useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useParams, useNavigate } from 'react-router-dom'
|
||||
import { Col, Row, Space, Button } from 'antd'
|
||||
import { Col, Row, Space, Button, Tabs, Flex, Divider } from 'antd'
|
||||
import dayjs from 'dayjs'
|
||||
import ReactEcharts from 'echarts-for-react'
|
||||
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import detailEmpty from '@/assets/images/userMemory/detail_empty.png'
|
||||
import type { Node, Edge, GraphData, StatementNodeProperties, ExtractedEntityNodeProperties } from '../types'
|
||||
import type { RawCommunityNode } from '@/components/D3Graph/types'
|
||||
import {
|
||||
getMemorySearchEdges,
|
||||
} from '@/api/memory'
|
||||
import Empty from '@/components/Empty'
|
||||
import Tag from '@/components/Tag'
|
||||
import CommunityNetwork from './CommunityNetwork'
|
||||
|
||||
/** Node color palette */
|
||||
const colors = ['#155EEF', '#369F21', '#4DA8FF', '#FF5D34', '#9C6FFF', '#FF8A4C', '#8BAEF7', '#FFB048']
|
||||
@@ -36,16 +38,21 @@ const RelationshipNetwork:FC = () => {
|
||||
const [nodes, setNodes] = useState<Node[]>([])
|
||||
const [links, setLinks] = useState<Edge[]>([])
|
||||
const [categories, setCategories] = useState<{ name: string }[]>([])
|
||||
const [selectedNode, setSelectedNode] = useState<Node | null>(null)
|
||||
const [selectedNode, setSelectedNode] = useState<Node | RawCommunityNode | null>(null)
|
||||
// const [fullScreen, setFullScreen] = useState<boolean>(false)
|
||||
const navigate = useNavigate()
|
||||
const [activeTab, setActiveTab] = useState('relationshipNetwork')
|
||||
|
||||
console.log('categories', categories)
|
||||
const edgeAbortRef = useRef<AbortController | null>(null)
|
||||
|
||||
/** Fetch relationship network data */
|
||||
const getEdgeData = useCallback(() => {
|
||||
if (!id) return
|
||||
edgeAbortRef.current?.abort()
|
||||
edgeAbortRef.current = new AbortController()
|
||||
setSelectedNode(null)
|
||||
getMemorySearchEdges(id).then((res) => {
|
||||
getMemorySearchEdges(id, { signal: edgeAbortRef.current.signal }).then((res) => {
|
||||
const { nodes, edges, statistics } = res as GraphData
|
||||
const curNodes: Node[] = []
|
||||
const curEdges: Edge[] = []
|
||||
@@ -123,6 +130,7 @@ const RelationshipNetwork:FC = () => {
|
||||
useEffect(() => {
|
||||
if (!id) return
|
||||
getEdgeData()
|
||||
return () => { edgeAbortRef.current?.abort() }
|
||||
}, [id])
|
||||
|
||||
useEffect(() => {
|
||||
@@ -153,34 +161,36 @@ const RelationshipNetwork:FC = () => {
|
||||
const params = new URLSearchParams({
|
||||
nodeId: selectedNode.id,
|
||||
nodeLabel: selectedNode.label,
|
||||
nodeName: selectedNode.name || ''
|
||||
nodeName: (selectedNode as Node).name || ''
|
||||
})
|
||||
navigate(`/user-memory/detail/${id}/GRAPH?${params.toString()}`)
|
||||
}
|
||||
const handleChangeTab = (tab: string) => {
|
||||
if (tab === 'communityNetwork') {
|
||||
edgeAbortRef.current?.abort()
|
||||
} else {
|
||||
getEdgeData()
|
||||
}
|
||||
setActiveTab(tab)
|
||||
setSelectedNode(null)
|
||||
}
|
||||
|
||||
return (
|
||||
<Row gutter={16}>
|
||||
{/* Relationship Network */}
|
||||
<Col span={16}>
|
||||
<RbCard
|
||||
title={t('userMemory.relationshipNetwork')}
|
||||
headerType="borderless"
|
||||
headerClassName="rb:min-h-[46px]!"
|
||||
// extra={
|
||||
// <div
|
||||
// onClick={handleFullScreen}
|
||||
// className="rb:group rb:cursor-pointer rb:hover:text-[#212332] rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:flex rb:items-center rb:gap-1"
|
||||
// >
|
||||
// <div className="rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/fullScreen.svg')] rb:hover:bg-[url('@/assets/images/fullScreen_hover.svg')]"></div>
|
||||
// {t('userMemory.fullScreen')}
|
||||
// </div>
|
||||
// }
|
||||
>
|
||||
<RbCard bodyClassName="rb:pt-0!">
|
||||
<Tabs
|
||||
items={['relationshipNetwork', 'communityNetwork'].map(key => ({ key, label: t(`userMemory.${key}`) }))}
|
||||
activeKey={activeTab}
|
||||
onChange={handleChangeTab}
|
||||
/>
|
||||
<div className="rb:h-129.5 rb:bg-[#F6F8FC] rb:border rb:border-[#DFE4ED] rb:rounded-sm">
|
||||
{nodes.length === 0 ? (
|
||||
<Empty className="rb:h-full" />
|
||||
) : (
|
||||
<ReactEcharts
|
||||
{activeTab === 'communityNetwork'
|
||||
? <CommunityNetwork onSelectCommunity={community => setSelectedNode(community)} />
|
||||
: nodes.length === 0
|
||||
? <Empty className="rb:h-full" />
|
||||
: <ReactEcharts
|
||||
option={{
|
||||
colors: colors,
|
||||
tooltip: {
|
||||
@@ -253,103 +263,121 @@ const RelationshipNetwork:FC = () => {
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
}
|
||||
</div>
|
||||
</RbCard>
|
||||
</Col>
|
||||
{/* Memory Details */}
|
||||
<Col span={8}>
|
||||
<RbCard
|
||||
<RbCard
|
||||
title={t('userMemory.memoryDetails')}
|
||||
headerType="borderless"
|
||||
headerClassName="rb:min-h-[46px]!"
|
||||
bodyClassName='rb:p-0!'
|
||||
extra={selectedNode && <Button type="text" onClick={handleViewAll}>
|
||||
<div
|
||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/userMemory/view.svg')] rb:hover:bg-[url('@/assets/images/userMemory/view_hover.svg')]"
|
||||
></div>
|
||||
{t('userMemory.completeMemory')}
|
||||
</Button>}
|
||||
bodyClassName="rb:p-0!"
|
||||
extra={selectedNode && !(selectedNode as RawCommunityNode).properties.community_id && (
|
||||
<Button type="text" onClick={handleViewAll}>
|
||||
<div className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/userMemory/view.svg')] rb:hover:bg-[url('@/assets/images/userMemory/view_hover.svg')]" />
|
||||
{t('userMemory.completeMemory')}
|
||||
</Button>
|
||||
)}
|
||||
>
|
||||
<div className="rb:h-133.5 rb:overflow-y-auto">
|
||||
{!selectedNode
|
||||
? <Empty
|
||||
url={detailEmpty}
|
||||
subTitle={t('userMemory.memoryDetailEmptyDesc')}
|
||||
className="rb:h-full rb:mx-10 rb:text-center"
|
||||
size={[197.81, 150]}
|
||||
/>
|
||||
: <>
|
||||
{selectedNode.name && <div className="rb:bg-[#F6F8FC] rb:border-t rb:border-b rb:border-[#DFE4ED] rb:font-medium rb:py-2 rb:px-4 rb:h-10">{selectedNode.name}</div>}
|
||||
<div className="rb:p-4">
|
||||
<>
|
||||
? <Empty url={detailEmpty} subTitle={activeTab === 'relationshipNetwork' ? t('userMemory.memoryDetailEmptyDesc') : t('userMemory.communityDetailEmptyDesc')} className="rb:h-full rb:mx-10 rb:text-center" size={[197.81, 150]} />
|
||||
: (selectedNode as RawCommunityNode).properties.community_id
|
||||
? <div className="rb:p-3 rb:pt-0">
|
||||
<div className="rb:font-medium rb:text-[#212332] rb:text-[16px] rb:leading-5.5 rb:pl-1">
|
||||
{(selectedNode as RawCommunityNode).properties.name}
|
||||
</div>
|
||||
<div className="rb:mt-3 rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.summary')}</div>
|
||||
<div className="rb:bg-[#F6F6F6] rb:rounded-xl rb:px-3 rb:py-2.5 rb:mt-2">
|
||||
{(selectedNode as RawCommunityNode).properties.summary}
|
||||
</div>
|
||||
<Flex align="center" justify="space-between" className="rb:mt-5!">
|
||||
<span className="rb:text-[#5B6167] rb:font-regular rb:pl-1">{t('userMemory.member_count')}</span>
|
||||
<span className="rb:font-medium">{(selectedNode as RawCommunityNode).properties.member_count}{t('userMemory.member_count_desc')}</span>
|
||||
</Flex>
|
||||
|
||||
<Divider className='rb:my-2.5!' />
|
||||
<div className="rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.core_entities')}</div>
|
||||
<ul className="rb:list-disc rb:pl-4 rb:text-[#5B6167] rb:mt-2">
|
||||
{(selectedNode as RawCommunityNode).properties.core_entities.map((entity, index) => <li key={index}>{entity}</li>)}
|
||||
</ul>
|
||||
</div>
|
||||
: <>
|
||||
{(selectedNode as Node).name && (
|
||||
<div className="rb:bg-[#F6F8FC] rb:border-t rb:border-b rb:border-[#DFE4ED] rb:font-medium rb:py-2 rb:px-4 rb:h-10">
|
||||
{(selectedNode as Node).name}
|
||||
</div>
|
||||
)}
|
||||
<div className="rb:p-4">
|
||||
<div className="rb:font-medium rb:leading-5">{t('userMemory.memoryContent')}</div>
|
||||
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
|
||||
{['Chunk', 'Dialogue', 'MemorySummary'].includes(selectedNode.label) && 'content' in selectedNode.properties
|
||||
? selectedNode.properties.content
|
||||
: selectedNode.label === 'ExtractedEntity' && 'description' in selectedNode.properties
|
||||
? selectedNode.properties.description
|
||||
: selectedNode.label === 'Statement' && 'statement' in selectedNode.properties
|
||||
? selectedNode.properties.statement
|
||||
: ''
|
||||
}
|
||||
</div>
|
||||
</>
|
||||
<div className="rb:font-medium rb:mb-2 rb:mt-4">
|
||||
<div className="rb:font-medium rb:leading-5">{t('userMemory.created_at')}</div>
|
||||
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
|
||||
{dayjs(selectedNode?.properties.created_at).format('YYYY-MM-DD HH:mm:ss')}
|
||||
? selectedNode.properties.description
|
||||
: selectedNode.label === 'Statement' && 'statement' in selectedNode.properties
|
||||
? selectedNode.properties.statement
|
||||
: ''}
|
||||
</div>
|
||||
|
||||
{selectedNode?.properties.associative_memory > 0 && <div className="rb:mt-4">
|
||||
<div className="rb:font-medium rb:leading-5">{t('userMemory.associative_memory')}</div>
|
||||
<div className="rb:font-medium rb:mb-2 rb:mt-4">
|
||||
<div className="rb:font-medium rb:leading-5">{t('userMemory.created_at')}</div>
|
||||
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
|
||||
<span className="rb:text-[#155EEF] rb:font-medium">{selectedNode?.properties.associative_memory}</span> {t('userMemory.unix')}{t('userMemory.associative_memory')}
|
||||
{dayjs((selectedNode as Node).properties.created_at).format('YYYY-MM-DD HH:mm:ss')}
|
||||
</div>
|
||||
</div>}
|
||||
|
||||
{selectedNode.label === 'Statement' && <>
|
||||
{(['emotion_keywords', 'emotion_type', 'emotion_subject', 'importance_score'] as const).map(key => {
|
||||
const statementProps = selectedNode.properties as StatementNodeProperties;
|
||||
if ((key === 'emotion_keywords' && statementProps[key]?.length > 0) || typeof statementProps[key] === 'string') {
|
||||
console.log('statementProps[key]', statementProps[key])
|
||||
return (
|
||||
<div className="rb:mt-4" key={key}>
|
||||
{t(`userMemory.Statement_${key}`)}
|
||||
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
|
||||
{key === 'emotion_keywords'
|
||||
? <Space>{statementProps.emotion_keywords.map((vo, index) => <Tag key={index}>{vo}</Tag>)}</Space>
|
||||
: statementProps[key]
|
||||
}
|
||||
{(selectedNode as Node).properties.associative_memory > 0 && (
|
||||
<div className="rb:mt-4">
|
||||
<div className="rb:font-medium rb:leading-5">{t('userMemory.associative_memory')}</div>
|
||||
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
|
||||
<span className="rb:text-[#155EEF] rb:font-medium">{(selectedNode as Node).properties.associative_memory}</span>
|
||||
{' '}{t('userMemory.unix')}{t('userMemory.associative_memory')}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{selectedNode.label === 'Statement' && (
|
||||
(['emotion_keywords', 'emotion_type', 'emotion_subject', 'importance_score'] as const).map(key => {
|
||||
const p = selectedNode.properties as StatementNodeProperties
|
||||
if ((key === 'emotion_keywords' && p[key]?.length > 0) || typeof p[key] === 'string') {
|
||||
return (
|
||||
<div className="rb:mt-4" key={key}>
|
||||
{t(`userMemory.Statement_${key}`)}
|
||||
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
|
||||
{key === 'emotion_keywords'
|
||||
? <Space>{p.emotion_keywords.map((v, i) => <Tag key={i}>{v}</Tag>)}</Space>
|
||||
: p[key]}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
return null
|
||||
})}
|
||||
</>}
|
||||
{selectedNode.label === 'ExtractedEntity' && <>
|
||||
{(['name', 'entity_type', 'aliases', 'connect_strngth', 'importance_score'] as const).map(key => {
|
||||
const entityProps = selectedNode.properties as ExtractedEntityNodeProperties;
|
||||
if (entityProps[key]) {
|
||||
return (
|
||||
<div className="rb:mt-4" key={key}>
|
||||
{t(`userMemory.ExtractedEntity_${key}`)}
|
||||
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
|
||||
{Array.isArray(entityProps[key]) && entityProps[key].length > 0
|
||||
? entityProps[key].map((vo, index) => <div key={index}>- {vo}</div>)
|
||||
: entityProps[key]
|
||||
}
|
||||
)
|
||||
}
|
||||
return null
|
||||
})
|
||||
)}
|
||||
|
||||
{selectedNode.label === 'ExtractedEntity' && (
|
||||
(['name', 'entity_type', 'aliases', 'connect_strngth', 'importance_score'] as const).map(key => {
|
||||
const p = selectedNode.properties as ExtractedEntityNodeProperties
|
||||
if (p[key]) {
|
||||
return (
|
||||
<div className="rb:mt-4" key={key}>
|
||||
{t(`userMemory.ExtractedEntity_${key}`)}
|
||||
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
|
||||
{Array.isArray(p[key]) && p[key].length > 0
|
||||
? p[key].map((v, i) => <div key={i}>- {v}</div>)
|
||||
: p[key]}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
return null
|
||||
})}
|
||||
</>}
|
||||
)
|
||||
}
|
||||
return null
|
||||
})
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
</>
|
||||
}
|
||||
</div>
|
||||
</RbCard>
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 17:57:15
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 17:57:15
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-13 11:49:52
|
||||
*/
|
||||
/**
|
||||
* User Memory Detail Types
|
||||
@@ -90,6 +90,7 @@ export interface ExtractedEntityNodeProperties {
|
||||
connect_strngth: string;
|
||||
importance_score: number;
|
||||
associative_memory: number;
|
||||
community_name?: string;
|
||||
}
|
||||
/**
|
||||
* Memory summary node
|
||||
@@ -246,4 +247,53 @@ export interface ForgetData {
|
||||
*/
|
||||
export interface GraphDetailRef {
|
||||
handleOpen: (vo: Node) => void
|
||||
}
|
||||
}
|
||||
// Community
|
||||
export type CommunityNodeType = 'Community' | 'ExtractedEntity';
|
||||
export type CommunityEdgeType = 'BELONGS_TO_COMMUNITY' | 'EXTRACTED_RELATIONSHIP';
|
||||
export type CommunityEntityType = "Person" | "Organization" | "ORG" | "Location" | "LOC" | "Event" | "Concept" | "Time" | "Position" | "WorkRole" | "System" | "Policy" | "HistoricalPeriod" | "HistoricalState" | "HistoricalEvent" | "EconomicFactor" | "Condition" | "Numeric" | "Work";
|
||||
// 社区节点
|
||||
export interface CommunityTypeNode {
|
||||
id: string;
|
||||
label: 'Community';
|
||||
properties: {
|
||||
community_id: string;
|
||||
end_user_id: string;
|
||||
member_count: number;
|
||||
updated_at: string;
|
||||
name: string;
|
||||
summary: string;
|
||||
core_entities: string[];
|
||||
member_entity_ids: string[];
|
||||
};
|
||||
}
|
||||
// 核心实体
|
||||
export interface ExtractedEntityTypeNode {
|
||||
id: string;
|
||||
label: 'ExtractedEntity';
|
||||
properties: {
|
||||
name: string;
|
||||
end_user_id: string;
|
||||
description: string;
|
||||
created_at: string;
|
||||
entity_type: CommunityEntityType;
|
||||
community_name: string;
|
||||
};
|
||||
}
|
||||
// 社区图谱连线
|
||||
export interface CommunityEdge {
|
||||
id: string;
|
||||
target: string;
|
||||
source: string;
|
||||
}
|
||||
export interface CommunityStatistics {
|
||||
total_nodes: number;
|
||||
total_edges: number;
|
||||
node_types: Record<CommunityNodeType, number>;
|
||||
edge_types: Record<CommunityEdgeType, number>;
|
||||
}
|
||||
export interface CommunityGraphData {
|
||||
nodes: (CommunityTypeNode | ExtractedEntityTypeNode)[];
|
||||
edges: CommunityEdge[];
|
||||
statistics: CommunityStatistics;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user