Merge branch 'develop' of https://github.com/SuanmoSuanyangTechnology/MemoryBear into feature/app-share-wxy

This commit is contained in:
wxy
2026-03-13 17:24:20 +08:00
100 changed files with 8956 additions and 1123 deletions

View File

@@ -1,10 +1,11 @@
import os
import asyncio
import json
import logging
from typing import Dict, Any, Optional
import redis.asyncio as redis
from redis.asyncio import ConnectionPool
from app.core.config import settings
# 设置日志记录器

View File

@@ -62,7 +62,7 @@ celery_app.conf.update(
task_serializer='json',
accept_content=['json'],
result_serializer='json',
# 时区
timezone='Asia/Shanghai',
enable_utc=False,
@@ -70,43 +70,44 @@ celery_app.conf.update(
# 任务追踪
task_track_started=True,
task_ignore_result=False,
# 超时设置
task_time_limit=3600, # 60分钟硬超时
task_soft_time_limit=3000, # 50分钟软超时
# Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
# 结果过期时间
result_expires=3600, # 结果保存1小时
# 任务确认设置
task_acks_late=True,
task_reject_on_worker_lost=True,
worker_disable_rate_limits=True,
# FLower setting
worker_send_task_events=True,
task_send_sent_event=True,
# task routing
task_routes={
# Memory tasks → memory_tasks queue (threads worker)
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'},
# Long-term storage tasks → memory_tasks queue (batched write strategies)
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
@@ -131,7 +132,7 @@ implicit_emotions_update_schedule = crontab(
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
)
#构建定时任务配置
# 构建定时任务配置
beat_schedule_config = {
"run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task",

View File

@@ -13,4 +13,4 @@ logger.info("Celery worker logging initialized")
# 导入任务模块以注册任务
import app.tasks
__all__ = ['celery_app']
__all__ = ['celery_app']

View File

@@ -16,6 +16,7 @@ from . import (
file_controller,
file_storage_controller,
home_page_controller,
i18n_controller,
implicit_memory_controller,
knowledge_controller,
knowledgeshare_controller,
@@ -94,5 +95,6 @@ manager_router.include_router(memory_working_controller.router)
manager_router.include_router(file_storage_controller.router)
manager_router.include_router(ontology_controller.router)
manager_router.include_router(skill_controller.router)
manager_router.include_router(i18n_controller.router)
__all__ = ["manager_router"]

View File

@@ -844,6 +844,15 @@ async def draft_run_compare(
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
service._validate_app_accessible(app, workspace_id)
if payload.user_id is None:
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app_id,
other_id=str(current_user.id),
original_user_id=str(current_user.id) # Save original user_id to other_id
)
payload.user_id = str(new_end_user.id)
# 2. 获取 Agent 配置
from sqlalchemy import select
from app.models import AgentConfig
@@ -889,6 +898,8 @@ async def draft_run_compare(
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
})
# 流式返回
if payload.stream:
async def event_generator():
@@ -900,7 +911,7 @@ async def draft_run_compare(
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
user_id=payload.user_id,
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
@@ -931,7 +942,7 @@ async def draft_run_compare(
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
user_id=payload.user_id,
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,

View File

@@ -1,4 +1,5 @@
from datetime import datetime, timedelta, timezone
from typing import Callable
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
@@ -16,6 +17,7 @@ from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.dependencies import get_current_user, oauth2_scheme
from app.models.user_model import User
from app.i18n.dependencies import get_translator
# 获取专用日志器
auth_logger = get_auth_logger()
@@ -26,7 +28,8 @@ router = APIRouter(tags=["Authentication"])
@router.post("/token", response_model=ApiResponse)
async def login_for_access_token(
form_data: TokenRequest,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
t: Callable = Depends(get_translator)
):
"""用户登录获取token"""
auth_logger.info(f"用户登录请求: {form_data.email}")
@@ -40,10 +43,10 @@ async def login_for_access_token(
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
if not invite_info.is_valid:
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
raise BusinessException(t("auth.invite.invalid"), code=BizCode.BAD_REQUEST)
if invite_info.email != form_data.email:
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
raise BusinessException(t("auth.invite.email_mismatch"), code=BizCode.BAD_REQUEST)
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
try:
# 尝试认证用户
@@ -69,7 +72,7 @@ async def login_for_access_token(
elif e.code == BizCode.PASSWORD_ERROR:
# 用户存在但密码错误
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
raise BusinessException(t("auth.invite.password_verification_failed"), BizCode.LOGIN_FAILED)
else:
# 其他认证失败情况,直接抛出
raise
@@ -82,7 +85,7 @@ async def login_for_access_token(
except BusinessException as e:
# 其他认证失败情况,直接抛出
raise BusinessException(e.message,BizCode.LOGIN_FAILED)
raise BusinessException(e.message, BizCode.LOGIN_FAILED)
# 创建 tokens
access_token, access_token_id = security.create_access_token(subject=user.id)
@@ -110,14 +113,15 @@ async def login_for_access_token(
expires_at=access_expires_at,
refresh_expires_at=refresh_expires_at
),
msg="登录成功"
msg=t("auth.login.success")
)
@router.post("/refresh", response_model=ApiResponse)
async def refresh_token(
refresh_request: RefreshTokenRequest,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
t: Callable = Depends(get_translator)
):
"""刷新token"""
auth_logger.info("收到token刷新请求")
@@ -125,18 +129,18 @@ async def refresh_token(
# 验证 refresh token
userId = security.verify_token(refresh_request.refresh_token, "refresh")
if not userId:
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
raise BusinessException(t("auth.token.invalid_refresh_token"), code=BizCode.TOKEN_INVALID)
# 检查用户是否存在
user = auth_service.get_user_by_id(db, userId)
if not user:
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
# 检查 refresh token 黑名单
if settings.ENABLE_SINGLE_SESSION:
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
raise BusinessException(t("auth.token.refresh_token_blacklisted"), code=BizCode.TOKEN_BLACKLISTED)
# 生成新 tokens
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
@@ -167,7 +171,7 @@ async def refresh_token(
expires_at=access_expires_at,
refresh_expires_at=refresh_expires_at
),
msg="token刷新成功"
msg=t("auth.token.refresh_success")
)
@@ -175,14 +179,15 @@ async def refresh_token(
async def logout(
token: str = Depends(oauth2_scheme),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
db: Session = Depends(get_db),
t: Callable = Depends(get_translator)
):
"""登出当前用户加入token黑名单并清理会话"""
auth_logger.info(f"用户 {current_user.username} 请求登出")
token_id = security.get_token_id(token)
if not token_id:
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
raise BusinessException(t("auth.token.invalid"), code=BizCode.TOKEN_INVALID)
# 加入黑名单
await SessionService.blacklist_token(token_id)
@@ -192,5 +197,5 @@ async def logout(
await SessionService.clear_user_session(current_user.username)
auth_logger.info(f"用户 {current_user.username} 登出成功")
return success(msg="登出成功")
return success(msg=t("auth.logout.success"))

View File

@@ -0,0 +1,833 @@
"""
I18n Management API Controller
This module provides management APIs for:
- Language management (list, get, add, update languages)
- Translation management (get, update, reload translations)
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import Callable, Optional
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user, get_current_superuser
from app.i18n.dependencies import get_translator
from app.i18n.service import get_translation_service
from app.models.user_model import User
from app.schemas.i18n_schema import (
LanguageInfo,
LanguageListResponse,
LanguageCreateRequest,
LanguageUpdateRequest,
TranslationResponse,
TranslationUpdateRequest,
MissingTranslationsResponse,
ReloadResponse
)
from app.schemas.response_schema import ApiResponse
api_logger = get_api_logger()
router = APIRouter(
prefix="/i18n",
tags=["I18n Management"],
)
# ============================================================================
# Language Management APIs
# ============================================================================
@router.get("/languages", response_model=ApiResponse)
def get_languages(
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get list of all supported languages.
Returns:
List of language information including code, name, and status
"""
api_logger.info(f"Get languages request from user: {current_user.username}")
from app.core.config import settings
translation_service = get_translation_service()
# Get available locales from translation service
available_locales = translation_service.get_available_locales()
# Build language info list
languages = []
for locale in available_locales:
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
# Get native names
native_names = {
"zh": "中文(简体)",
"en": "English",
"ja": "日本語",
"ko": "한국어",
"fr": "Français",
"de": "Deutsch",
"es": "Español"
}
language_info = LanguageInfo(
code=locale,
name=f"{locale.upper()}",
native_name=native_names.get(locale, locale),
is_enabled=is_enabled,
is_default=is_default
)
languages.append(language_info)
response = LanguageListResponse(languages=languages)
api_logger.info(f"Returning {len(languages)} languages")
return success(data=response.dict(), msg=t("common.success.retrieved"))
@router.get("/languages/{locale}", response_model=ApiResponse)
def get_language(
locale: str,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get information about a specific language.
Args:
locale: Language code (e.g., 'zh', 'en')
Returns:
Language information
"""
api_logger.info(f"Get language info request: locale={locale}, user={current_user.username}")
from app.core.config import settings
translation_service = get_translation_service()
# Check if locale exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
# Build language info
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
native_names = {
"zh": "中文(简体)",
"en": "English",
"ja": "日本語",
"ko": "한국어",
"fr": "Français",
"de": "Deutsch",
"es": "Español"
}
language_info = LanguageInfo(
code=locale,
name=f"{locale.upper()}",
native_name=native_names.get(locale, locale),
is_enabled=is_enabled,
is_default=is_default
)
api_logger.info(f"Returning language info for: {locale}")
return success(data=language_info.dict(), msg=t("common.success.retrieved"))
@router.post("/languages", response_model=ApiResponse)
def add_language(
request: LanguageCreateRequest,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Add a new language (admin only).
Note: This endpoint validates the request but actual language addition
requires creating translation files in the locales directory.
Args:
request: Language creation request
Returns:
Success message
"""
api_logger.info(
f"Add language request: code={request.code}, admin={current_user.username}"
)
from app.core.config import settings
translation_service = get_translation_service()
# Check if language already exists
available_locales = translation_service.get_available_locales()
if request.code in available_locales:
api_logger.warning(f"Language already exists: {request.code}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=t("i18n.language.already_exists", locale=request.code)
)
# Note: Actual language addition requires creating translation files
# This endpoint serves as a validation and documentation point
api_logger.info(
f"Language addition validated: {request.code}. "
"Translation files need to be created manually."
)
return success(
msg=t(
"i18n.language.add_instructions",
locale=request.code,
dir=settings.I18N_CORE_LOCALES_DIR
)
)
@router.put("/languages/{locale}", response_model=ApiResponse)
def update_language(
locale: str,
request: LanguageUpdateRequest,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Update language configuration (admin only).
Note: This endpoint validates the request but actual configuration
changes require updating environment variables or config files.
Args:
locale: Language code
request: Language update request
Returns:
Success message
"""
api_logger.info(
f"Update language request: locale={locale}, admin={current_user.username}"
)
translation_service = get_translation_service()
# Check if language exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
# Note: Actual configuration changes require updating settings
# This endpoint serves as a validation and documentation point
api_logger.info(
f"Language update validated: {locale}. "
"Configuration changes require environment variable updates."
)
return success(msg=t("i18n.language.update_instructions", locale=locale))
# ============================================================================
# Translation Management APIs
# ============================================================================
@router.get("/translations", response_model=ApiResponse)
def get_all_translations(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get all translations for all or specific locale.
Args:
locale: Optional locale filter
Returns:
All translations organized by locale and namespace
"""
api_logger.info(
f"Get all translations request: locale={locale}, user={current_user.username}"
)
translation_service = get_translation_service()
if locale:
# Get translations for specific locale
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
translations = {
locale: translation_service._cache.get(locale, {})
}
else:
# Get all translations
translations = translation_service._cache
response = TranslationResponse(translations=translations)
api_logger.info(f"Returning translations for: {locale or 'all locales'}")
return success(data=response.dict(), msg=t("common.success.retrieved"))
@router.get("/translations/{locale}", response_model=ApiResponse)
def get_locale_translations(
locale: str,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get all translations for a specific locale.
Args:
locale: Language code
Returns:
All translations for the locale organized by namespace
"""
api_logger.info(
f"Get locale translations request: locale={locale}, user={current_user.username}"
)
translation_service = get_translation_service()
# Check if locale exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
translations = translation_service._cache.get(locale, {})
api_logger.info(f"Returning {len(translations)} namespaces for locale: {locale}")
return success(data={"locale": locale, "translations": translations}, msg=t("common.success.retrieved"))
@router.get("/translations/{locale}/{namespace}", response_model=ApiResponse)
def get_namespace_translations(
locale: str,
namespace: str,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get translations for a specific namespace in a locale.
Args:
locale: Language code
namespace: Translation namespace (e.g., 'common', 'auth')
Returns:
Translations for the specified namespace
"""
api_logger.info(
f"Get namespace translations request: locale={locale}, "
f"namespace={namespace}, user={current_user.username}"
)
translation_service = get_translation_service()
# Check if locale exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
# Get namespace translations
locale_translations = translation_service._cache.get(locale, {})
namespace_translations = locale_translations.get(namespace, {})
if not namespace_translations:
api_logger.warning(f"Namespace not found: {namespace} in locale: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.namespace.not_found", namespace=namespace, locale=locale)
)
api_logger.info(
f"Returning translations for namespace: {namespace} in locale: {locale}"
)
return success(
data={
"locale": locale,
"namespace": namespace,
"translations": namespace_translations
},
msg=t("common.success.retrieved")
)
@router.put("/translations/{locale}/{key:path}", response_model=ApiResponse)
def update_translation(
locale: str,
key: str,
request: TranslationUpdateRequest,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Update a single translation (admin only).
Note: This endpoint validates the request but actual translation updates
require modifying translation files in the locales directory.
Args:
locale: Language code
key: Translation key (format: "namespace.key.subkey")
request: Translation update request
Returns:
Success message
"""
api_logger.info(
f"Update translation request: locale={locale}, key={key}, "
f"admin={current_user.username}"
)
translation_service = get_translation_service()
# Check if locale exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
# Validate key format
if "." not in key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=t("i18n.translation.invalid_key_format", key=key)
)
# Note: Actual translation updates require modifying JSON files
# This endpoint serves as a validation and documentation point
api_logger.info(
f"Translation update validated: {locale}/{key}. "
"Translation files need to be updated manually."
)
return success(
msg=t("i18n.translation.update_instructions", locale=locale, key=key)
)
@router.get("/translations/missing", response_model=ApiResponse)
def get_missing_translations(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get list of missing translations.
Compares translations across locales to find missing keys.
Args:
locale: Optional locale to check (defaults to checking all non-default locales)
Returns:
List of missing translation keys
"""
api_logger.info(
f"Get missing translations request: locale={locale}, user={current_user.username}"
)
from app.core.config import settings
translation_service = get_translation_service()
default_locale = settings.I18N_DEFAULT_LANGUAGE
available_locales = translation_service.get_available_locales()
# Get default locale translations as reference
default_translations = translation_service._cache.get(default_locale, {})
# Collect all keys from default locale
def collect_keys(data, prefix=""):
keys = []
for key, value in data.items():
full_key = f"{prefix}.{key}" if prefix else key
if isinstance(value, dict):
keys.extend(collect_keys(value, full_key))
else:
keys.append(full_key)
return keys
default_keys = set()
for namespace, translations in default_translations.items():
namespace_keys = collect_keys(translations, namespace)
default_keys.update(namespace_keys)
# Find missing keys in target locale(s)
missing_by_locale = {}
target_locales = [locale] if locale else [
loc for loc in available_locales if loc != default_locale
]
for target_locale in target_locales:
if target_locale not in available_locales:
continue
target_translations = translation_service._cache.get(target_locale, {})
target_keys = set()
for namespace, translations in target_translations.items():
namespace_keys = collect_keys(translations, namespace)
target_keys.update(namespace_keys)
missing_keys = default_keys - target_keys
if missing_keys:
missing_by_locale[target_locale] = sorted(list(missing_keys))
response = MissingTranslationsResponse(missing_translations=missing_by_locale)
total_missing = sum(len(keys) for keys in missing_by_locale.values())
api_logger.info(f"Found {total_missing} missing translations across {len(missing_by_locale)} locales")
return success(data=response.dict(), msg=t("common.success.retrieved"))
@router.post("/reload", response_model=ApiResponse)
def reload_translations(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Trigger hot reload of translation files (admin only).
Args:
locale: Optional locale to reload (defaults to reloading all locales)
Returns:
Reload status and statistics
"""
api_logger.info(
f"Reload translations request: locale={locale or 'all'}, "
f"admin={current_user.username}"
)
from app.core.config import settings
if not settings.I18N_ENABLE_HOT_RELOAD:
api_logger.warning("Hot reload is disabled in configuration")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=t("i18n.reload.disabled")
)
translation_service = get_translation_service()
try:
# Reload translations
translation_service.reload(locale)
# Get statistics
available_locales = translation_service.get_available_locales()
reloaded_locales = [locale] if locale else available_locales
response = ReloadResponse(
success=True,
reloaded_locales=reloaded_locales,
total_locales=len(available_locales)
)
api_logger.info(
f"Successfully reloaded translations for: {', '.join(reloaded_locales)}"
)
return success(data=response.dict(), msg=t("i18n.reload.success"))
except Exception as e:
api_logger.error(f"Failed to reload translations: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=t("i18n.reload.failed", error=str(e))
)
# ============================================================================
# Performance Monitoring APIs
# ============================================================================
@router.get("/metrics", response_model=ApiResponse)
def get_metrics(
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Get i18n performance metrics (admin only).
Returns:
Performance metrics including:
- Request counts
- Missing translations
- Timing statistics
- Locale usage
- Error counts
"""
api_logger.info(f"Get metrics request: admin={current_user.username}")
translation_service = get_translation_service()
metrics = translation_service.get_metrics_summary()
api_logger.info("Returning i18n metrics")
return success(data=metrics, msg=t("common.success.retrieved"))
@router.get("/metrics/cache", response_model=ApiResponse)
def get_cache_stats(
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Get cache statistics (admin only).
Returns:
Cache statistics including:
- Hit/miss rates
- LRU cache performance
- Loaded locales
- Memory usage
"""
api_logger.info(f"Get cache stats request: admin={current_user.username}")
translation_service = get_translation_service()
cache_stats = translation_service.get_cache_stats()
memory_usage = translation_service.get_memory_usage()
data = {
"cache": cache_stats,
"memory": memory_usage
}
api_logger.info("Returning cache statistics")
return success(data=data, msg=t("common.success.retrieved"))
@router.get("/metrics/prometheus")
def get_prometheus_metrics(
current_user: User = Depends(get_current_superuser)
):
"""
Get metrics in Prometheus format (admin only).
Returns:
Prometheus-formatted metrics as plain text
"""
api_logger.info(f"Get Prometheus metrics request: admin={current_user.username}")
from app.i18n.metrics import get_metrics
metrics = get_metrics()
prometheus_output = metrics.export_prometheus()
from fastapi.responses import PlainTextResponse
return PlainTextResponse(content=prometheus_output)
@router.post("/metrics/reset", response_model=ApiResponse)
def reset_metrics(
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Reset all metrics (admin only).
Returns:
Success message
"""
api_logger.info(f"Reset metrics request: admin={current_user.username}")
from app.i18n.metrics import get_metrics
metrics = get_metrics()
metrics.reset()
translation_service = get_translation_service()
translation_service.cache.reset_stats()
api_logger.info("Metrics reset completed")
return success(msg=t("i18n.metrics.reset_success"))
# ============================================================================
# Missing Translation Logging and Reporting APIs
# ============================================================================
@router.get("/logs/missing", response_model=ApiResponse)
def get_missing_translation_logs(
locale: Optional[str] = None,
limit: Optional[int] = 100,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Get missing translation logs (admin only).
Returns logged missing translations with context information.
Args:
locale: Optional locale filter
limit: Maximum number of entries to return (default: 100)
Returns:
Missing translation logs with context
"""
api_logger.info(
f"Get missing translation logs request: locale={locale}, "
f"limit={limit}, admin={current_user.username}"
)
translation_service = get_translation_service()
translation_logger = translation_service.translation_logger
# Get missing translations
missing_translations = translation_logger.get_missing_translations(locale)
# Get missing with context
missing_with_context = translation_logger.get_missing_with_context(locale, limit)
# Get statistics
statistics = translation_logger.get_statistics()
data = {
"missing_translations": missing_translations,
"recent_context": missing_with_context,
"statistics": statistics
}
api_logger.info(
f"Returning {statistics['total_missing']} missing translations"
)
return success(data=data, msg=t("common.success.retrieved"))
@router.get("/logs/missing/report", response_model=ApiResponse)
def generate_missing_translation_report(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Generate a comprehensive missing translation report (admin only).
Args:
locale: Optional locale filter
Returns:
Comprehensive report with missing translations and statistics
"""
api_logger.info(
f"Generate missing translation report request: locale={locale}, "
f"admin={current_user.username}"
)
translation_service = get_translation_service()
translation_logger = translation_service.translation_logger
# Generate report
report = translation_logger.generate_report(locale)
api_logger.info(
f"Generated report with {report['total_missing']} missing translations"
)
return success(data=report, msg=t("common.success.retrieved"))
@router.post("/logs/missing/export", response_model=ApiResponse)
def export_missing_translations(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Export missing translations to JSON file (admin only).
Args:
locale: Optional locale filter
Returns:
Export status and file path
"""
api_logger.info(
f"Export missing translations request: locale={locale}, "
f"admin={current_user.username}"
)
from datetime import datetime
translation_service = get_translation_service()
translation_logger = translation_service.translation_logger
# Generate filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
locale_suffix = f"_{locale}" if locale else "_all"
output_file = f"logs/i18n/missing_translations{locale_suffix}_{timestamp}.json"
# Export to file
translation_logger.export_to_json(output_file)
api_logger.info(f"Missing translations exported to: {output_file}")
return success(
data={"file_path": output_file},
msg=t("i18n.logs.export_success", file=output_file)
)
@router.delete("/logs/missing", response_model=ApiResponse)
def clear_missing_translation_logs(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Clear missing translation logs (admin only).
Args:
locale: Optional locale to clear (clears all if not specified)
Returns:
Success message
"""
api_logger.info(
f"Clear missing translation logs request: locale={locale or 'all'}, "
f"admin={current_user.username}"
)
translation_service = get_translation_service()
translation_logger = translation_service.translation_logger
# Clear logs
translation_logger.clear(locale)
api_logger.info(f"Cleared missing translation logs for: {locale or 'all locales'}")
return success(msg=t("i18n.logs.clear_success"))

View File

@@ -1,3 +1,19 @@
"""
Memory Reflection Controller
This module provides REST API endpoints for managing memory reflection configurations
and operations. It handles reflection engine setup, configuration management, and
execution of self-reflection processes across memory systems.
Key Features:
- Reflection configuration management (save, retrieve, update)
- Workspace-wide reflection execution across multiple applications
- Individual configuration-based reflection runs
- Multi-language support for reflection outputs
- Integration with Neo4j memory storage and LLM models
- Comprehensive error handling and logging
"""
import asyncio
import time
import uuid
@@ -28,9 +44,13 @@ from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id
# Load environment variables for configuration
load_dotenv()
# Initialize API logger for request tracking and debugging
api_logger = get_api_logger()
# Configure router with prefix and tags for API organization
router = APIRouter(
prefix="/memory",
tags=["Memory"],
@@ -43,7 +63,38 @@ async def save_reflection_config(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""Save reflection configuration to data_comfig table"""
"""
Save reflection configuration to memory config table
Persists reflection engine configuration settings to the data_config table,
including reflection parameters, model settings, and evaluation criteria.
Validates configuration parameters and ensures data consistency.
Args:
request: Memory reflection configuration data including:
- config_id: Configuration identifier to update
- reflection_enabled: Whether reflection is enabled
- reflection_period_in_hours: Reflection execution interval
- reflexion_range: Scope of reflection (partial/all)
- baseline: Reflection strategy (time/fact/hybrid)
- reflection_model_id: LLM model for reflection operations
- memory_verify: Enable memory verification checks
- quality_assessment: Enable quality assessment evaluation
current_user: Authenticated user saving the configuration
db: Database session for data operations
Returns:
dict: Success response with saved reflection configuration data
Raises:
HTTPException 400: If config_id is missing or parameters are invalid
HTTPException 500: If configuration save operation fails
Database Operations:
- Updates memory_config table with reflection settings
- Commits transaction and refreshes entity
- Maintains configuration consistency
"""
try:
config_id = request.config_id
config_id = resolve_config_id(config_id, db)
@@ -54,6 +105,7 @@ async def save_reflection_config(
)
api_logger.info(f"用户 {current_user.username} 保存反思配置config_id: {config_id}")
# Update reflection configuration in database
memory_config = MemoryConfigRepository.update_reflection_config(
db,
config_id=config_id,
@@ -66,6 +118,7 @@ async def save_reflection_config(
quality_assessment=request.quality_assessment
)
# Commit transaction and refresh entity
db.commit()
db.refresh(memory_config)
@@ -102,13 +155,55 @@ async def start_workspace_reflection(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""启动工作空间中所有匹配应用的反思功能"""
"""
Start reflection functionality for all matching applications in workspace
Initiates reflection processes across all applications within the user's current
workspace that have valid memory configurations. Processes each application's
configurations and associated end users, executing reflection operations
with proper error isolation and transaction management.
This endpoint serves as a workspace-wide reflection orchestrator, ensuring
that reflection failures for individual users don't affect other operations.
Args:
current_user: Authenticated user initiating workspace reflection
db: Database session for configuration queries
Returns:
dict: Success response with reflection results for all processed applications:
- app_id: Application identifier
- config_id: Memory configuration identifier
- end_user_id: End user identifier
- reflection_result: Individual reflection operation result
Processing Logic:
1. Retrieve all applications in the current workspace
2. Filter applications with valid memory configurations
3. For each configuration, find matching releases
4. Execute reflection for each end user with isolated transactions
5. Aggregate results with error handling per user
Error Handling:
- Individual user reflection failures are isolated
- Failed operations are logged and included in results
- Database transactions are isolated per user to prevent cascading failures
- Comprehensive error reporting for debugging
Raises:
HTTPException 500: If workspace reflection initialization fails
Performance Notes:
- Uses independent database sessions for each user operation
- Prevents transaction failures from affecting other users
- Comprehensive logging for operation tracking
"""
workspace_id = current_user.current_workspace_id
try:
api_logger.info(f"用户 {current_user.username} 启动workspace反思workspace_id: {workspace_id}")
# 使用独立的数据库会话来获取工作空间应用详情,避免事务失败
# Use independent database session to get workspace app details, avoiding transaction failures
from app.db import get_db_context
with get_db_context() as query_db:
service = WorkspaceAppService(query_db)
@@ -116,8 +211,9 @@ async def start_workspace_reflection(
reflection_results = []
# Process each application in the workspace
for data in result['apps_detailed_info']:
# 跳过没有配置的应用
# Skip applications without configurations
if not data['memory_configs']:
api_logger.debug(f"应用 {data['id']} 没有memory_configs跳过")
continue
@@ -126,22 +222,22 @@ async def start_workspace_reflection(
memory_configs = data['memory_configs']
end_users = data['end_users']
# 为每个配置和用户组合执行反思
# Execute reflection for each configuration and user combination
for config in memory_configs:
config_id_str = str(config['config_id'])
# 找到匹配此配置的所有release
# Find all releases matching this configuration
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
if not matching_releases:
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
continue
# 为每个用户执行反思 - 使用独立的数据库会话
# Execute reflection for each user - using independent database sessions
for user in end_users:
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config_id_str}")
# 为每个用户创建独立的数据库会话,避免事务失败影响其他用户
# Create independent database session for each user to avoid transaction failure impact
with get_db_context() as user_db:
try:
reflection_service = MemoryReflectionService(user_db)
@@ -184,14 +280,51 @@ async def start_reflection_configs(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""通过config_id查询memory_config表中的反思配置信息"""
"""
Query reflection configuration information by config_id
Retrieves detailed reflection configuration settings from the memory_config
table for a specific configuration ID. Provides comprehensive reflection
parameters including model settings, evaluation criteria, and operational flags.
Args:
config_id: Configuration identifier (UUID or integer) to query
current_user: Authenticated user making the request
db: Database session for data operations
Returns:
dict: Success response with detailed reflection configuration:
- config_id: Resolved configuration identifier
- reflection_enabled: Whether reflection is enabled for this config
- reflection_period_in_hours: Reflection execution interval
- reflexion_range: Scope of reflection operations (partial/all)
- baseline: Reflection strategy (time/fact/hybrid)
- reflection_model_id: LLM model identifier for reflection
- memory_verify: Memory verification flag
- quality_assessment: Quality assessment flag
Database Operations:
- Queries memory_config table by resolved config_id
- Retrieves all reflection-related configuration fields
- Resolves configuration ID for consistent formatting
Raises:
HTTPException 404: If configuration with specified ID is not found
HTTPException 500: If configuration query operation fails
ID Resolution:
- Supports both UUID and integer config_id formats
- Automatically resolves to appropriate internal format
- Maintains consistency across different ID representations
"""
config_id = resolve_config_id(config_id, db)
try:
config_id=resolve_config_id(config_id,db)
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
memory_config_id = resolve_config_id(result.config_id, db)
# 构建返回数据
# Build response data with comprehensive configuration details
reflection_config = {
"config_id": memory_config_id,
"reflection_enabled": result.enable_self_reflexion,
@@ -204,10 +337,12 @@ async def start_reflection_configs(
}
api_logger.info(f"成功查询反思配置config_id: {config_id}")
return success(data=reflection_config, msg="反思配置查询成功")
api_logger.info(f"Successfully queried reflection config, config_id: {config_id}")
return success(data=reflection_config, msg="Reflection configuration query successful")
except HTTPException:
# 重新抛出HTTP异常
# Re-raise HTTP exceptions without modification
raise
except Exception as e:
api_logger.error(f"查询反思配置失败: {str(e)}")
@@ -223,13 +358,66 @@ async def reflection_run(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""Activate the reflection function for all matching applications in the workspace"""
# 使用集中化的语言校验
"""
Execute reflection engine with specified configuration
Runs the reflection engine using configuration parameters from the database.
Validates model availability, sets up the reflection engine with proper
configuration, and executes the reflection process with multi-language support.
This endpoint provides a test run capability for reflection configurations,
allowing users to validate their reflection settings and see results before
deploying to production environments.
Args:
config_id: Configuration identifier (UUID or integer) for reflection settings
language_type: Language preference header for output localization (optional)
current_user: Authenticated user executing the reflection
db: Database session for configuration queries
Returns:
dict: Success response with reflection execution results including:
- baseline: Reflection strategy used
- source_data: Input data processed
- memory_verifies: Memory verification results (if enabled)
- quality_assessments: Quality assessment results (if enabled)
- reflexion_data: Generated reflection insights and solutions
Configuration Validation:
- Verifies configuration exists in database
- Validates LLM model availability
- Falls back to default model if specified model is unavailable
- Ensures all required parameters are properly set
Reflection Engine Setup:
- Creates ReflectionConfig with database parameters
- Initializes Neo4j connector for memory access
- Sets up ReflectionEngine with validated model
- Configures language preferences for output
Error Handling:
- Model validation with fallback to default
- Configuration validation and error reporting
- Comprehensive logging for debugging
- Graceful handling of missing configurations
Raises:
HTTPException 404: If configuration is not found
HTTPException 500: If reflection execution fails
Performance Notes:
- Direct database query for configuration retrieval
- Model validation to prevent runtime failures
- Efficient reflection engine initialization
- Language-aware output processing
"""
# Use centralized language validation for consistent localization
language = get_language_from_header(language_type)
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
config_id = resolve_config_id(config_id, db)
# 使用MemoryConfigRepository查询反思配置
# Query reflection configuration using MemoryConfigRepository
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
if not result:
raise HTTPException(
@@ -239,7 +427,7 @@ async def reflection_run(
api_logger.info(f"成功查询反思配置config_id: {config_id}")
# 验证模型ID是否存在
# Validate model ID existence
model_id = result.reflection_model_id
if model_id:
try:
@@ -250,6 +438,7 @@ async def reflection_run(
# 可以设置为None让反思引擎使用默认模型
model_id = None
# Create reflection configuration with database parameters
config = ReflectionConfig(
enabled=result.enable_self_reflexion,
iteration_period=result.iteration_period,
@@ -262,11 +451,13 @@ async def reflection_run(
model_id=model_id,
language_type=language_type
)
# Initialize Neo4j connector and reflection engine
connector = Neo4jConnector()
engine = ReflectionEngine(
config=config,
neo4j_connector=connector,
llm_client=model_id # 传入验证后的 model_id
llm_client=model_id # Pass validated model_id
)
result=await (engine.reflection_run())

View File

@@ -1,3 +1,18 @@
"""
Memory Short Term Controller
This module provides REST API endpoints for managing short-term and long-term memory
data retrieval and analysis. It handles memory system statistics, data aggregation,
and provides comprehensive memory insights for end users.
Key Features:
- Short-term memory data retrieval and statistics
- Long-term memory data aggregation
- Entity count integration
- Multi-language response support
- Memory system analytics and reporting
"""
from typing import Optional
from dotenv import load_dotenv
@@ -13,9 +28,13 @@ from app.models.user_model import User
from app.services.memory_short_service import LongService, ShortService
from app.services.memory_storage_service import search_entity
# Load environment variables for configuration
load_dotenv()
# Initialize API logger for request tracking and debugging
api_logger = get_api_logger()
# Configure router with prefix and tags for API organization
router = APIRouter(
prefix="/memory/short",
tags=["Memory"],
@@ -27,24 +46,73 @@ async def short_term_configs(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
# 使用集中化的语言校验
"""
Retrieve comprehensive short-term and long-term memory statistics
Provides a comprehensive overview of memory system data for a specific end user,
including short-term memory entries, long-term memory aggregations, entity counts,
and retrieval statistics. Supports multi-language responses based on request headers.
This endpoint serves as a central dashboard for memory system analytics, combining
data from multiple memory subsystems to provide a holistic view of user memory state.
Args:
end_user_id: Unique identifier for the end user whose memory data to retrieve
language_type: Language preference header for response localization (optional)
current_user: Authenticated user making the request (injected by dependency)
db: Database session for data operations (injected by dependency)
Returns:
dict: Success response containing comprehensive memory statistics:
- short_term: List of short-term memory entries with detailed data
- long_term: List of long-term memory aggregations and summaries
- entity: Count of entities associated with the end user
- retrieval_number: Total count of short-term memory retrievals
- long_term_number: Total count of long-term memory entries
Response Structure:
{
"code": 200,
"msg": "Short-term memory system data retrieved successfully",
"data": {
"short_term": [...], # Short-term memory entries
"long_term": [...], # Long-term memory data
"entity": 42, # Entity count
"retrieval_number": 156, # Short-term retrieval count
"long_term_number": 23 # Long-term memory count
}
}
Raises:
HTTPException: If end_user_id is invalid or data retrieval fails
Performance Notes:
- Combines multiple service calls for comprehensive data
- Entity search is performed asynchronously for better performance
- Response time depends on memory data volume for the specified user
"""
# Use centralized language validation for consistent localization
language = get_language_from_header(language_type)
# 获取短期记忆数据
short_term=ShortService(end_user_id, db)
short_result=short_term.get_short_databasets()
short_count=short_term.get_short_count()
# Retrieve short-term memory data and statistics
short_term = ShortService(end_user_id, db)
short_result = short_term.get_short_databasets() # Get short-term memory entries
short_count = short_term.get_short_count() # Get short-term retrieval count
long_term=LongService(end_user_id, db)
long_result=long_term.get_long_databasets()
# Retrieve long-term memory data and aggregations
long_term = LongService(end_user_id, db)
long_result = long_term.get_long_databasets() # Get long-term memory entries
# Get entity count for the specified end user
entity_result = await search_entity(end_user_id)
# Compile comprehensive memory statistics response
result = {
'short_term': short_result,
'long_term': long_result,
'entity': entity_result.get('num', 0),
"retrieval_number":short_count,
"long_term_number":len(long_result)
'short_term': short_result, # Short-term memory entries
'long_term': long_result, # Long-term memory data
'entity': entity_result.get('num', 0), # Entity count (default to 0 if not found)
"retrieval_number": short_count, # Short-term retrieval statistics
"long_term_number": len(long_result) # Long-term memory entry count
}
return success(data=result, msg="短期记忆系统数据获取成功")

View File

@@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
import uuid
from typing import Callable
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
@@ -19,6 +20,7 @@ from app.services import user_service
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.core.security import verify_password
from app.i18n.dependencies import get_translator
# 获取API专用日志器
api_logger = get_api_logger()
@@ -33,7 +35,8 @@ router = APIRouter(
def create_superuser(
user: user_schema.UserCreate,
db: Session = Depends(get_db),
current_superuser: User = Depends(get_current_superuser)
current_superuser: User = Depends(get_current_superuser),
t: Callable = Depends(get_translator)
):
"""创建超级管理员(仅超级管理员可访问)"""
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
@@ -42,7 +45,7 @@ def create_superuser(
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg="超级管理员创建成功")
return success(data=result_schema, msg=t("users.create.superuser_success"))
@router.delete("/{user_id}", response_model=ApiResponse)
@@ -50,6 +53,7 @@ def delete_user(
user_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""停用用户(软删除)"""
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
@@ -57,13 +61,14 @@ def delete_user(
db=db, user_id_to_deactivate=user_id, current_user=current_user
)
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
return success(msg="用户停用成功")
return success(msg=t("users.delete.deactivate_success"))
@router.post("/{user_id}/activate", response_model=ApiResponse)
def activate_user(
user_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""激活用户"""
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
@@ -74,13 +79,14 @@ def activate_user(
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg="用户激活成功")
return success(data=result_schema, msg=t("users.activate.success"))
@router.get("", response_model=ApiResponse)
def get_current_user_info(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""获取当前用户信息"""
api_logger.info(f"当前用户信息请求: {current_user.username}")
@@ -105,7 +111,7 @@ def get_current_user_info(
break
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
return success(data=result_schema, msg="用户信息获取成功")
return success(data=result_schema, msg=t("users.info.get_success"))
@router.get("/superusers", response_model=ApiResponse)
@@ -113,6 +119,7 @@ def get_tenant_superusers(
include_inactive: bool = False,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
t: Callable = Depends(get_translator)
):
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
@@ -125,7 +132,7 @@ def get_tenant_superusers(
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
return success(data=superusers_schema, msg="租户超管列表获取成功")
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
@@ -134,6 +141,7 @@ def get_user_info_by_id(
user_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""根据用户ID获取用户信息"""
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
@@ -144,7 +152,7 @@ def get_user_info_by_id(
api_logger.info(f"用户信息获取成功: {result.username}")
result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg="用户信息获取成功")
return success(data=result_schema, msg=t("users.info.get_success"))
@router.put("/change-password", response_model=ApiResponse)
@@ -152,6 +160,7 @@ async def change_password(
request: ChangePasswordRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""修改当前用户密码"""
api_logger.info(f"用户密码修改请求: {current_user.username}")
@@ -164,7 +173,7 @@ async def change_password(
current_user=current_user
)
api_logger.info(f"用户密码修改成功: {current_user.username}")
return success(msg="密码修改成功")
return success(msg=t("auth.password.change_success"))
@router.put("/admin/change-password", response_model=ApiResponse)
@@ -172,6 +181,7 @@ async def admin_change_password(
request: AdminChangePasswordRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
t: Callable = Depends(get_translator)
):
"""超级管理员修改指定用户的密码"""
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
@@ -186,16 +196,17 @@ async def admin_change_password(
# 根据是否生成了随机密码来构造响应
if request.new_password:
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
return success(msg="密码修改成功")
return success(msg=t("auth.password.change_success"))
else:
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
return success(data=generated_password, msg="密码重置成功")
return success(data=generated_password, msg=t("auth.password.reset_success"))
@router.post("/verify_pwd", response_model=ApiResponse)
def verify_pwd(
request: VerifyPasswordRequest,
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""验证当前用户密码"""
api_logger.info(f"用户验证密码请求: {current_user.username}")
@@ -203,8 +214,8 @@ def verify_pwd(
is_valid = verify_password(request.password, current_user.hashed_password)
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
if not is_valid:
raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED)
return success(data={"valid": is_valid}, msg="验证完成")
raise BusinessException(t("users.errors.password_verification_failed"), code=BizCode.VALIDATION_FAILED)
return success(data={"valid": is_valid}, msg=t("common.success.retrieved"))
@router.post("/send-email-code", response_model=ApiResponse)
@@ -212,6 +223,7 @@ async def send_email_code(
request: SendEmailCodeRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""发送邮箱验证码"""
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
@@ -219,7 +231,7 @@ async def send_email_code(
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
return success(msg="验证码已发送到您的邮箱,请查收")
return success(msg=t("users.email.code_sent"))
@router.put("/change-email", response_model=ApiResponse)
@@ -227,6 +239,7 @@ async def change_email(
request: VerifyEmailCodeRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""验证验证码并修改邮箱"""
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
@@ -239,4 +252,51 @@ async def change_email(
)
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
return success(msg="邮箱修改成功")
return success(msg=t("users.email.change_success"))
@router.get("/me/language", response_model=ApiResponse)
def get_current_user_language(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""获取当前用户的语言偏好"""
api_logger.info(f"获取用户语言偏好: {current_user.username}")
language = user_service.get_user_language_preference(
db=db,
user_id=current_user.id,
current_user=current_user
)
api_logger.info(f"用户语言偏好获取成功: {current_user.username}, language={language}")
return success(
data=user_schema.LanguagePreferenceResponse(language=language),
msg=t("users.language.get_success")
)
@router.put("/me/language", response_model=ApiResponse)
def update_current_user_language(
request: user_schema.LanguagePreferenceRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""设置当前用户的语言偏好"""
api_logger.info(f"更新用户语言偏好: {current_user.username}, language={request.language}")
updated_user = user_service.update_user_language_preference(
db=db,
user_id=current_user.id,
language=request.language,
current_user=current_user
)
api_logger.info(f"用户语言偏好更新成功: {current_user.username}, language={request.language}")
return success(
data=user_schema.LanguagePreferenceResponse(language=updated_user.preferred_language),
msg=t("users.language.update_success")
)

View File

@@ -14,6 +14,12 @@ from app.dependencies import (
get_current_user,
workspace_access_guard,
)
from app.i18n.dependencies import get_current_language, get_translator
from app.i18n.serializers import (
WorkspaceSerializer,
WorkspaceMemberSerializer,
WorkspaceInviteSerializer
)
from app.models.tenant_model import Tenants
from app.models.user_model import User
from app.models.workspace_model import InviteStatus
@@ -65,7 +71,9 @@ def get_workspaces(
include_current: bool = Query(True, description="是否包含当前工作空间"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
current_tenant: Tenants = Depends(get_current_tenant)
current_tenant: Tenants = Depends(get_current_tenant),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""获取当前租户下用户参与的所有工作空间
@@ -88,8 +96,13 @@ def get_workspaces(
)
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
return success(data=workspaces_schema, msg="工作空间列表获取成功")
# 使用序列化器添加国际化字段
serializer = WorkspaceSerializer()
workspaces_data = [WorkspaceResponse.model_validate(w).model_dump() for w in workspaces]
workspaces_i18n = serializer.serialize_list(workspaces_data, language)
return success(data=workspaces_i18n, msg=t("workspace.list_retrieved"))
@router.post("", response_model=ApiResponse)
@@ -98,6 +111,8 @@ def create_workspace(
language_type: str = Header(default="zh", alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""创建新的工作空间"""
from app.core.language_utils import get_language_from_header
@@ -118,8 +133,13 @@ def create_workspace(
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
f"创建者: {current_user.username}, language={language}"
)
result_schema = WorkspaceResponse.model_validate(result)
return success(data=result_schema, msg="工作空间创建成功")
# 使用序列化器添加国际化字段
serializer = WorkspaceSerializer()
result_data = WorkspaceResponse.model_validate(result).model_dump()
result_i18n = serializer.serialize(result_data, language)
return success(data=result_i18n, msg=t("workspace.created"))
@router.put("", response_model=ApiResponse)
@cur_workspace_access_guard()
@@ -127,6 +147,8 @@ def update_workspace(
workspace: WorkspaceUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""更新工作空间"""
workspace_id = current_user.current_workspace_id
@@ -139,14 +161,21 @@ def update_workspace(
user=current_user,
)
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
result_schema = WorkspaceResponse.model_validate(result)
return success(data=result_schema, msg="工作空间更新成功")
# 使用序列化器添加国际化字段
serializer = WorkspaceSerializer()
result_data = WorkspaceResponse.model_validate(result).model_dump()
result_i18n = serializer.serialize(result_data, language)
return success(data=result_i18n, msg=t("workspace.updated"))
@router.get("/members", response_model=ApiResponse)
@cur_workspace_access_guard()
def get_cur_workspace_members(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""获取工作空间成员列表(关系序列化)"""
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
@@ -157,8 +186,14 @@ def get_cur_workspace_members(
user=current_user,
)
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
# 转换为表格项并使用序列化器添加国际化字段
table_items = _convert_members_to_table_items(members)
return success(data=table_items, msg="工作空间成员列表获取成功")
serializer = WorkspaceMemberSerializer()
members_data = [item.model_dump() for item in table_items]
members_i18n = serializer.serialize_list(members_data, language)
return success(data=members_i18n, msg=t("workspace.members.list_retrieved"))
@router.put("/members", response_model=ApiResponse)
@@ -168,6 +203,7 @@ def update_workspace_members(
updates: List[WorkspaceMemberUpdate],
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
):
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
@@ -178,7 +214,7 @@ def update_workspace_members(
user=current_user,
)
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
return success(msg="成员角色更新成功")
return success(msg=t("workspace.members.role_updated"))
@router.delete("/members/{member_id}", response_model=ApiResponse)
@@ -187,6 +223,7 @@ def delete_workspace_member(
member_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
):
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
@@ -198,7 +235,7 @@ def delete_workspace_member(
user=current_user,
)
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
return success(msg="成员删除成功")
return success(msg=t("workspace.members.deleted"))
# 创建空间协作邀请
@@ -208,6 +245,8 @@ def create_workspace_invite(
invite_data: WorkspaceInviteCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""创建工作空间邀请"""
workspace_id = current_user.current_workspace_id
@@ -220,7 +259,12 @@ def create_workspace_invite(
user=current_user
)
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
return success(data=result, msg="邀请创建成功")
# 使用序列化器添加国际化字段
serializer = WorkspaceInviteSerializer()
result_i18n = serializer.serialize(result, language)
return success(data=result_i18n, msg=t("workspace.invites.created"))
@router.get("/invites", response_model=ApiResponse)
@@ -232,6 +276,8 @@ def get_workspace_invites(
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""获取工作空间邀请列表"""
workspace_id = current_user.current_workspace_id
@@ -246,18 +292,30 @@ def get_workspace_invites(
offset=offset
)
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
return success(data=invites, msg="邀请列表获取成功")
# 使用序列化器添加国际化字段
serializer = WorkspaceInviteSerializer()
invites_i18n = serializer.serialize_list(invites, language)
return success(data=invites_i18n, msg=t("workspace.invites.list_retrieved"))
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
def get_workspace_invite_info(
token: str,
db: Session = Depends(get_db),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""获取工作空间邀请用户信息(无需认证)"""
result = workspace_service.validate_invite_token(db=db, token=token)
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
return success(data=result, msg="邀请验证成功")
# 使用序列化器添加国际化字段
serializer = WorkspaceInviteSerializer()
result_i18n = serializer.serialize(result, language)
return success(data=result_i18n, msg=t("workspace.invites.validated"))
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
@@ -267,6 +325,8 @@ def revoke_workspace_invite(
invite_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""撤销工作空间邀请"""
workspace_id = current_user.current_workspace_id
@@ -279,7 +339,12 @@ def revoke_workspace_invite(
user=current_user
)
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
return success(data=result, msg="邀请撤销成功")
# 使用序列化器添加国际化字段
serializer = WorkspaceInviteSerializer()
result_i18n = serializer.serialize(result, language)
return success(data=result_i18n, msg=t("workspace.invites.revoked"))
# ==================== 公开邀请接口(无需认证) ====================
@@ -302,6 +367,7 @@ def switch_workspace(
workspace_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
):
"""切换工作空间"""
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
@@ -312,7 +378,7 @@ def switch_workspace(
user=current_user,
)
api_logger.info(f"成功切换工作空间为 {workspace_id}")
return success(msg="工作空间切换成功")
return success(msg=t("workspace.switched"))
@router.get("/storage", response_model=ApiResponse)
@@ -320,6 +386,7 @@ def switch_workspace(
def get_workspace_storage_type(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
):
"""获取当前工作空间的存储类型"""
workspace_id = current_user.current_workspace_id
@@ -331,7 +398,7 @@ def get_workspace_storage_type(
user=current_user
)
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
return success(data={"storage_type": storage_type}, msg=t("workspace.storage.type_retrieved"))
@router.get("/workspace_models", response_model=ApiResponse)
@@ -339,6 +406,8 @@ def get_workspace_storage_type(
def workspace_models_configs(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""获取当前工作空间的模型配置llm, embedding, rerank"""
workspace_id = current_user.current_workspace_id
@@ -354,14 +423,14 @@ def workspace_models_configs(
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作空间不存在或无权访问"
detail=t("workspace.not_found")
)
api_logger.info(
f"成功获取工作空间 {workspace_id} 的模型配置: "
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
)
return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功")
return success(data=WorkspaceModelsConfig.model_validate(configs), msg=t("workspace.models.config_retrieved"))
@router.put("/workspace_models", response_model=ApiResponse)
@@ -370,6 +439,7 @@ def update_workspace_models_configs(
models_update: WorkspaceModelsUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
):
"""更新当前工作空间的模型配置llm, embedding, rerank"""
workspace_id = current_user.current_workspace_id
@@ -386,5 +456,5 @@ def update_workspace_models_configs(
f"成功更新工作空间 {workspace_id} 的模型配置: "
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
)
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功")
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg=t("workspace.models.config_updated"))

View File

@@ -162,6 +162,44 @@ class Settings:
# This controls the language used for memory summary titles and other generated content
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
# ========================================================================
# Internationalization (i18n) Configuration
# ========================================================================
# Default language for API responses
I18N_DEFAULT_LANGUAGE: str = os.getenv("I18N_DEFAULT_LANGUAGE", "zh")
# Supported languages (comma-separated)
I18N_SUPPORTED_LANGUAGES: list[str] = [
lang.strip()
for lang in os.getenv("I18N_SUPPORTED_LANGUAGES", "zh,en").split(",")
if lang.strip()
]
# Core locales directory (community edition)
# Use absolute path to work from any working directory
I18N_CORE_LOCALES_DIR: str = os.getenv(
"I18N_CORE_LOCALES_DIR",
os.path.join(os.path.dirname(os.path.dirname(__file__)), "locales")
)
# Premium locales directory (enterprise edition, optional)
I18N_PREMIUM_LOCALES_DIR: Optional[str] = os.getenv("I18N_PREMIUM_LOCALES_DIR", None)
# Enable translation cache
I18N_ENABLE_TRANSLATION_CACHE: bool = os.getenv("I18N_ENABLE_TRANSLATION_CACHE", "true").lower() == "true"
# LRU cache size for hot translations
I18N_LRU_CACHE_SIZE: int = int(os.getenv("I18N_LRU_CACHE_SIZE", "1000"))
# Enable hot reload of translation files
I18N_ENABLE_HOT_RELOAD: bool = os.getenv("I18N_ENABLE_HOT_RELOAD", "false").lower() == "true"
# Fallback language when translation is missing
I18N_FALLBACK_LANGUAGE: str = os.getenv("I18N_FALLBACK_LANGUAGE", "zh")
# Log missing translations
I18N_LOG_MISSING_TRANSLATIONS: bool = os.getenv("I18N_LOG_MISSING_TRANSLATIONS", "true").lower() == "true"
# Logging settings
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")

View File

@@ -2,15 +2,37 @@ from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
def content_input_node(state: ReadState) -> ReadState:
"""开始节点 - 提取内容并保持状态信息"""
"""
Start node - Extract content and maintain state information
Extracts the content from the first message in the state and returns it
as the data field while preserving all other state information.
Args:
state: ReadState containing messages and other state data
Returns:
ReadState: Updated state with extracted content in data field
"""
content = state['messages'][0].content if state.get('messages') else ''
# 返回内容并保持所有状态信息
# Return content and maintain all state information
return {"data": content}
def content_input_write(state: WriteState) -> WriteState:
"""开始节点 - 提取内容并保持状态信息"""
"""
Start node - Extract content and maintain state information for write operations
Extracts the content from the first message in the state for write operations.
Args:
state: WriteState containing messages and other state data
Returns:
WriteState: Updated state with extracted content in data field
"""
content = state['messages'][0].content if state.get('messages') else ''
# 返回内容并保持所有状态信息
return {"data": content}
# Return content and maintain all state information
return {"data": content}

View File

@@ -19,19 +19,39 @@ logger = get_agent_logger(__name__)
class ProblemNodeService(LLMServiceMixin):
"""问题处理节点服务类"""
"""
Problem processing node service class
Handles problem decomposition and extension operations using LLM services.
Inherits from LLMServiceMixin to provide structured LLM calling capabilities.
Attributes:
template_service: Service for rendering Jinja2 templates
"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
# Create global service instance
problem_service = ProblemNodeService()
async def Split_The_Problem(state: ReadState) -> ReadState:
"""问题分解节点"""
"""
Problem decomposition node
Breaks down complex user queries into smaller, more manageable sub-problems.
Uses LLM to analyze the input and generate structured problem decomposition
with question types and reasoning.
Args:
state: ReadState containing user input and configuration
Returns:
ReadState: Updated state with problem decomposition results
"""
# 从状态中获取数据
content = state.get('data', '')
end_user_id = state.get('end_user_id', '')
@@ -64,7 +84,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
# 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
# 验证结构化响应
# Validate structured response
if not structured or not hasattr(structured, 'root'):
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
split_result = json.dumps([], ensure_ascii=False)
@@ -106,7 +126,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
exc_info=True
)
# 提供更详细的错误信息
# Provide more detailed error information
error_details = {
"error_type": type(e).__name__,
"error_message": str(e),
@@ -116,7 +136,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
logger.error(f"Split_The_Problem error details: {error_details}")
# 创建默认的空结果
# Create default empty result
result = {
"context": json.dumps([], ensure_ascii=False),
"original": content,
@@ -130,13 +150,25 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
}
}
# 返回更新后的状态,包含spit_context字段
# Return updated state including spit_context field
return {"spit_data": result}
async def Problem_Extension(state: ReadState) -> ReadState:
"""问题扩展节点"""
# 获取原始数据和分解结果
"""
Problem extension node
Extends the decomposed problems from Split_The_Problem node by generating
additional related questions and organizing them by original question.
Uses LLM to create comprehensive question extensions for better memory retrieval.
Args:
state: ReadState containing decomposed problems and configuration
Returns:
ReadState: Updated state with extended problem results
"""
# Get original data and decomposition results
start = time.time()
content = state.get('data', '')
data = state.get('spit_data', '')['context']
@@ -182,7 +214,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
# 验证结构化响应
# Validate structured response
if not response_content or not hasattr(response_content, 'root'):
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
aggregated_dict = {}
@@ -216,7 +248,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
exc_info=True
)
# 提供更详细的错误信息
# Provide more detailed error information
error_details = {
"error_type": type(e).__name__,
"error_message": str(e),

View File

@@ -29,6 +29,18 @@ logger = get_agent_logger(__name__)
async def rag_config(state):
"""
Configure RAG (Retrieval-Augmented Generation) settings
Creates configuration for knowledge base retrieval including similarity thresholds,
weights, and reranker settings.
Args:
state: Current state containing user_rag_memory_id
Returns:
dict: RAG configuration dictionary
"""
user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = {
"knowledge_bases": [
@@ -48,6 +60,19 @@ async def rag_config(state):
async def rag_knowledge(state, question):
"""
Retrieve knowledge using RAG approach
Performs knowledge retrieval from configured knowledge bases using the
provided question and returns formatted results.
Args:
state: Current state containing configuration
question: Question to search for
Returns:
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
"""
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
@@ -68,12 +93,24 @@ async def rag_knowledge(state, question):
async def llm_infomation(state: ReadState) -> ReadState:
"""
Get LLM configuration information from state
Retrieves model configuration details including model ID and tenant ID
from the memory configuration in the current state.
Args:
state: ReadState containing memory configuration
Returns:
ReadState: Model configuration as Pydantic model
"""
memory_config = state.get('memory_config', None)
model_id = memory_config.llm_model_id
tenant_id = memory_config.tenant_id
# 使用现有的 memory_config 而不是重新查询数据库
# 或者使用线程安全的数据库访问
# Use existing memory_config instead of re-querying database
# or use thread-safe database access
with get_db_context() as db:
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
@@ -82,16 +119,20 @@ async def llm_infomation(state: ReadState) -> ReadState:
async def clean_databases(data) -> str:
"""
简化的数据库搜索结果清理函数
Simplified database search result cleaning function
Processes and cleans search results from various sources including
reranked results and time-based search results. Extracts text content
from structured data and returns as formatted string.
Args:
data: 搜索结果数据
data: Search result data (can be string, dict, or other types)
Returns:
清理后的内容字符串
str: Cleaned content string
"""
try:
# 解析JSON字符串
# Parse JSON string
if isinstance(data, str):
try:
data = json.loads(data)
@@ -101,24 +142,24 @@ async def clean_databases(data) -> str:
if not isinstance(data, dict):
return str(data)
# 获取结果数据
# Get result data
# with open("搜索结果.json","w",encoding='utf-8') as f:
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
results = data.get('results', data)
if not isinstance(results, dict):
return str(results)
# 收集所有内容
# Collect all content
content_list = []
# 处理重排序结果
# Process reranked results
reranked = results.get('reranked_results', {})
if reranked:
for category in ['summaries', 'statements', 'chunks', 'entities']:
items = reranked.get(category, [])
if isinstance(items, list):
content_list.extend(items)
# 处理时间搜索结果
# Process time search results
time_search = results.get('time_search', {})
if time_search:
if isinstance(time_search, dict):
@@ -128,7 +169,7 @@ async def clean_databases(data) -> str:
elif isinstance(time_search, list):
content_list.extend(time_search)
# 提取文本内容
# Extract text content
text_parts = []
for item in content_list:
if isinstance(item, dict):
@@ -146,10 +187,19 @@ async def clean_databases(data) -> str:
async def retrieve_nodes(state: ReadState) -> ReadState:
'''
模型信息
'''
"""
Retrieve information using simplified search approach
Processes extended problems from previous nodes and performs retrieval
using either RAG or hybrid search based on storage type. Handles concurrent
processing of multiple questions and deduplicates results.
Args:
state: ReadState containing problem extensions and configuration
Returns:
ReadState: Updated state with retrieval results and intermediate outputs
"""
problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '')
@@ -163,7 +213,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# 创建异步任务处理单个问题
# Create async task to process individual questions
async def process_question_nodes(idx, question):
try:
# Prepare search parameters based on storage type
@@ -209,7 +259,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
}
}
# 并发处理所有问题
# Process all questions concurrently
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
databases_anser = await asyncio.gather(*tasks)
databases_data = {
@@ -257,7 +307,20 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
async def retrieve(state: ReadState) -> ReadState:
# 从state中获取end_user_id
"""
Advanced retrieve function using LangChain agents and tools
Uses LangChain agents with specialized retrieval tools (time-based and hybrid)
to perform sophisticated information retrieval. Supports both RAG and traditional
memory storage approaches with concurrent processing and result deduplication.
Args:
state: ReadState containing problem extensions and configuration
Returns:
ReadState: Updated state with retrieval results and intermediate outputs
"""
# Get end_user_id from state
import time
start = time.time()
problem_extension = state.get('problem_extension', '')['context']
@@ -299,21 +362,21 @@ async def retrieve(state: ReadState) -> ReadState:
system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
)
# 创建异步任务处理单个问题
# Create async task to process individual questions
import asyncio
# 在模块级别定义信号量,限制最大并发数
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
# Define semaphore at module level to limit maximum concurrency
SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations
async def process_question(idx, question):
async with SEMAPHORE: # 限制并发
async with SEMAPHORE: # Limit concurrency
try:
if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
question)
else:
cleaned_query = question
# 使用 asyncio 在线程池中运行同步的 agent.invoke
# Use asyncio to run synchronous agent.invoke in thread pool
import asyncio
response = await asyncio.get_event_loop().run_in_executor(
None,
@@ -362,7 +425,7 @@ async def retrieve(state: ReadState) -> ReadState:
}
}
# 并发处理所有问题
# Process all questions concurrently
import asyncio
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
databases_anser = await asyncio.gather(*tasks)

View File

@@ -23,18 +23,39 @@ logger = get_agent_logger(__name__)
class SummaryNodeService(LLMServiceMixin):
"""总结节点服务类"""
"""
Summary node service class
Handles summary generation operations using LLM services. Inherits from
LLMServiceMixin to provide structured LLM calling capabilities for
generating summaries from retrieved information.
Attributes:
template_service: Service for rendering Jinja2 templates
"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
# Create global service instance
summary_service = SummaryNodeService()
async def rag_config(state):
"""
Configure RAG (Retrieval-Augmented Generation) settings for summary operations
Creates configuration for knowledge base retrieval including similarity thresholds,
weights, and reranker settings specifically for summary generation.
Args:
state: Current state containing user_rag_memory_id
Returns:
dict: RAG configuration dictionary with knowledge base settings
"""
user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = {
"knowledge_bases": [
@@ -54,6 +75,23 @@ async def rag_config(state):
async def rag_knowledge(state, question):
"""
Retrieve knowledge using RAG approach for summary generation
Performs knowledge retrieval from configured knowledge bases using the
provided question and returns formatted results for summary processing.
Args:
state: Current state containing configuration
question: Question to search for in knowledge base
Returns:
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
- retrieval_knowledge: List of retrieved knowledge chunks
- clean_content: Formatted content string
- cleaned_query: Processed query string
- raw_results: Raw retrieval results
"""
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
@@ -74,6 +112,18 @@ async def rag_knowledge(state, question):
async def summary_history(state: ReadState) -> ReadState:
"""
Retrieve conversation history for summary context
Gets the conversation history for the current user to provide context
for summary generation operations.
Args:
state: ReadState containing end_user_id
Returns:
ReadState: Conversation history data
"""
end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history
@@ -82,11 +132,26 @@ async def summary_history(state: ReadState) -> ReadState:
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
search_mode) -> str:
"""
增强的summary_llm函数,包含更好的错误处理和数据验证
Enhanced summary_llm function with better error handling and data validation
Generates summaries using LLM with structured output. Includes fallback mechanisms
for handling LLM failures and provides robust error recovery.
Args:
state: ReadState containing current context
history: Conversation history for context
retrieve_info: Retrieved information to summarize
template_name: Jinja2 template name for prompt generation
operation_name: Type of operation (summary, input_summary, retrieve_summary)
response_model: Pydantic model for structured output
search_mode: Search mode flag ("0" for simple, "1" for complex)
Returns:
str: Generated summary text or fallback message
"""
data = state.get("data", '')
# 构建系统提示词
# Build system prompt
if str(search_mode) == "0":
system_prompt = await summary_service.template_service.render_template(
template_name=template_name,
@@ -103,7 +168,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
retrieve_info=retrieve_info
)
try:
# 使用优化的LLM服务进行结构化输出
# Use optimized LLM service for structured output
with get_db_context() as db_session:
structured = await summary_service.call_llm_structured(
state=state,
@@ -112,23 +177,23 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
response_model=response_model,
fallback_value=None
)
# 验证结构化响应
# Validate structured response
if structured is None:
logger.warning("LLM返回None使用默认回答")
return "信息不足,无法回答"
# 根据操作类型提取答案
# Extract answer based on operation type
if operation_name == "summary":
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
else:
# 处理RetrieveSummaryResponse
# Handle RetrieveSummaryResponse
if hasattr(structured, 'data') and structured.data:
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
else:
logger.warning("结构化响应缺少data字段")
aimessages = "信息不足,无法回答"
# 验证答案不为空
# Validate answer is not empty
if not aimessages or aimessages.strip() == "":
aimessages = "信息不足,无法回答"
@@ -137,7 +202,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
except Exception as e:
logger.error(f"结构化输出失败: {e}", exc_info=True)
# 尝试非结构化输出作为fallback
# Try unstructured output as fallback
try:
logger.info("尝试非结构化输出作为fallback")
response = await summary_service.call_llm_simple(
@@ -148,9 +213,9 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
)
if response and response.strip():
# 简单清理响应
# Simple response cleaning
cleaned_response = response.strip()
# 移除可能的JSON标记
# Remove possible JSON markers
if cleaned_response.startswith('```'):
lines = cleaned_response.split('\n')
cleaned_response = '\n'.join(lines[1:-1])
@@ -165,6 +230,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
"""
Save summary results to Redis session storage
Stores the generated summary and user query in Redis for session management
and conversation history tracking.
Args:
state: ReadState containing user and query information
aimessages: Generated summary message to save
Returns:
ReadState: Updated state after saving to Redis
"""
data = state.get("data", '')
end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session(
@@ -179,6 +257,20 @@ async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
"""
Format summary results for different output types
Creates structured output formats for both input summary and retrieval summary
operations, including metadata and intermediate results for frontend display.
Args:
state: ReadState containing storage and user information
aimessages: Generated summary message
raw_results: Raw search/retrieval results
Returns:
tuple: (input_summary, retrieve_summary) formatted result dictionaries
"""
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
data = state.get("data", '')
@@ -217,6 +309,19 @@ async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState
async def Input_Summary(state: ReadState) -> ReadState:
"""
Generate quick input summary from retrieved information
Performs fast retrieval and generates a quick summary response for user queries.
This function prioritizes speed by only searching summary nodes and provides
immediate feedback to users.
Args:
state: ReadState containing user query, storage configuration, and context
Returns:
ReadState: Dictionary containing summary results with status and metadata
"""
start = time.time()
storage_type = state.get("storage_type", '')
memory_config = state.get('memory_config', None)
@@ -266,6 +371,19 @@ async def Input_Summary(state: ReadState) -> ReadState:
async def Retrieve_Summary(state: ReadState) -> ReadState:
"""
Generate comprehensive summary from retrieved expansion issues
Processes retrieved expansion issues and generates a detailed summary using LLM.
This function handles complex retrieval results and provides comprehensive answers
based on expanded query results.
Args:
state: ReadState containing retrieve data with expansion issues
Returns:
ReadState: Dictionary containing comprehensive summary results
"""
retrieve = state.get("retrieve", '')
history = await summary_history(state)
import json
@@ -299,13 +417,26 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
duration = 0.0
log_time('Retrieval summary', duration)
# 修复协程调用 - await,然后访问返回值
# Fixed coroutine call - await first, then access return value
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary": summary}
async def Summary(state: ReadState) -> ReadState:
"""
Generate final comprehensive summary from verified data
Creates the final summary using verified expansion issues and conversation history.
This function processes verified data to generate the most comprehensive and
accurate response to user queries.
Args:
state: ReadState containing verified data and query information
Returns:
ReadState: Dictionary containing final summary results
"""
start = time.time()
query = state.get("data", '')
verify = state.get("verify", '')
@@ -336,13 +467,26 @@ async def Summary(state: ReadState) -> ReadState:
duration = 0.0
log_time('Retrieval summary', duration)
# 修复协程调用 - await,然后访问返回值
# Fixed coroutine call - await first, then access return value
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary": summary}
async def Summary_fails(state: ReadState) -> ReadState:
"""
Generate fallback summary when normal summary process fails
Provides a fallback summary generation mechanism when the standard summary
process encounters errors or fails to produce satisfactory results. Uses
a specialized failure template to handle edge cases.
Args:
state: ReadState containing verified data and failure context
Returns:
ReadState: Dictionary containing fallback summary results
"""
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
history = await summary_history(state)

View File

@@ -18,24 +18,46 @@ logger = get_agent_logger(__name__)
class VerificationNodeService(LLMServiceMixin):
"""验证节点服务类"""
"""
Verification node service class
Handles data verification operations using LLM services. Inherits from
LLMServiceMixin to provide structured LLM calling capabilities for
verifying and validating retrieved information.
Attributes:
template_service: Service for rendering Jinja2 templates
"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
# Create global service instance
verification_service = VerificationNodeService()
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
"""处理验证结果并生成输出格式"""
"""
Process verification results and generate output format
Transforms VerificationResult objects into structured output format suitable
for frontend consumption. Handles conversion of VerificationItem objects to
dictionary format and adds metadata for tracking.
Args:
state: ReadState containing storage and user configuration
messages_deal: VerificationResult containing verification outcomes
Returns:
dict: Formatted verification result with status and metadata
"""
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
data = state.get('data', '')
# VerificationItem 对象转换为字典列表
# Convert VerificationItem objects to dictionary list
verified_data = []
if messages_deal.expansion_issue:
for item in messages_deal.expansion_issue:
@@ -89,7 +111,7 @@ async def Verify(state: ReadState):
logger.info("Verify: 开始渲染模板")
# 生成 JSON schema 以指导 LLM 输出正确格式
# Generate JSON schema to guide LLM output format
json_schema = VerificationResult.model_json_schema()
system_prompt = await verification_service.template_service.render_template(
@@ -104,8 +126,8 @@ async def Verify(state: ReadState):
# 使用优化的LLM服务添加超时保护
logger.info("Verify: 开始调用 LLM")
try:
# 添加 asyncio.wait_for 超时包裹,防止无限等待
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
# Add asyncio.wait_for timeout wrapper to prevent infinite waiting
# Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds)
with get_db_context() as db_session:
structured = await asyncio.wait_for(
@@ -122,7 +144,7 @@ async def Verify(state: ReadState):
"reason": "验证失败或超时"
}
),
timeout=150.0 # 150秒超时
timeout=150.0 # 150 second timeout
)
logger.info(f"Verify: LLM 调用完成result={structured}")
except asyncio.TimeoutError:

View File

@@ -33,7 +33,19 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
@asynccontextmanager
async def make_read_graph():
"""创建并返回 LangGraph 工作流"""
"""
Create and return a LangGraph workflow for memory reading operations
Builds a state graph workflow that handles memory retrieval, problem analysis,
verification, and summarization. The workflow includes nodes for content input,
problem splitting, retrieval, verification, and various summary operations.
Yields:
StateGraph: Compiled LangGraph workflow for memory reading
Raises:
Exception: If workflow creation fails
"""
try:
# Build workflow graph
workflow = StateGraph(ReadState)
@@ -48,7 +60,7 @@ async def make_read_graph():
workflow.add_node("Summary", Summary)
workflow.add_node("Summary_fails", Summary_fails)
# 添加边
# Add edges to define workflow flow
workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
@@ -63,7 +75,7 @@ async def make_read_graph():
'''-----'''
# workflow.add_edge("Retrieve", END)
# 编译工作流
# Compile workflow
graph = workflow.compile()
yield graph
@@ -72,108 +84,3 @@ async def make_read_graph():
raise
finally:
print("工作流创建完成")
async def main():
"""主函数 - 运行工作流"""
message = "昨天有什么好看的电影"
end_user_id = '88a459f5_text09' # 组ID
storage_type = 'neo4j' # 存储类型
search_switch = '1' # 搜索开关
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=17, # 改为整数
service_name="MemoryAgentService"
)
import time
start = time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
"end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息
_intermediate_outputs = []
summary = ''
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
print(f"处理节点: {node_name}")
# 处理不同Summary节点的返回结构
if 'Summary' in node_name:
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
summary = node_data['InputSummary']['summary_result']
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
summary = node_data['RetrieveSummary']['summary_result']
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
summary = node_data['summary']['summary_result']
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
summary = node_data['SummaryFails']['summary_result']
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
if spit_data and spit_data != [] and spit_data != {}:
_intermediate_outputs.append(spit_data)
# Problem_Extension 节点
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
if problem_extension and problem_extension != [] and problem_extension != {}:
_intermediate_outputs.append(problem_extension)
# Retrieve 节点
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node)
# Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}:
_intermediate_outputs.append(verify_n)
# Summary 节点
summary_n = node_data.get('summary', {}).get('_intermediate', None)
if summary_n and summary_n != [] and summary_n != {}:
_intermediate_outputs.append(summary_n)
# # 过滤掉空值
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
#
# # 优化搜索结果
# print("=== 开始优化搜索结果 ===")
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
# result=reorder_output_results(optimized_outputs)
# # 保存优化后的结果到文件
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
# import json
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
#
print(f"=== 最终摘要 ===")
print(summary)
except Exception as e:
import traceback
traceback.print_exc()
finally:
db_session.close()
end = time.time()
print(100 * 'y')
print(f"总耗时: {end - start}s")
print(100 * 'y')
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -1,13 +1,13 @@
from typing import Literal
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
logger = get_agent_logger(__name__)
counter = COUNTState(limit=3)
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing based on search_switch value.
@@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa
return 'Input_Summary'
return 'Split_The_Problem' # 默认情况
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing based on search_switch value.
@@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
elif search_switch == '1':
return 'Retrieve_Summary'
return 'Retrieve_Summary' # Default based on business logic
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
status=state.get('verify', '')['status']
status = state.get('verify', '')['status']
# loop_count = counter.get_total()
if "success" in status:
# counter.reset()
@@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
# if loop_count < 2: # Maximum loop count is 3
# return "content_input"
# else:
# counter.reset()
# counter.reset()
return "Summary_fails"
else:
# Add default return value to avoid returning None

View File

@@ -2,77 +2,104 @@ import json
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.redis_tool import count_store
from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context, get_db
from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
"""
Write messages to RAG storage system
Combines user and AI messages into a single string format and stores them
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
Args:
end_user_id: User identifier for the conversation
user_message: User's input message content
ai_message: AI's response message content
user_rag_memory_id: RAG memory identifier for storage location
"""
# RAG mode: combine messages into string format (maintain original logic)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
actual_config_id, long_term_messages=[]):
async def write(
storage_type,
end_user_id,
user_message,
ai_message,
user_rag_memory_id,
actual_end_user_id,
actual_config_id,
long_term_messages=None
):
"""
写入记忆(支持结构化消息)
Write memory with structured message support
Handles memory writing operations for different storage types (Neo4j/RAG).
Supports both individual message pairs and batch long-term message processing.
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
storage_type: Storage type identifier ("neo4j" or "rag")
end_user_id: Terminal user identifier
user_message: User message content
ai_message: AI response content
user_rag_memory_id: RAG memory identifier
actual_end_user_id: Actual user identifier for storage
actual_config_id: Configuration identifier
long_term_messages: Optional list of structured messages for batch processing
逻辑说明:
- RAG 模式:组合 user_message ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
Logic explanation:
- RAG mode: Combines user_message and ai_message into string format, maintains original logic
- Neo4j mode: Uses structured message lists
1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant]
2. If only user_message exists: Creates single user message [user] (for historical memory scenarios)
3. Each message is converted to independent Chunk, preserving speaker field
"""
db = next(get_db())
try:
if long_term_messages is None:
long_term_messages = []
with get_db_context() as db:
actual_config_id = resolve_config_id(actual_config_id, db)
# Neo4j 模式:使用结构化消息列表
# Neo4j mode: Use structured message lists
structured_messages = []
# 始终添加用户消息(如果不为空)
# Always add user message (if not empty)
if isinstance(user_message, str) and user_message.strip() != "":
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
# Only add assistant message when AI reply is not empty
if isinstance(ai_message, str) and ai_message.strip() != "":
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果提供了 long_term_messages,使用它替代 structured_messages
# If long_term_messages provided, use it to replace structured_messages
if long_term_messages and isinstance(long_term_messages, list):
structured_messages = long_term_messages
elif long_term_messages and isinstance(long_term_messages, str):
# 如果是 JSON 字符串,先解析
# If it's a JSON string, parse it first
try:
structured_messages = json.loads(long_term_messages)
except json.JSONDecodeError:
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
# 如果没有消息,直接返回
# If no messages, return directly
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
@@ -80,29 +107,41 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
logger.info(
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: JSON 字符串格式的消息列表
str(actual_config_id), # config_id: 配置ID字符串
actual_end_user_id, # end_user_id: User ID
structured_messages, # message: JSON string format message list
str(actual_config_id), # config_id: Configuration ID string
storage_type, # storage_type: "neo4j"
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
"""
Save long-term memory data to database
Handles the storage of long-term memory data based on different strategies
(chunk-based or aggregate-based) and manages the transition from short-term
to long-term memory storage.
Args:
long_term_messages: Long-term message data to be saved
actual_config_id: Configuration identifier for memory settings
end_user_id: User identifier for memory association
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
scope: Scope/window size for memory processing
"""
with get_db_context() as db_session:
repo = LongTermMemoryRepository(db_session)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
if len(chunk_data) == scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'---------写入短长期-----------')
else:
@@ -112,18 +151,23 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
logger.info(f'写入短长期:')
"""Window-based dialogue processing"""
'''根据窗口'''
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
'''
根据窗口获取redis数据,写入neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
langchain_messages原始数据LIST
scope窗口大小
'''
scope=scope
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
"""
Process dialogue based on window size and write to Neo4j
Manages conversation data based on a sliding window approach. When the window
reaches the specified scope size, it triggers long-term memory storage to Neo4j.
Args:
end_user_id: Terminal user identifier
memory_config: Memory configuration object containing settings
langchain_messages: Original message data list
scope: Window size determining when to trigger long-term storage
"""
scope = scope
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
@@ -135,50 +179,72 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
elif int(is_end_user_id) == int(scope):
logger.info('写入长期记忆NEO4J')
formatted_messages = (redis_messages)
# 获取 config_id(如果 memory_config 是对象,提取 config_id否则直接使用
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id
else:
config_id = memory_config
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
config_id, formatted_messages)
await write(
AgentMemory_Long_Term.STORAGE_NEO4J,
end_user_id,
"",
"",
None,
end_user_id,
config_id,
formatted_messages
)
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""根据时间"""
async def memory_long_term_storage(end_user_id,memory_config,time):
'''
根据时间获取redis数据,写入neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
'''
"""Time-based memory processing"""
async def memory_long_term_storage(end_user_id, memory_config, time):
"""
Process memory storage based on time intervals and write to Neo4j
Retrieves Redis data based on time intervals and writes it to Neo4j for
long-term storage. This function handles time-based memory consolidation.
Args:
end_user_id: Terminal user identifier
memory_config: Memory configuration object containing settings
time: Time interval for data retrieval
"""
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
format_messages = (long_time_data)
messages=[]
memory_config=memory_config.config_id
format_messages = long_time_data
messages = []
memory_config = memory_config.config_id
for i in format_messages:
message=json.loads(i['Query'])
messages+= message
if format_messages!=[]:
message = json.loads(i['Query'])
messages += message
if format_messages:
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
memory_config, messages)
'''聚合判断'''
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
"""
聚合判断函数:判断输入句子和历史消息是否描述同一事件
Aggregation judgment function: determine if input sentence and historical messages describe the same event
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
historical data or stored as separate events. This helps optimize memory storage and retrieval.
Args:
end_user_id: 终端用户ID
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: 内存配置对象
"""
end_user_id: Terminal user identifier
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: Memory configuration object containing LLM settings
Returns:
dict: Aggregation judgment result containing is_same_event flag and processed output
"""
history = None
try:
# 1. 获取历史会话数据(使用新方法)
# 1. Get historical session data (using new method)
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
history = await format_parsing(result)
if not result:
@@ -210,7 +276,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
output_value = structured.output
if isinstance(output_value, list):
output_value = [
{"role": msg.role, "content": msg.content}
{"role": msg.role, "content": msg.content}
for msg in output_value
]
@@ -223,16 +289,16 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
await write("neo4j", end_user_id, "", "", None, end_user_id,
memory_config.config_id, output_value)
return result_dict
except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}")
import traceback
traceback.print_exc()
return {
"is_same_event": False,
"output": ori_messages,
"messages": ori_messages,
"history": history if 'history' in locals() else [],
"error": str(e)
}
}

View File

@@ -2,41 +2,53 @@ import asyncio
import json
from datetime import datetime, timedelta
from langchain.tools import tool
from pydantic import BaseModel, Field
from app.core.memory.src.search import (
search_by_temporal,
search_by_keyword_temporal,
)
def extract_tool_message_content(response):
"""从agent响应中提取ToolMessage内容和工具名称"""
"""
Extract ToolMessage content and tool names from agent response
Parses agent response messages to extract tool execution results and metadata.
Handles JSON parsing and provides structured access to tool output data.
Args:
response: Agent response dictionary containing messages
Returns:
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
- tool_name: Name of the executed tool
- content: Parsed tool execution result (JSON or raw text)
"""
messages = response.get('messages', [])
for message in messages:
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
# 这是一个ToolMessage
# This is a ToolMessage
tool_content = message.content
tool_name = None
# 尝试获取工具名称
# Try to get tool name
if hasattr(message, 'name'):
tool_name = message.name
elif hasattr(message, 'tool_name'):
tool_name = message.tool_name
try:
# 解析JSON内容
# Parse JSON content
parsed_content = json.loads(tool_content)
return {
'tool_name': tool_name,
'content': parsed_content
}
except json.JSONDecodeError:
# 如果不是JSON格式直接返回内容
# If not JSON format, return content directly
return {
'tool_name': tool_name,
'content': tool_content
@@ -46,38 +58,61 @@ def extract_tool_message_content(response):
class TimeRetrievalInput(BaseModel):
"""时间检索工具的输入模式"""
"""
Input schema for time retrieval tool
Defines the expected input parameters for time-based retrieval operations.
Used for validation and documentation of tool parameters.
Attributes:
context: User input query content for search
end_user_id: Group ID for filtering search results, defaults to test user
"""
context: str = Field(description="用户输入的查询内容")
end_user_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
def create_time_retrieval_tool(end_user_id: str):
"""
创建一个带有特定end_user_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements)
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
Creates a specialized time-based retrieval tool that searches for statements within
specified time ranges. Includes field cleaning functionality to remove unnecessary
metadata from search results.
Args:
end_user_id: User identifier for scoping search results
Returns:
function: Configured TimeRetrievalWithGroupId tool function
"""
def clean_temporal_result_fields(data):
"""
清理时间搜索结果中不需要的字段,并修改结构
Clean unnecessary fields from temporal search results and modify structure
Removes metadata fields that are not needed for end-user consumption and
restructures the response format for better usability.
Args:
data: 要清理的数据
data: Data to be cleaned (dict, list, or other types)
Returns:
清理后的数据
Cleaned data with unnecessary fields removed
"""
# 需要过滤的字段列表
# List of fields to filter out
fields_to_remove = {
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
'valid_at', 'invalid_at', 'statement_ids'
}
if isinstance(data, dict):
cleaned = {}
for key, value in data.items():
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
# statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
# Change statements: {"statements": [...]} to time_search: {"statements": [...]}
cleaned_value = clean_temporal_result_fields(value)
# 进一步将内部的 statements 改为 time_search
# Further change internal statements to time_search
if 'statements' in cleaned_value:
cleaned['results'] = {
'time_search': cleaned_value['statements']
@@ -91,26 +126,35 @@ def create_time_retrieval_tool(end_user_id: str):
return [clean_temporal_result_fields(item) for item in data]
else:
return data
@tool
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
end_user_id_param: str = None, clean_output: bool = True) -> str:
"""
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
显式接收参数:
- context: 查询上下文内容
- start_date: 开始时间可选格式YYYY-MM-DD
- end_date: 结束时间可选格式YYYY-MM-DD
- end_user_id_param: 组ID可选用于覆盖默认组ID
- clean_output: 是否清理输出中的元数据字段
-end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d")
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
Performs time-based search operations with automatic metadata filtering. Supports
flexible date range specification and provides clean, user-friendly output.
Explicit parameters:
- context: Query context content
- start_date: Start time (optional, format: YYYY-MM-DD)
- end_date: End time (optional, format: YYYY-MM-DD)
- end_user_id_param: Group ID (optional, overrides default group ID)
- clean_output: Whether to clean metadata fields from output
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
Returns:
str: JSON formatted search results with temporal data
"""
async def _async_search():
# 使用传入的参数或默认值
# Use passed parameters or default values
actual_end_user_id = end_user_id_param or end_user_id
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
# 基本时间搜索
# Basic time search
results = await search_by_temporal(
end_user_id=actual_end_user_id,
start_date=actual_start_date,
@@ -118,33 +162,43 @@ def create_time_retrieval_tool(end_user_id: str):
limit=10
)
# 清理结果中不需要的字段
# Clean unnecessary fields from results
if clean_output:
cleaned_results = clean_temporal_result_fields(results)
else:
cleaned_results = results
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
return asyncio.run(_async_search())
@tool
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
clean_output: bool = True) -> str:
"""
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
显式接收参数:
- context: 查询内容
- days_back: 向前搜索的天数默认7天
- start_date: 开始时间可选格式YYYY-MM-DD
- end_date: 结束时间可选格式YYYY-MM-DD
- clean_output: 是否清理输出中的元数据字段
- end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d")
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
Performs combined keyword and temporal search operations with automatic metadata
filtering. Provides more targeted search results by combining content relevance
with time-based filtering.
Explicit parameters:
- context: Query content for keyword matching
- days_back: Number of days to search backwards, default 7 days
- start_date: Start time (optional, format: YYYY-MM-DD)
- end_date: End time (optional, format: YYYY-MM-DD)
- clean_output: Whether to clean metadata fields from output
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
Returns:
str: JSON formatted search results combining keyword and temporal data
"""
async def _async_search():
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
# 关键词时间搜索
# Keyword time search
results = await search_by_keyword_temporal(
query_text=context,
end_user_id=end_user_id,
@@ -153,7 +207,7 @@ def create_time_retrieval_tool(end_user_id: str):
limit=15
)
# 清理结果中不需要的字段
# Clean unnecessary fields from results
if clean_output:
cleaned_results = clean_temporal_result_fields(results)
else:
@@ -162,51 +216,60 @@ def create_time_retrieval_tool(end_user_id: str):
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
return asyncio.run(_async_search())
return TimeRetrievalWithGroupId
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
"""
创建混合检索工具使用run_hybrid_search进行混合检索优化输出格式并过滤不需要的字段
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
Creates an advanced hybrid search tool that combines multiple search strategies
(keyword, vector, hybrid) with automatic result cleaning and formatting.
Args:
memory_config: 内存配置对象
**search_params: 搜索参数,包含end_user_id, limit, include
memory_config: Memory configuration object containing LLM and search settings
**search_params: Search parameters including end_user_id, limit, include, etc.
Returns:
function: Configured HybridSearch tool function with async capabilities
"""
def clean_result_fields(data):
"""
递归清理结果中不需要的字段
Recursively clean unnecessary fields from results
Removes metadata fields that are not needed for end-user consumption,
improving readability and reducing response size.
Args:
data: 要清理的数据(可能是字典、列表或其他类型)
data: Data to be cleaned (can be dict, list, or other types)
Returns:
清理后的数据
Cleaned data with unnecessary fields removed
"""
# 需要过滤的字段列表
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# List of fields to filter out
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
}
if isinstance(data, dict):
# 对字典进行清理
# Clean dictionary
cleaned = {}
for key, value in data.items():
if key not in fields_to_remove:
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
cleaned[key] = clean_result_fields(value) # Recursively clean nested data
return cleaned
elif isinstance(data, list):
# 对列表中的每个元素进行清理
# Clean each element in list
return [clean_result_fields(item) for item in data]
else:
# 其他类型直接返回
# Return other types directly
return data
@tool
async def HybridSearch(
context: str,
@@ -216,57 +279,63 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
clean_output: bool = True # 新增:是否清理输出字段
clean_output: bool = True # New: whether to clean output fields
) -> str:
"""
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
Provides comprehensive search capabilities combining multiple search strategies
with intelligent result ranking and automatic metadata filtering for clean output.
Args:
context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制
end_user_id: 组ID用于过滤搜索结果
rerank_alpha: 重排序权重参数
use_forgetting_rerank: 是否使用遗忘重排序
use_llm_rerank: 是否使用LLM重排序
clean_output: 是否清理输出中的元数据字段
context: Query content for search
search_type: Search type ('keyword', 'embedding', 'hybrid')
limit: Result quantity limit
end_user_id: Group ID for filtering search results
rerank_alpha: Reranking weight parameter for result scoring
use_forgetting_rerank: Whether to use forgetting-based reranking
use_llm_rerank: Whether to use LLM-based reranking
clean_output: Whether to clean metadata fields from output
Returns:
str: JSON formatted comprehensive search results
"""
try:
# 导入run_hybrid_search函数
# Import run_hybrid_search function
from app.core.memory.src.search import run_hybrid_search
# 合并参数,优先使用传入的参数
# Merge parameters, prioritize passed parameters
final_params = {
"query_text": context,
"search_type": search_type,
"end_user_id": end_user_id or search_params.get("end_user_id"),
"limit": limit or search_params.get("limit", 10),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
"output_path": None, # 不保存到文件
"output_path": None, # Don't save to file
"memory_config": memory_config,
"rerank_alpha": rerank_alpha,
"use_forgetting_rerank": use_forgetting_rerank,
"use_llm_rerank": use_llm_rerank
}
# 执行混合检索
# Execute hybrid retrieval
raw_results = await run_hybrid_search(**final_params)
# 清理结果中不需要的字段
# Clean unnecessary fields from results
if clean_output:
cleaned_results = clean_result_fields(raw_results)
else:
cleaned_results = raw_results
# 格式化返回结果
# Format return results
formatted_results = {
"search_query": context,
"search_type": search_type,
"results": cleaned_results
}
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
except Exception as e:
error_result = {
"error": f"混合检索失败: {str(e)}",
@@ -275,38 +344,52 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
"timestamp": datetime.now().isoformat()
}
return json.dumps(error_result, ensure_ascii=False, indent=2)
return HybridSearch
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"""
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
Creates a synchronous wrapper around the async hybrid search functionality,
making it compatible with synchronous tool execution environments.
Args:
memory_config: 内存配置对象
**search_params: 搜索参数
memory_config: Memory configuration object containing search settings
**search_params: Search parameters for configuration
Returns:
function: Configured HybridSearchSync tool function
"""
@tool
def HybridSearchSync(
context: str,
search_type: str = "hybrid",
limit: int = 10,
end_user_id: str = None,
clean_output: bool = True
context: str,
search_type: str = "hybrid",
limit: int = 10,
end_user_id: str = None,
clean_output: bool = True
) -> str:
"""
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
Provides the same hybrid search capabilities as the async version but in a
synchronous execution context. Automatically handles async-to-sync conversion.
Args:
context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制
end_user_id: 组ID用于过滤搜索结果
clean_output: 是否清理输出中的元数据字段
context: Query content for search
search_type: Search type ('keyword', 'embedding', 'hybrid')
limit: Result quantity limit
end_user_id: Group ID for filtering search results
clean_output: Whether to clean metadata fields from output
Returns:
str: JSON formatted search results
"""
async def _async_search():
# 创建异步工具并执行
# Create async tool and execute
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
return await async_tool.ainvoke({
"context": context,
@@ -315,7 +398,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"end_user_id": end_user_id,
"clean_output": clean_output
})
return asyncio.run(_async_search())
return HybridSearchSync
return HybridSearchSync

View File

@@ -1,20 +1,28 @@
import json
from langchain_core.messages import HumanMessage, AIMessage
async def format_parsing(messages: list,type:str='string'):
async def format_parsing(messages: list, type: str = 'string'):
"""
格式化解析消息列表
Format and parse message lists into different output types
Processes message lists from storage and converts them into either string format
or dictionary format based on the specified type parameter. Handles JSON parsing
and role-based message organization.
Args:
messages: 消息列表
type: 返回类型 ('string''dict')
messages: List of message objects from storage containing message data
type: Return type specification ('string' for text format, 'dict' for key-value pairs)
Returns:
格式化后的消息列表
list: Formatted message list in the specified format
- 'string': List of formatted text messages with role prefixes
- 'dict': List of dictionaries mapping user messages to AI responses
"""
result = []
user=[]
ai=[]
user = []
ai = []
for message in messages:
hstory_messages = message['messages']
@@ -24,25 +32,38 @@ async def format_parsing(messages: list,type:str='string'):
role = content['role']
content = content['content']
if type == "string":
if role == 'human' or role=="user":
if role == 'human' or role == "user":
content = '用户:' + content
else:
content = 'AI:' + content
result.append(content)
if type == "dict" :
if role == 'human' or role=="user":
user.append( content)
if type == "dict":
if role == 'human' or role == "user":
user.append(content)
else:
ai.append(content)
if type == "dict":
for key,values in zip(user,ai):
result.append({key:values})
for key, values in zip(user, ai):
result.append({key: values})
return result
async def messages_parse(messages: list | dict):
user=[]
ai=[]
database=[]
"""
Parse messages from storage format into user-AI conversation pairs
Extracts and organizes conversation data from stored message format,
separating user and AI messages and pairing them for database storage.
Args:
messages: List or dictionary containing stored message data with Query fields
Returns:
list: List of dictionaries containing user-AI message pairs for database storage
"""
user = []
ai = []
database = []
for message in messages:
Query = message['Query']
Query = json.loads(Query)
@@ -54,10 +75,23 @@ async def messages_parse(messages: list | dict):
ai.append(data['content'])
for key, values in zip(user, ai):
database.append({key, values})
return database
return database
async def agent_chat_messages(user_content,ai_content):
async def agent_chat_messages(user_content, ai_content):
"""
Create structured chat message format for agent conversations
Formats user and AI content into a standardized message structure suitable
for agent processing and storage. Creates role-based message objects.
Args:
user_content: User's message content string
ai_content: AI's response content string
Returns:
list: List of structured message dictionaries with role and content fields
"""
messages = [
{
"role": "user",

View File

@@ -13,7 +13,6 @@ from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
@@ -42,10 +41,26 @@ async def make_write_graph():
yield graph
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
end_user_id: str = '', scope: int = 6):
"""
Handle long-term memory storage with different strategies
Supports multiple storage strategies including chunk-based, time-based,
and aggregate judgment approaches for long-term memory persistence.
Args:
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
langchain_messages: List of messages to store
memory_config: Memory configuration identifier
end_user_id: User group identifier
scope: Scope parameter for chunk-based storage (default: 6)
"""
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, (langchain_messages))
write_store.save_session_write(end_user_id, langchain_messages)
# 获取数据库会话
with get_db_context() as db_session:
config_service = MemoryConfigService(db_session)
@@ -53,26 +68,39 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
)
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
"""方案三:聚合判断"""
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
"""Time-based strategy"""
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
"""Strategy 3: Aggregate judgment"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
"""
Write long-term memory with different storage types
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
Handles both RAG-based storage and traditional memory storage approaches.
For traditional storage, uses chunk-based strategy with paired user-AI messages.
Args:
storage_type: Type of storage (RAG or traditional)
end_user_id: User group identifier
message_chat: User message content
aimessages: AI response messages
user_rag_memory_id: RAG memory identifier
actual_config_id: Actual configuration ID
"""
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
else:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages)
@@ -101,4 +129,4 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_
#
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())
# asyncio.run(main())

View File

@@ -8,10 +8,11 @@ from langgraph.graph import add_messages
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
class WriteState(TypedDict):
'''
"""
Langgrapg Writing TypedDict
'''
"""
messages: Annotated[list[AnyMessage], add_messages]
end_user_id: str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
@@ -20,6 +21,7 @@ class WriteState(TypedDict):
data: str
language: str # 语言类型 ("zh" 中文, "en" 英文)
class ReadState(TypedDict):
"""
LangGraph 工作流状态定义
@@ -43,18 +45,20 @@ class ReadState(TypedDict):
config_id: str
data: str # 新增字段用于传递内容
spit_data: dict # 新增字段用于传递问题分解结果
problem_extension:dict
problem_extension: dict
storage_type: str
user_rag_memory_id: str
llm_id: str
embedding_id: str
memory_config: object # 新增字段用于传递内存配置对象
retrieve:dict
retrieve: dict
RetrieveSummary: dict
InputSummary: dict
verify: dict
SummaryFails: dict
summary: dict
class COUNTState:
"""
工作流对话检索内容计数器
@@ -99,6 +103,7 @@ class COUNTState:
self.total = 0
print("[COUNTState] 已重置为 0")
def deduplicate_entries(entries):
seen = set()
deduped = []
@@ -109,6 +114,7 @@ def deduplicate_entries(entries):
deduped.append(entry)
return deduped
def merge_to_key_value_pairs(data, query_key, result_key):
grouped = defaultdict(list)
for item in data:
@@ -142,4 +148,4 @@ def convert_extended_question_to_question(data):
return [convert_extended_question_to_question(item) for item in data]
else:
# 其他类型直接返回
return data
return data

View File

@@ -5,7 +5,7 @@ from typing import List, Dict, Optional
from app.core.logging_config import get_memory_logger
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
from app.core.memory.models.triplet_models import TripletExtractionResponse
from app.core.memory.models.message_models import DialogData, Statement
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
@@ -14,15 +14,15 @@ from app.core.memory.utils.log.logging_utils import prompt_logger
logger = get_memory_logger(__name__)
class TripletExtractor:
"""Extracts knowledge triplets and entities from statements using LLM"""
def __init__(
self,
llm_client: OpenAIClient,
ontology_types: Optional[OntologyTypeList] = None,
language: str = "zh"):
self,
llm_client: OpenAIClient,
ontology_types: Optional[OntologyTypeList] = None,
language: str = "zh"
):
"""Initialize the TripletExtractor with an LLM client
Args:
@@ -65,7 +65,8 @@ class TripletExtractor:
# Create messages for LLM
messages = [
{"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
{"role": "system",
"content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
{"role": "user", "content": prompt_content}
]
@@ -116,7 +117,8 @@ class TripletExtractor:
logger.error(f"Error processing statement: {e}", exc_info=True)
return TripletExtractionResponse(triplets=[], entities=[])
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]:
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[
str, TripletExtractionResponse]:
"""Extract triplets and entities from statements
Args:

View File

@@ -1,11 +1,11 @@
"""
自我反思引擎实现
Self-Reflection Engine Implementation
该模块实现了记忆系统的自我反思功能,包括:
1. 基于时间的反思 - 根据时间周期触发反思
2. 基于事实的反思 - 检测记忆冲突并解决
3. 综合反思 - 整合多种反思策略
4. 反思结果应用 - 更新记忆库
This module implements the self-reflection functionality of the memory system, including:
1. Time-based reflection - Triggers reflection based on time cycles
2. Fact-based reflection - Detects and resolves memory conflicts
3. Comprehensive reflection - Integrates multiple reflection strategies
4. Reflection result application - Updates memory database
"""
import asyncio
@@ -38,7 +38,7 @@ from app.schemas.memory_storage_schema import (
)
from pydantic import BaseModel
# 配置日志
# Configure logging
_root_logger = logging.getLogger()
if not _root_logger.handlers:
logging.basicConfig(
@@ -49,35 +49,62 @@ else:
_root_logger.setLevel(logging.INFO)
class TranslationResponse(BaseModel):
"""翻译响应模型"""
"""Translation response model for language conversion"""
data: str
class ReflectionRange(str, Enum):
"""反思范围枚举"""
PARTIAL = "partial" # 从检索结果中反思
ALL = "all" # 从整个数据库中反思
"""
Reflection range enumeration
Defines the scope of data to be included in reflection operations.
"""
PARTIAL = "partial" # Reflect from retrieval results
ALL = "all" # Reflect from entire database
class ReflectionBaseline(str, Enum):
"""反思基线枚举"""
TIME = "TIME" # 基于时间的反思
FACT = "FACT" # 基于事实的反思
HYBRID = "HYBRID" # 混合反思
"""
Reflection baseline enumeration
Defines the strategy or approach used for reflection operations.
"""
TIME = "TIME" # Time-based reflection
FACT = "FACT" # Fact-based reflection
HYBRID = "HYBRID" # Hybrid reflection combining multiple strategies
class ReflectionConfig(BaseModel):
"""反思引擎配置"""
"""
Reflection engine configuration
Defines all configuration parameters for the reflection engine including
operation modes, model settings, and evaluation criteria.
Attributes:
enabled: Whether reflection engine is enabled
iteration_period: Reflection cycle period (e.g., "3" hours)
reflexion_range: Scope of reflection (PARTIAL or ALL)
baseline: Reflection strategy (TIME, FACT, or HYBRID)
model_id: LLM model identifier for reflection operations
end_user_id: User identifier for scoped operations
output_example: Example output format for guidance
memory_verify: Enable memory verification checks
quality_assessment: Enable quality assessment evaluation
violation_handling_strategy: Strategy for handling violations
language_type: Language type for output ("zh" or "en")
"""
enabled: bool = False
iteration_period: str = "3" # 反思周期
iteration_period: str = "3" # Reflection cycle period
reflexion_range: ReflectionRange = ReflectionRange.PARTIAL
baseline: ReflectionBaseline = ReflectionBaseline.TIME
model_id: Optional[str] = None # 模型ID
model_id: Optional[str] = None # Model ID
end_user_id: Optional[str] = None
output_example: Optional[str] = None # 输出示例
output_example: Optional[str] = None # Output example
# 评估相关字段
memory_verify: bool = True # 记忆验证
quality_assessment: bool = True # 质量评估
violation_handling_strategy: str = "warn" # 违规处理策略
# Evaluation related fields
memory_verify: bool = True # Memory verification
quality_assessment: bool = True # Quality assessment
violation_handling_strategy: str = "warn" # Violation handling strategy
language_type: str = "zh"
class Config:
@@ -85,7 +112,21 @@ class ReflectionConfig(BaseModel):
class ReflectionResult(BaseModel):
"""反思结果"""
"""
Reflection operation result
Contains comprehensive information about the outcome of a reflection operation
including success status, metrics, and execution details.
Attributes:
success: Whether the reflection operation succeeded
message: Descriptive message about the operation result
conflicts_found: Number of conflicts detected during reflection
conflicts_resolved: Number of conflicts successfully resolved
memories_updated: Number of memory entries updated in database
execution_time: Total time taken for the reflection operation
details: Additional details about the operation (optional)
"""
success: bool
message: str
conflicts_found: int = 0
@@ -97,9 +138,22 @@ class ReflectionResult(BaseModel):
class ReflectionEngine:
"""
自我反思引擎
负责执行记忆系统的自我反思,包括冲突检测、冲突解决和记忆更新。
Self-Reflection Engine
Responsible for executing memory system self-reflection operations including
conflict detection, conflict resolution, and memory updates. Supports multiple
reflection strategies and provides comprehensive result tracking.
The engine can operate in different modes:
- Time-based: Reflects on memories within specific time periods
- Fact-based: Detects and resolves factual conflicts in memories
- Hybrid: Combines multiple reflection strategies
Attributes:
config: Reflection engine configuration
neo4j_connector: Neo4j database connector
llm_client: Language model client for analysis
Various function handlers for data processing and prompt rendering
"""
def __init__(
@@ -115,18 +169,21 @@ class ReflectionEngine:
update_query: Optional[str] = None
):
"""
初始化反思引擎
Initialize reflection engine
Sets up the reflection engine with configuration and optional dependencies.
Uses lazy initialization to avoid circular imports and optimize startup time.
Args:
config: 反思引擎配置
neo4j_connector: Neo4j 连接器(可选)
llm_client: LLM 客户端(可选)
get_data_func: 获取数据的函数(可选)
render_evaluate_prompt_func: 渲染评估提示词的函数(可选)
render_reflexion_prompt_func: 渲染反思提示词的函数(可选)
conflict_schema: 冲突结果 Schema(可选)
reflexion_schema: 反思结果 Schema(可选)
update_query: 更新查询语句(可选)
config: Reflection engine configuration object
neo4j_connector: Neo4j connector instance (optional, will be created if not provided)
llm_client: LLM client instance (optional, will be created if not provided)
get_data_func: Function for retrieving data (optional, uses default if not provided)
render_evaluate_prompt_func: Function for rendering evaluation prompts (optional)
render_reflexion_prompt_func: Function for rendering reflection prompts (optional)
conflict_schema: Schema for conflict result validation (optional)
reflexion_schema: Schema for reflection result validation (optional)
update_query: Query string for database updates (optional)
"""
self.config = config
self.neo4j_connector = neo4j_connector
@@ -137,14 +194,20 @@ class ReflectionEngine:
self.conflict_schema = conflict_schema
self.reflexion_schema = reflexion_schema
self.update_query = update_query
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5
self._semaphore = asyncio.Semaphore(5) # Default concurrency limit of 5
# 延迟导入以避免循环依赖
# Lazy import to avoid circular dependencies
self._lazy_init_done = False
def _lazy_init(self):
"""延迟初始化,避免循环导入"""
"""
Lazy initialization to avoid circular imports
Initializes dependencies only when needed, preventing circular import issues
and optimizing startup performance. Sets up default implementations for
any components not provided during construction.
"""
if self._lazy_init_done:
return
@@ -158,7 +221,7 @@ class ReflectionEngine:
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(self.config.model_id)
elif isinstance(self.llm_client, str):
# 如果 llm_client 是字符串model_id则用它初始化客户端
# If llm_client is a string (model_id), use it to initialize the client
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
@@ -172,10 +235,10 @@ class ReflectionEngine:
model_config = config_service.get_model_config(model_id)
extra_params={
"temperature": 0.2, # 降低温度提高响应速度和一致性
"max_tokens": 600, # 限制最大token数
"top_p": 0.8, # 优化采样参数
"stream": False, # 确保非流式输出以获得最快响应
"temperature": 0.2, # Lower temperature for faster response and consistency
"max_tokens": 600, # Limit maximum token count
"top_p": 0.8, # Optimize sampling parameters
"stream": False, # Ensure non-streaming output for fastest response
}
self.llm_client = OpenAIClient(RedBearModelConfig(
@@ -191,7 +254,7 @@ class ReflectionEngine:
if self.get_data_func is None:
self.get_data_func = get_data
# 导入get_data_statement函数
# Import get_data_statement function
if not hasattr(self, 'get_data_statement'):
self.get_data_statement = get_data_statement
@@ -223,13 +286,20 @@ class ReflectionEngine:
async def execute_reflection(self, host_id) -> ReflectionResult:
"""
执行完整的反思流程
Execute complete reflection workflow
Performs the full reflection process including data retrieval, conflict detection,
conflict resolution, and memory updates. This is the main entry point for
reflection operations.
Args:
host_id: 主机ID
host_id: Host identifier for scoping reflection operations
Returns:
ReflectionResult: 反思结果
ReflectionResult: Comprehensive result of the reflection operation including
success status, conflict metrics, and execution time
"""
# 延迟初始化
# Lazy initialization
self._lazy_init()
if not self.config.enabled:
@@ -243,7 +313,7 @@ class ReflectionEngine:
print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment)
try:
# 1. 获取反思数据
# 1. Get reflection data
reflexion_data, statement_databasets = await self._get_reflexion_data(host_id)
if not reflexion_data:
return ReflectionResult(
@@ -252,7 +322,7 @@ class ReflectionEngine:
execution_time=asyncio.get_event_loop().time() - start_time
)
# 2. 检测冲突(基于事实的反思)
# 2. Detect conflicts (fact-based reflection)
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
conflict_list=[]
for i in conflict_data:
@@ -261,7 +331,7 @@ class ReflectionEngine:
conflicts_found=0
# 3. 解决冲突
# 3. Resolve conflicts
solved_data = await self._resolve_conflicts(conflict_list, statement_databasets)
if not solved_data:
@@ -276,7 +346,7 @@ class ReflectionEngine:
logging.info(f"解决了 {conflicts_resolved} 个冲突")
# 4. 应用反思结果(更新记忆库)
# 4. Apply reflection results (update memory database)
memories_updated=await self._apply_reflection_results(solved_data)
execution_time = asyncio.get_event_loop().time() - start_time
@@ -302,7 +372,19 @@ class ReflectionEngine:
)
async def Translate(self, text):
# 翻译中文为英文
"""
Translate Chinese text to English
Uses the configured LLM to translate Chinese text to English with structured output.
Provides consistent translation format for reflection results.
Args:
text: Chinese text to be translated
Returns:
str: Translated English text
"""
# Translate Chinese to English
translation_messages = [
{
"role": "user",
@@ -316,6 +398,19 @@ class ReflectionEngine:
)
return response.data
async def extract_translation(self,data):
"""
Extract and translate reflection data to English
Processes reflection data structure and translates all Chinese content to English.
Handles nested data structures including memory verifications, quality assessments,
and reflection data while preserving the original structure.
Args:
data: Dictionary containing reflection data with Chinese content
Returns:
dict: Translated data structure with English content
"""
end_datas={}
end_datas['source_data']=await self.Translate(data['source_data'])
quality_assessments = []
@@ -350,6 +445,18 @@ class ReflectionEngine:
return end_datas
async def reflection_run(self):
"""
Execute reflection workflow with comprehensive data processing
Performs a complete reflection operation including conflict detection, resolution,
and result formatting. Supports both Chinese and English output based on
configuration settings.
Returns:
dict: Comprehensive reflection results including source data, memory verifications,
quality assessments, and reflection data. Results are translated to English
if language_type is set to 'en'.
"""
self._lazy_init()
start_time = time.time()
memory_verifies_flag = self.config.memory_verify
@@ -367,7 +474,7 @@ class ReflectionEngine:
result_data['source_data'] = "我是 2023 年春天去北京工作的后来基本一直都在北京上班也没怎么换过城市。不过后来公司调整2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X银行卡是 6222023847595898这些一直没变。对了其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
# 2. 检测冲突(基于事实的反思)
conflict_data = await self._detect_conflicts(databasets, source_data)
# 遍历数据提取字段
# Traverse data to extract fields
quality_assessments = []
memory_verifies = []
for item in conflict_data:
@@ -375,9 +482,9 @@ class ReflectionEngine:
memory_verifies.append(item['memory_verify'])
result_data['memory_verifies'] = memory_verifies
result_data['quality_assessments'] = quality_assessments
conflicts_found = 0 # 初始化为整数0而不是空字符串
conflicts_found = 0 # Initialize as integer 0 instead of empty string
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
# Clearn conflict_dataAnd memory_verifyquality_assessment
# Clean conflict_data, and memory_verify and quality_assessment
cleaned_conflict_data = []
for item in conflict_data:
cleaned_item = {
@@ -389,7 +496,7 @@ class ReflectionEngine:
for item in conflict_data:
cleaned_data = []
for row in item.get("data", []):
# 删除 created_at / expired_at
# Remove created_at / expired_at
cleaned_row = {
k: v
for k, v in row.items()
@@ -402,7 +509,7 @@ class ReflectionEngine:
}
cleaned_conflict_data_.append(cleaned_item)
print(cleaned_conflict_data_)
# 3. 解决冲突
# 3. Resolve conflicts
solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data)
if not solved_data:
return ReflectionResult(
@@ -413,7 +520,7 @@ class ReflectionEngine:
)
reflexion_data = []
# 遍历数据提取reflexion字段
# Traverse data to extract reflexion fields
for item in solved_data:
if 'results' in item:
for result in item['results']:
@@ -431,15 +538,24 @@ class ReflectionEngine:
async def extract_fields_from_json(self):
"""从example.json中提取source_data和databasets字段"""
"""
Extract source_data and databasets fields from example.json
Reads reflection example data from the example.json file and extracts
the source data and database statements for testing and demonstration purposes.
Returns:
tuple: (source_data, databasets) extracted from the example file
Returns empty lists if file reading fails
"""
prompt_dir = os.path.join(os.path.dirname(__file__), "example")
try:
# 读取JSON文件
# Read JSON file
with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f:
data = json.loads(f.read())
# 提取memory_verify下的字段
# Extract fields under memory_verify
memory_verify = data.get("memory_verify", {})
source_data = memory_verify.get("source_data", [])
databasets = memory_verify.get("databasets", [])
@@ -451,15 +567,17 @@ class ReflectionEngine:
async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]:
"""
获取反思数据
根据配置的反思范围获取需要反思的记忆数据。
Get reflection data from database
Retrieves memory data for reflection based on the configured reflection range.
Supports both partial (from retrieval results) and full (entire database) modes.
Args:
host_id: 主机ID
host_id: Host UUID identifier for scoping data retrieval
Returns:
List[Any]: 反思数据列表
tuple: (reflexion_data, statement_data) containing memory data for reflection
Returns empty lists if query fails
"""
print("=== 获取反思数据 ===")
@@ -484,26 +602,29 @@ class ReflectionEngine:
async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]:
"""
检测冲突(基于事实的反思)
使用 LLM 分析记忆数据,检测其中的冲突。
Detect conflicts (fact-based reflection)
Uses LLM to analyze memory data and detect conflicts within the memories.
Performs comprehensive conflict detection including memory verification and
quality assessment based on configuration settings.
Args:
data: 待检测的记忆数据
data: Memory data to be analyzed for conflicts
statement_databasets: Statement database records for context
Returns:
List[Any]: 冲突记忆列表
List[Any]: List of detected conflicts with detailed analysis
"""
if not data:
return []
# 数据预处理:如果数据量太少,直接返回无冲突
# Data preprocessing: if data is too small, return no conflicts directly
if len(data) < 2:
logging.info("数据量不足,无需检测冲突")
return []
# 使用转换后的数据
# print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
# Use converted data
# print("Converted data:", data[:2] if len(data) > 2 else data) # Only print first 2 to avoid long logs
memory_verify = self.config.memory_verify
logging.info("====== 冲突检测开始 ======")
@@ -512,7 +633,7 @@ class ReflectionEngine:
language_type=self.config.language_type
try:
# 渲染冲突检测提示词
# Render conflict detection prompt
rendered_prompt = await self.render_evaluate_prompt_func(
data,
self.conflict_schema,
@@ -526,7 +647,7 @@ class ReflectionEngine:
messages = [{"role": "user", "content": rendered_prompt}]
logging.info(f"提示词长度: {len(rendered_prompt)}")
# 调用 LLM 进行冲突检测
# Call LLM for conflict detection
response = await self.llm_client.response_structured(
messages,
self.conflict_schema
@@ -539,7 +660,7 @@ class ReflectionEngine:
logging.error("LLM 冲突检测输出解析失败")
return []
# 标准化返回格式
# Standardize return format
if isinstance(response, BaseModel):
return [response.model_dump()]
elif hasattr(response, 'dict'):
@@ -553,15 +674,17 @@ class ReflectionEngine:
async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]:
"""
解决冲突
使用 LLM 对检测到的冲突进行反思和解决。
Resolve detected conflicts
Uses LLM to perform reflection and resolution on detected conflicts.
Processes conflicts in parallel for efficiency while respecting concurrency limits.
Args:
conflicts: 冲突列表
conflicts: List of conflicts to be resolved
statement_databasets: Statement database records for context
Returns:
List[Any]: 解决方案列表
List[Any]: List of resolution solutions with reflection results
"""
if not conflicts:
return []
@@ -570,12 +693,12 @@ class ReflectionEngine:
baseline = self.config.baseline
memory_verify = self.config.memory_verify
# 并行处理每个冲突
# Process each conflict in parallel
async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]:
"""解决单个冲突"""
"""Resolve a single conflict"""
async with self._semaphore:
try:
# 渲染反思提示词
# Render reflection prompt
rendered_prompt = await self.render_reflexion_prompt_func(
[conflict],
self.reflexion_schema,
@@ -587,7 +710,7 @@ class ReflectionEngine:
messages = [{"role": "user", "content": rendered_prompt}]
# 调用 LLM 进行反思
# Call LLM for reflection
response = await self.llm_client.response_structured(
messages,
self.reflexion_schema
@@ -596,7 +719,7 @@ class ReflectionEngine:
if not response:
return None
# 标准化返回格式
# Standardize return format
if isinstance(response, BaseModel):
return response.model_dump()
elif hasattr(response, 'dict'):
@@ -610,11 +733,11 @@ class ReflectionEngine:
logging.warning(f"解决单个冲突失败: {e}")
return None
# 并发执行所有冲突解决任务
# Execute all conflict resolution tasks concurrently
tasks = [_resolve_one(conflict) for conflict in conflicts]
results = await asyncio.gather(*tasks, return_exceptions=False)
# 过滤掉失败的结果
# Filter out failed results
solved = [r for r in results if r is not None]
logging.info(f"成功解决 {len(solved)}/{len(conflicts)} 个冲突")
@@ -626,15 +749,16 @@ class ReflectionEngine:
solved_data: List[Dict[str, Any]]
) -> int:
"""
应用反思结果(更新记忆库)
将解决冲突后的记忆更新到 Neo4j 数据库中。
Apply reflection results (update memory database)
Updates the Neo4j database with resolved conflicts and reflection results.
Processes the solved data and applies changes to the memory storage system.
Args:
solved_data: 解决方案列表
solved_data: List of resolved conflict solutions with reflection data
Returns:
int: 成功更新的记忆数量
int: Number of successfully updated memory entries
"""
changes = extract_and_process_changes(solved_data)
success_count = await neo4j_data(changes)
@@ -642,80 +766,86 @@ class ReflectionEngine:
# 基于时间的反思方法
# Time-based reflection methods
async def time_based_reflection(
self,
host_id: uuid.UUID,
time_period: Optional[str] = None
) -> ReflectionResult:
"""
基于时间的反思
根据时间周期触发反思,检查在指定时间段内的记忆。
Time-based reflection
Triggers reflection based on time cycles, checking memories within
specified time periods. Uses the configured iteration period if
no specific time period is provided.
Args:
host_id: 主机ID
time_period: 时间周期(如"三小时"),如果不提供则使用配置中的值
host_id: Host UUID identifier for scoping reflection
time_period: Time period (e.g., "three hours"), uses config value if not provided
Returns:
ReflectionResult: 反思结果
ReflectionResult: Comprehensive reflection operation result
"""
period = time_period or self.config.iteration_period
logging.info(f"执行基于时间的反思,周期: {period}")
# 使用标准反思流程
# Use standard reflection workflow
return await self.execute_reflection(host_id)
# 基于事实的反思方法
# Fact-based reflection methods
async def fact_based_reflection(
self,
host_id: uuid.UUID
) -> ReflectionResult:
"""
基于事实的反思
检测记忆中的事实冲突并解决。
Fact-based reflection
Detects and resolves factual conflicts within memories. Analyzes
memory data for inconsistencies and contradictions that need resolution.
Args:
host_id: 主机ID
host_id: Host UUID identifier for scoping reflection
Returns:
ReflectionResult: 反思结果
ReflectionResult: Comprehensive reflection operation result
"""
logging.info("执行基于事实的反思")
# 使用标准反思流程
# Use standard reflection workflow
return await self.execute_reflection(host_id)
# 综合反思方法
# Comprehensive reflection methods
async def comprehensive_reflection(
self,
host_id: uuid.UUID
) -> ReflectionResult:
"""
综合反思
整合基于时间和基于事实的反思策略。
Comprehensive reflection
Integrates time-based and fact-based reflection strategies based on
the configured baseline. Supports hybrid approaches that combine
multiple reflection methodologies.
Args:
host_id: 主机ID
host_id: Host UUID identifier for scoping reflection
Returns:
ReflectionResult: 反思结果
ReflectionResult: Comprehensive reflection operation result combining
multiple strategies if using hybrid baseline
"""
logging.info("执行综合反思")
# 根据配置的基线选择反思策略
# Choose reflection strategy based on configured baseline
if self.config.baseline == ReflectionBaseline.TIME:
return await self.time_based_reflection(host_id)
elif self.config.baseline == ReflectionBaseline.FACT:
return await self.fact_based_reflection(host_id)
elif self.config.baseline == ReflectionBaseline.HYBRID:
# 混合策略:先执行基于时间的反思,再执行基于事实的反思
# Hybrid strategy: execute time-based reflection first, then fact-based reflection
time_result = await self.time_based_reflection(host_id)
fact_result = await self.fact_based_reflection(host_id)
# 合并结果
# Merge results
return ReflectionResult(
success=time_result.success and fact_result.success,
message=f"时间反思: {time_result.message}; 事实反思: {fact_result.message}",

View File

@@ -2,9 +2,17 @@ import json
def escape_lucene_query(query: str) -> str:
"""Escape Lucene special characters in a free-text query.
This prevents ParseException when using Neo4j full-text procedures.
"""
Escape special characters in Lucene queries
Prevents ParseException when using Neo4j full-text search procedures.
Escapes all Lucene reserved special characters and operators.
Args:
query: Original query string
Returns:
str: Escaped query string safe for Lucene search
"""
if query is None:
return ""
@@ -22,11 +30,21 @@ def escape_lucene_query(query: str) -> str:
return s
def extract_plain_query(query_input: str) -> str:
"""Extract clean, plain-text query from various input forms.
"""
Extract clean plain-text query from various input forms
Handles the following cases:
- Strips surrounding quotes and whitespace
- If input looks like JSON, prefers the 'original' field
- Fallbacks to the raw string when parsing fails
- Falls back to raw string when parsing fails
- Handles dictionary-type input
- Best-effort unescape common escape characters
Args:
query_input: Query input in various forms (string, dict, etc.)
Returns:
str: Extracted plain-text query string
"""
if query_input is None:
return ""

View File

@@ -4,7 +4,13 @@ from datetime import datetime
def validate_date_format(date_str: str) -> bool:
"""
Validate if the date string is in the format YYYY-MM-DD.
Validate if date string conforms to YYYY-MM-DD format
Args:
date_str: Date string to validate
Returns:
bool: True if format is correct, False otherwise
"""
pattern = r"^\d{4}-\d{1,2}-\d{1,2}$"
return bool(re.match(pattern, date_str))
@@ -41,7 +47,20 @@ def normalize_date(date_str: str) -> str:
def preprocess_date_string(date_str: str) -> str:
"""预处理日期字符串,处理特殊格式"""
"""
预处理日期字符串,处理特殊格式
处理以下特殊格式:
- 年份后直接跟月份没有分隔符的格式(如 "20259/28"
- 无分隔符的纯数字格式(如 "20251028", "251028"
- 混合分隔符,统一为 "-"
Args:
date_str: 原始日期字符串
Returns:
str: 预处理后的日期字符串,格式为 "YYYY-MM-DD""YYYY-MM"
"""
# 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔)
match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str)
@@ -78,7 +97,23 @@ def preprocess_date_string(date_str: str) -> str:
def fallback_parse(date_str: str) -> str:
"""备选解析方案"""
"""
备选日期解析方案
当智能解析失败时,尝试使用预定义的日期格式进行解析。
支持多种常见的日期格式,包括:
- YYYY-MM-DD, YYYY/MM/DD, YYYY.MM.DD
- YYYYMMDD, YYMMDD
- MM-DD-YYYY, MM/DD/YYYY, MM.DD.YYYY
- DD-MM-YYYY, DD/MM/YYYY, DD.MM.YYYY
- YYYY-MM, YYYY/MM, YYYY.MM
Args:
date_str: 待解析的日期字符串
Returns:
str: 标准化后的日期字符串YYYY-MM-DD格式解析失败时返回原字符串
"""
# 尝试常见的日期格式[citation:4][citation:5]
formats_to_try = [

View File

@@ -2,15 +2,15 @@ import os
from jinja2 import Environment, FileSystemLoader
from typing import List, Dict, Any
# Setup Jinja2 environment
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
baseline: str = "TIME",
memory_verify: bool = False,quality_assessment:bool = False,
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
memory_verify: bool = False, quality_assessment: bool = False,
statement_databasets=None, language_type: str = "zh") -> str:
"""
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
@@ -23,6 +23,8 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
Returns:
Rendered prompt content as string
"""
if statement_databasets is None:
statement_databasets = []
template = prompt_env.get_template("evaluate.jinja2")
# Convert Pydantic model to JSON schema if needed
@@ -46,7 +48,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
statement_databasets=None, language_type: str = "zh") -> str:
"""
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
@@ -58,6 +60,8 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
Returns:
Rendered prompt content as a string.
"""
if statement_databasets is None:
statement_databasets = []
template = prompt_env.get_template("reflexion.jinja2")
# Convert Pydantic model to JSON schema if needed
@@ -69,7 +73,7 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
json_schema = schema
rendered_prompt = template.render(data=data, json_schema=json_schema,
baseline=baseline,memory_verify=memory_verify,
statement_databasets=statement_databasets,language_type=language_type)
baseline=baseline, memory_verify=memory_verify,
statement_databasets=statement_databasets, language_type=language_type)
return rendered_prompt

View File

@@ -1,23 +1,19 @@
from __future__ import annotations
import asyncio
import os
import time
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, TypeVar
from typing import Any, Dict, Optional, TypeVar
from langchain_aws import ChatBedrock
from langchain_community.chat_models import ChatTongyi
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLLM
from langchain_ollama import OllamaLLM
from langchain_openai import ChatOpenAI, OpenAI
from pydantic import BaseModel, Field
import httpx
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.models.models_model import ModelProvider, ModelType
from langchain_community.document_compressors import JinaRerank
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel, BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnableSerializable
from pydantic import BaseModel, Field
T = TypeVar("T")
@@ -163,25 +159,17 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
# dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni:
from langchain_openai import ChatOpenAI
return ChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
if type == ModelType.LLM:
from langchain_openai import OpenAI
return OpenAI
elif type == ModelType.CHAT:
from langchain_openai import ChatOpenAI
return ChatOpenAI
elif provider == ModelProvider.DASHSCOPE:
from langchain_community.chat_models import ChatTongyi
return ChatTongyi
elif provider == ModelProvider.OLLAMA:
from langchain_ollama import OllamaLLM
return OllamaLLM
elif provider == ModelProvider.BEDROCK:
from langchain_aws import ChatBedrock, ChatBedrockConverse
return ChatBedrock
else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)

View File

@@ -16,6 +16,7 @@ from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.db import get_db_read
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
from app.schemas import FileInput
from app.schemas.model_schema import ModelInfo
from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__)
@@ -620,11 +621,12 @@ class BaseNode(ABC):
@staticmethod
async def process_message(
provider: str,
is_omni: bool,
api_config: ModelInfo,
content: str | dict | FileObject,
end_user_id: str,
enable_file=False
) -> list | str | None:
provider = api_config.provider
if isinstance(content, dict):
content = FileObject(
type=content.get("type"),
@@ -643,7 +645,7 @@ class BaseNode(ABC):
if content.content_cache.get(provider):
return content.content_cache[provider]
with get_db_read() as db:
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
multimodel_service = MultimodalService(db, api_config=api_config)
file_obj = FileInput(
type=content.type,
url=content.url,
@@ -653,7 +655,8 @@ class BaseNode(ABC):
)
file_obj.set_content(content.get_content())
message = await multimodel_service.process_files(
[file_obj]
end_user_id,
[file_obj],
)
content.set_content(file_obj.get_content())
if message:

View File

@@ -5,7 +5,7 @@ from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.variable.base_variable import VariableType
@@ -23,6 +23,26 @@ class IfElseNode(BaseNode):
"output": VariableType.STRING
}
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
result = []
for case in self.typed_config.cases:
expressions = []
for expression in case.expressions:
expressions.append({
"left": self.get_variable(expression.left, variable_pool, strict=False),
"right": expression.right
if expression.input_type == ValueInputType.CONSTANT
else self.get_variable(expression.right, variable_pool, strict=False),
"operator": expression.operator,
})
result.append({
"expressions": expressions,
"logical_operator": case.logical_operator,
})
return {
"cases": result
}
@staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
match operator:

View File

@@ -30,6 +30,12 @@ class KnowledgeRetrievalNode(BaseNode):
"output": VariableType.ARRAY_STRING
}
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {
"query": self._render_template(self.typed_config.query, variable_pool),
"knowledge_bases": [kb_config.model_dump(mode="json") for kb_config in self.typed_config.knowledge_bases],
}
@staticmethod
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
"""

View File

@@ -20,6 +20,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_context
from app.models import ModelType
from app.schemas.model_schema import ModelInfo
from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__)
@@ -113,12 +114,15 @@ class LLMNode(BaseNode):
# 在 Session 关闭前提取所有需要的数据
api_config = self.model_balance(config)
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
api_base = api_config.api_base
is_omni = api_config.is_omni
model_type = config.type
model_info = ModelInfo(
model_name=api_config.model_name,
model_type=ModelType(config.type),
api_key=api_config.api_key,
api_base=api_config.api_base,
provider=api_config.provider,
is_omni=api_config.is_omni,
capability=api_config.capability
)
# 4. 创建 LLM 实例(使用已提取的数据)
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
@@ -126,17 +130,18 @@ class LLMNode(BaseNode):
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
model_name=model_info.model_name,
provider=model_info.provider,
api_key=model_info.api_key,
base_url=model_info.api_base,
extra_params=extra_params,
is_omni=is_omni
is_omni=model_info.is_omni
),
type=ModelType(model_type)
type=model_info.model_type
)
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
logger.debug(
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
messages_config = self.typed_config.messages
@@ -148,35 +153,40 @@ class LLMNode(BaseNode):
content_template = msg_config.content
content_template = self._render_context(content_template, variable_pool)
content = self._render_template(content_template, variable_pool)
user_id = self.get_variable("sys.user_id", variable_pool)
# 根据角色创建对应的消息对象
if role == "system":
messages.append({
"role": "system",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
"content": await self.process_message(
model_info,
content,
user_id,
self.typed_config.vision,
)
})
elif role in ["user", "human"]:
messages.append({
"role": "user",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
})
elif role in ["ai", "assistant"]:
messages.append({
"role": "assistant",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
})
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({
"role": "user",
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
})
if self.typed_config.vision_input and self.typed_config.vision:
file_content = []
files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value:
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
if content:
file_content.extend(content)
if messages and messages[-1]["role"] == 'user':
@@ -190,14 +200,19 @@ class LLMNode(BaseNode):
if isinstance(message["content"], list):
file_content = []
for file in message["content"]:
content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
if content:
file_content.extend(content)
history_message.append(
{"role": message["role"], "content": file_content}
)
else:
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
message["content"] = await self.process_message(
model_info,
message["content"],
user_id,
self.typed_config.vision
)
history_message.append(message)
messages = messages[:-1] + history_message + messages[-1:]
self.messages = messages
@@ -293,7 +308,7 @@ class LLMNode(BaseNode):
# 调用 LLM流式支持字符串或消息列表
last_meta_data = {}
async for chunk in llm.astream(self.messages, stream_usage=True):
async for chunk in llm.astream(self.messages):
# 提取内容
if hasattr(chunk, 'content'):
content = self.process_model_output(chunk.content)

View File

@@ -37,6 +37,14 @@ class ParameterExtractorNode(BaseNode):
}
return None
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {
"text": self._render_template(self.typed_config.text, variable_pool),
"prompt": self._render_template(self.typed_config.prompt, variable_pool),
"params": [param.model_dump(mode="json") for param in self.typed_config.params],
"model_id": str(self.typed_config.model_id),
}
def _output_types(self) -> dict[str, VariableType]:
outputs = {}
for param in self.typed_config.params:

61
api/app/i18n/README.md Normal file
View File

@@ -0,0 +1,61 @@
# Internationalization (i18n) Module
This module provides internationalization support for the MemoryBear API.
## Components
- `service.py` - Translation service and core translation logic
- `middleware.py` - Language detection middleware
- `dependencies.py` - FastAPI dependency injection functions
- `exceptions.py` - Internationalized exception classes
## Usage
### Basic Translation
```python
from app.i18n import t
# Simple translation
message = t("common.success.created")
# Parameterized translation
message = t("common.validation.required", field="Name")
```
### Enum Translation
```python
from app.i18n import t_enum
# Translate enum value
role_display = t_enum("workspace_role", "manager")
```
### In FastAPI Endpoints
```python
from fastapi import Depends
from app.i18n.dependencies import get_translator
@router.post("/workspaces")
async def create_workspace(
data: WorkspaceCreate,
t: Callable = Depends(get_translator)
):
workspace = await workspace_service.create(data)
return {
"success": True,
"message": t("workspace.created_successfully"),
"data": workspace
}
```
## Configuration
See `app/core/config.py` for i18n configuration options:
- `I18N_DEFAULT_LANGUAGE` - Default language (default: "zh")
- `I18N_SUPPORTED_LANGUAGES` - Supported languages (default: "zh,en")
- `I18N_ENABLE_TRANSLATION_CACHE` - Enable caching (default: true)
- `I18N_LOG_MISSING_TRANSLATIONS` - Log missing translations (default: true)

124
api/app/i18n/__init__.py Normal file
View File

@@ -0,0 +1,124 @@
"""
Internationalization (i18n) module for MemoryBear Enterprise.
This module provides complete i18n support for the backend API including:
- Translation loading from multiple directories (community + enterprise)
- Translation service with caching and fallback
- Language detection middleware
- Dependency injection for FastAPI
- Convenience functions for easy usage
Usage:
from app.i18n import t, t_enum
# Simple translation
message = t("common.success.created")
# Parameterized translation
error = t("common.validation.required", field="名称")
# Enum translation
role_display = t_enum("workspace_role", "manager")
"""
from app.i18n.dependencies import (
get_current_language,
get_enum_translator,
get_translator,
)
from app.i18n.exceptions import (
BadRequestError,
ConflictError,
FileNotFoundError,
FileTooLargeError,
ForbiddenError,
I18nException,
InternalServerError,
InvalidCredentialsError,
InvalidFileTypeError,
NotFoundError,
QuotaExceededError,
RateLimitExceededError,
ServiceUnavailableError,
TenantNotFoundError,
TenantSuspendedError,
TokenExpiredError,
TokenInvalidError,
UnauthorizedError,
UserAlreadyExistsError,
UserNotFoundError,
ValidationError,
WorkspaceNotFoundError,
WorkspacePermissionDeniedError,
get_current_locale,
set_current_locale,
)
from app.i18n.loader import TranslationLoader
from app.i18n.logger import (
TranslationLogger,
get_translation_logger,
log_missing_translation,
log_translation_error,
)
from app.i18n.middleware import LanguageMiddleware
from app.i18n.serializers import (
I18nResponseMixin,
WorkspaceSerializer,
WorkspaceMemberSerializer,
WorkspaceInviteSerializer,
)
from app.i18n.service import (
TranslationService,
get_translation_service,
t,
t_enum,
)
__all__ = [
"TranslationLoader",
"LanguageMiddleware",
"TranslationService",
"get_translation_service",
"t",
"t_enum",
"get_current_language",
"get_translator",
"get_enum_translator",
# Context management
"get_current_locale",
"set_current_locale",
# Logging
"TranslationLogger",
"get_translation_logger",
"log_missing_translation",
"log_translation_error",
# Serializers
"I18nResponseMixin",
"WorkspaceSerializer",
"WorkspaceMemberSerializer",
"WorkspaceInviteSerializer",
# Exception classes
"I18nException",
"BadRequestError",
"UnauthorizedError",
"ForbiddenError",
"NotFoundError",
"ConflictError",
"ValidationError",
"InternalServerError",
"ServiceUnavailableError",
"WorkspaceNotFoundError",
"WorkspacePermissionDeniedError",
"UserNotFoundError",
"UserAlreadyExistsError",
"TenantNotFoundError",
"TenantSuspendedError",
"InvalidCredentialsError",
"TokenExpiredError",
"TokenInvalidError",
"FileNotFoundError",
"FileTooLargeError",
"InvalidFileTypeError",
"RateLimitExceededError",
"QuotaExceededError",
]

291
api/app/i18n/cache.py Normal file
View File

@@ -0,0 +1,291 @@
"""
Advanced caching system for i18n translations.
This module provides:
- LRU cache for hot translations
- Lazy loading mechanism
- Memory optimization
- Cache statistics
"""
import logging
from functools import lru_cache
from typing import Any, Dict, Optional
from collections import OrderedDict
import time
logger = logging.getLogger(__name__)
class TranslationCache:
"""
Advanced translation cache with LRU eviction and lazy loading.
Features:
- LRU cache for frequently accessed translations
- Lazy loading to reduce startup time
- Memory-efficient storage
- Cache hit/miss statistics
"""
def __init__(self, max_lru_size: int = 1000, enable_lazy_load: bool = True):
"""
Initialize the translation cache.
Args:
max_lru_size: Maximum size of LRU cache for hot translations
enable_lazy_load: Enable lazy loading of locales
"""
self.max_lru_size = max_lru_size
self.enable_lazy_load = enable_lazy_load
# Main cache: {locale: {namespace: {key: value}}}
self._main_cache: Dict[str, Dict[str, Any]] = {}
# LRU cache for hot translations
self._lru_cache: OrderedDict = OrderedDict()
# Loaded locales tracker
self._loaded_locales: set = set()
# Statistics
self._stats = {
"hits": 0,
"misses": 0,
"lru_hits": 0,
"lru_misses": 0,
"lazy_loads": 0
}
logger.info(
f"TranslationCache initialized with LRU size: {max_lru_size}, "
f"lazy loading: {enable_lazy_load}"
)
def set_locale_data(self, locale: str, data: Dict[str, Any]):
"""
Set translation data for a locale.
Args:
locale: Locale code
data: Translation data dictionary
"""
self._main_cache[locale] = data
self._loaded_locales.add(locale)
logger.debug(f"Loaded locale '{locale}' into cache")
def get_translation(
self,
locale: str,
namespace: str,
key_path: list
) -> Optional[str]:
"""
Get translation from cache with LRU optimization.
Args:
locale: Locale code
namespace: Translation namespace
key_path: List of nested keys
Returns:
Translation string or None if not found
"""
# Build cache key for LRU
cache_key = f"{locale}:{namespace}:{'.'.join(key_path)}"
# Check LRU cache first (hot translations)
if cache_key in self._lru_cache:
self._stats["lru_hits"] += 1
self._stats["hits"] += 1
# Move to end (most recently used)
self._lru_cache.move_to_end(cache_key)
return self._lru_cache[cache_key]
self._stats["lru_misses"] += 1
# Check main cache
if locale not in self._main_cache:
self._stats["misses"] += 1
return None
if namespace not in self._main_cache[locale]:
self._stats["misses"] += 1
return None
# Navigate through nested keys
current = self._main_cache[locale][namespace]
for key in key_path:
if isinstance(current, dict) and key in current:
current = current[key]
else:
self._stats["misses"] += 1
return None
# Return only if it's a string value
if not isinstance(current, str):
self._stats["misses"] += 1
return None
self._stats["hits"] += 1
# Add to LRU cache
self._add_to_lru(cache_key, current)
return current
def _add_to_lru(self, key: str, value: str):
"""
Add translation to LRU cache.
Args:
key: Cache key
value: Translation value
"""
# Remove oldest if cache is full
if len(self._lru_cache) >= self.max_lru_size:
self._lru_cache.popitem(last=False)
self._lru_cache[key] = value
def is_locale_loaded(self, locale: str) -> bool:
"""
Check if a locale is loaded.
Args:
locale: Locale code
Returns:
True if locale is loaded
"""
return locale in self._loaded_locales
def get_loaded_locales(self) -> list:
"""
Get list of loaded locales.
Returns:
List of locale codes
"""
return list(self._loaded_locales)
def clear_lru(self):
"""Clear the LRU cache."""
self._lru_cache.clear()
logger.info("LRU cache cleared")
def clear_locale(self, locale: str):
"""
Clear cache for a specific locale.
Args:
locale: Locale code
"""
if locale in self._main_cache:
del self._main_cache[locale]
self._loaded_locales.discard(locale)
# Clear related LRU entries
keys_to_remove = [k for k in self._lru_cache if k.startswith(f"{locale}:")]
for key in keys_to_remove:
del self._lru_cache[key]
logger.info(f"Cleared cache for locale '{locale}'")
def clear_all(self):
"""Clear all caches."""
self._main_cache.clear()
self._lru_cache.clear()
self._loaded_locales.clear()
logger.info("All caches cleared")
def get_stats(self) -> Dict[str, Any]:
"""
Get cache statistics.
Returns:
Dictionary with cache statistics
"""
total_requests = self._stats["hits"] + self._stats["misses"]
hit_rate = (
self._stats["hits"] / total_requests * 100
if total_requests > 0
else 0
)
lru_total = self._stats["lru_hits"] + self._stats["lru_misses"]
lru_hit_rate = (
self._stats["lru_hits"] / lru_total * 100
if lru_total > 0
else 0
)
return {
"total_requests": total_requests,
"hits": self._stats["hits"],
"misses": self._stats["misses"],
"hit_rate": round(hit_rate, 2),
"lru_hits": self._stats["lru_hits"],
"lru_misses": self._stats["lru_misses"],
"lru_hit_rate": round(lru_hit_rate, 2),
"lru_size": len(self._lru_cache),
"lru_max_size": self.max_lru_size,
"loaded_locales": len(self._loaded_locales),
"lazy_loads": self._stats["lazy_loads"]
}
def reset_stats(self):
"""Reset cache statistics."""
self._stats = {
"hits": 0,
"misses": 0,
"lru_hits": 0,
"lru_misses": 0,
"lazy_loads": 0
}
logger.info("Cache statistics reset")
def get_memory_usage(self) -> Dict[str, Any]:
"""
Estimate memory usage of the cache.
Returns:
Dictionary with memory usage information
"""
import sys
main_cache_size = sys.getsizeof(self._main_cache)
lru_cache_size = sys.getsizeof(self._lru_cache)
# Rough estimate of nested data
for locale_data in self._main_cache.values():
main_cache_size += sys.getsizeof(locale_data)
for namespace_data in locale_data.values():
main_cache_size += sys.getsizeof(namespace_data)
return {
"main_cache_bytes": main_cache_size,
"lru_cache_bytes": lru_cache_size,
"total_bytes": main_cache_size + lru_cache_size,
"main_cache_mb": round(main_cache_size / 1024 / 1024, 2),
"lru_cache_mb": round(lru_cache_size / 1024 / 1024, 2),
"total_mb": round((main_cache_size + lru_cache_size) / 1024 / 1024, 2)
}
@lru_cache(maxsize=128)
def get_cached_translation_key(locale: str, namespace: str, key: str) -> str:
"""
LRU cached function for building translation cache keys.
This reduces string concatenation overhead for frequently accessed keys.
Args:
locale: Locale code
namespace: Translation namespace
key: Translation key
Returns:
Cache key string
"""
return f"{locale}:{namespace}:{key}"

View File

@@ -0,0 +1,158 @@
"""
FastAPI dependency injection functions for i18n.
This module provides dependency injection functions that can be used
in FastAPI route handlers to access the current language and translator.
"""
import logging
from typing import Callable
from fastapi import Request
from app.i18n.service import get_translation_service
logger = logging.getLogger(__name__)
async def get_current_language(request: Request) -> str:
"""
Get the current language from the request context.
This dependency extracts the language that was determined by the
LanguageMiddleware and stored in request.state.
Args:
request: FastAPI request object
Returns:
Language code (e.g., "zh", "en")
Usage:
@router.get("/example")
async def example(language: str = Depends(get_current_language)):
return {"language": language}
"""
# Get language from request state (set by LanguageMiddleware)
language = getattr(request.state, "language", None)
if language is None:
# Fallback to default language if not set
from app.core.config import settings
language = settings.I18N_DEFAULT_LANGUAGE
logger.warning(
"Language not found in request.state, using default: "
f"{language}"
)
return language
async def get_translator(request: Request) -> Callable:
"""
Get a translator function bound to the current request's language.
This dependency returns a translation function that automatically
uses the current request's language, making it easy to translate
strings in route handlers.
Args:
request: FastAPI request object
Returns:
Translation function with signature: t(key: str, **params) -> str
Usage:
@router.post("/workspaces")
async def create_workspace(
data: WorkspaceCreate,
t: Callable = Depends(get_translator)
):
workspace = await workspace_service.create(data)
return {
"success": True,
"message": t("workspace.created_successfully"),
"data": workspace
}
# With parameters
@router.get("/items")
async def get_items(t: Callable = Depends(get_translator)):
count = 5
return {
"message": t("items.found", count=count)
}
"""
# Get current language
language = await get_current_language(request)
# Get translation service
service = get_translation_service()
# Return a bound translation function
def translate(key: str, **params) -> str:
"""
Translate a key using the current request's language.
Args:
key: Translation key (e.g., "common.success.created")
**params: Parameters for parameterized messages
Returns:
Translated string
"""
return service.translate(key, language, **params)
return translate
async def get_enum_translator(request: Request) -> Callable:
"""
Get an enum translator function bound to the current request's language.
This dependency returns a function for translating enum values
that automatically uses the current request's language.
Args:
request: FastAPI request object
Returns:
Enum translation function with signature:
t_enum(enum_type: str, value: str) -> str
Usage:
@router.get("/workspace/{id}")
async def get_workspace(
id: str,
t_enum: Callable = Depends(get_enum_translator)
):
workspace = await workspace_service.get(id)
return {
"id": workspace.id,
"role": workspace.role,
"role_display": t_enum("workspace_role", workspace.role),
"status": workspace.status,
"status_display": t_enum("workspace_status", workspace.status)
}
"""
# Get current language
language = await get_current_language(request)
# Get translation service
service = get_translation_service()
# Return a bound enum translation function
def translate_enum(enum_type: str, value: str) -> str:
"""
Translate an enum value using the current request's language.
Args:
enum_type: Enum type name (e.g., "workspace_role")
value: Enum value (e.g., "manager")
Returns:
Translated enum display name
"""
return service.translate_enum(enum_type, value, language)
return translate_enum

495
api/app/i18n/exceptions.py Normal file
View File

@@ -0,0 +1,495 @@
"""
Internationalized exception classes for i18n system.
This module provides exception classes that automatically translate
error messages based on the current request's language.
"""
import logging
from contextvars import ContextVar
from typing import Any, Dict, Optional
from fastapi import HTTPException, Request
from app.i18n.service import get_translation_service
logger = logging.getLogger(__name__)
# Context variable to store current locale
_current_locale: ContextVar[Optional[str]] = ContextVar("current_locale", default=None)
def set_current_locale(locale: str) -> None:
"""
Set the current locale in the context variable.
This should be called by the LanguageMiddleware.
Args:
locale: Locale code (e.g., "zh", "en")
"""
_current_locale.set(locale)
def get_current_locale() -> Optional[str]:
"""
Get the current locale from the context variable.
Returns:
Locale code or None if not set
"""
return _current_locale.get()
class I18nException(HTTPException):
"""
Base exception class with automatic i18n support.
This exception automatically translates error messages based on:
1. The current request's language (from request.state.language)
2. The fallback language if request language is not available
3. The error key itself if no translation is found
Features:
- Automatic error message translation
- Parameterized error messages support
- Consistent error response format
- Language-aware error handling
Usage:
# Simple error
raise I18nException(
error_key="errors.workspace.not_found",
status_code=404
)
# Error with parameters
raise I18nException(
error_key="errors.validation.missing_field",
status_code=400,
field="name"
)
# Custom error code
raise I18nException(
error_key="errors.workspace.not_found",
error_code="WORKSPACE_NOT_FOUND",
status_code=404,
workspace_id="123"
)
"""
def __init__(
self,
error_key: str,
status_code: int = 400,
error_code: Optional[str] = None,
locale: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
**params
):
"""
Initialize the i18n exception.
Args:
error_key: Translation key for the error message
(e.g., "errors.workspace.not_found")
status_code: HTTP status code (default: 400)
error_code: Custom error code for API clients
(default: derived from error_key)
locale: Target locale for translation (optional)
If not provided, uses current request's language
headers: Additional HTTP headers
**params: Parameters for parameterized error messages
"""
self.error_key = error_key
self.error_code = error_code or self._generate_error_code(error_key)
self.params = params
# Get locale from request context if not provided
if locale is None:
locale = self._get_current_locale()
# Translate error message
translation_service = get_translation_service()
message = translation_service.translate(
error_key,
locale,
**params
)
# Build error detail
detail = {
"error_code": self.error_code,
"message": message,
}
# Add parameters to detail if provided
if params:
detail["params"] = params
# Initialize HTTPException
super().__init__(
status_code=status_code,
detail=detail,
headers=headers
)
logger.debug(
f"I18nException raised: {self.error_code} "
f"(key: {error_key}, locale: {locale})"
)
def _get_current_locale(self) -> str:
"""
Get the current locale from request context.
Returns:
Locale code (e.g., "zh", "en")
"""
try:
# Try to get locale from context variable
locale = _current_locale.get()
if locale:
return locale
except Exception as e:
logger.debug(f"Could not get locale from context: {e}")
# Fallback to default locale
from app.core.config import settings
return settings.I18N_DEFAULT_LANGUAGE
def _generate_error_code(self, error_key: str) -> str:
"""
Generate error code from error key.
Converts "errors.workspace.not_found" to "WORKSPACE_NOT_FOUND"
Args:
error_key: Translation key
Returns:
Error code in UPPER_SNAKE_CASE
"""
# Remove "errors." prefix if present
if error_key.startswith("errors."):
error_key = error_key[7:]
# Convert to UPPER_SNAKE_CASE
parts = error_key.split(".")
return "_".join(parts).upper()
# Specific exception classes for common errors
class BadRequestError(I18nException):
"""Bad request error (400)."""
def __init__(
self,
error_key: str = "errors.common.bad_request",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=400,
error_code=error_code,
**params
)
class UnauthorizedError(I18nException):
"""Unauthorized error (401)."""
def __init__(
self,
error_key: str = "errors.auth.unauthorized",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=401,
error_code=error_code,
**params
)
class ForbiddenError(I18nException):
"""Forbidden error (403)."""
def __init__(
self,
error_key: str = "errors.auth.forbidden",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=403,
error_code=error_code,
**params
)
class NotFoundError(I18nException):
"""Not found error (404)."""
def __init__(
self,
error_key: str = "errors.common.not_found",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=404,
error_code=error_code,
**params
)
class ConflictError(I18nException):
"""Conflict error (409)."""
def __init__(
self,
error_key: str = "errors.common.conflict",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=409,
error_code=error_code,
**params
)
class ValidationError(I18nException):
"""Validation error (422)."""
def __init__(
self,
error_key: str = "errors.common.validation_failed",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=422,
error_code=error_code,
**params
)
class InternalServerError(I18nException):
"""Internal server error (500)."""
def __init__(
self,
error_key: str = "errors.common.internal_error",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=500,
error_code=error_code,
**params
)
class ServiceUnavailableError(I18nException):
"""Service unavailable error (503)."""
def __init__(
self,
error_key: str = "errors.common.service_unavailable",
error_code: Optional[str] = None,
**params
):
super().__init__(
error_key=error_key,
status_code=503,
error_code=error_code,
**params
)
# Domain-specific exception classes
class WorkspaceNotFoundError(NotFoundError):
"""Workspace not found error."""
def __init__(self, workspace_id: Optional[str] = None, **params):
if workspace_id:
params["workspace_id"] = workspace_id
super().__init__(
error_key="errors.workspace.not_found",
error_code="WORKSPACE_NOT_FOUND",
**params
)
class WorkspacePermissionDeniedError(ForbiddenError):
"""Workspace permission denied error."""
def __init__(self, workspace_id: Optional[str] = None, **params):
if workspace_id:
params["workspace_id"] = workspace_id
super().__init__(
error_key="errors.workspace.permission_denied",
error_code="WORKSPACE_PERMISSION_DENIED",
**params
)
class UserNotFoundError(NotFoundError):
"""User not found error."""
def __init__(self, user_id: Optional[str] = None, **params):
if user_id:
params["user_id"] = user_id
super().__init__(
error_key="errors.user.not_found",
error_code="USER_NOT_FOUND",
**params
)
class UserAlreadyExistsError(ConflictError):
"""User already exists error."""
def __init__(self, identifier: Optional[str] = None, **params):
if identifier:
params["identifier"] = identifier
super().__init__(
error_key="errors.user.already_exists",
error_code="USER_ALREADY_EXISTS",
**params
)
class TenantNotFoundError(NotFoundError):
"""Tenant not found error."""
def __init__(self, tenant_id: Optional[str] = None, **params):
if tenant_id:
params["tenant_id"] = tenant_id
super().__init__(
error_key="errors.tenant.not_found",
error_code="TENANT_NOT_FOUND",
**params
)
class TenantSuspendedError(ForbiddenError):
"""Tenant suspended error."""
def __init__(self, tenant_id: Optional[str] = None, **params):
if tenant_id:
params["tenant_id"] = tenant_id
super().__init__(
error_key="errors.tenant.suspended",
error_code="TENANT_SUSPENDED",
**params
)
class InvalidCredentialsError(UnauthorizedError):
"""Invalid credentials error."""
def __init__(self, **params):
super().__init__(
error_key="errors.auth.invalid_credentials",
error_code="INVALID_CREDENTIALS",
**params
)
class TokenExpiredError(UnauthorizedError):
"""Token expired error."""
def __init__(self, **params):
super().__init__(
error_key="errors.auth.token_expired",
error_code="TOKEN_EXPIRED",
**params
)
class TokenInvalidError(UnauthorizedError):
"""Token invalid error."""
def __init__(self, **params):
super().__init__(
error_key="errors.auth.token_invalid",
error_code="TOKEN_INVALID",
**params
)
class FileNotFoundError(NotFoundError):
"""File not found error."""
def __init__(self, file_id: Optional[str] = None, **params):
if file_id:
params["file_id"] = file_id
super().__init__(
error_key="errors.file.not_found",
error_code="FILE_NOT_FOUND",
**params
)
class FileTooLargeError(BadRequestError):
"""File too large error."""
def __init__(self, max_size: Optional[str] = None, **params):
if max_size:
params["max_size"] = max_size
super().__init__(
error_key="errors.file.too_large",
error_code="FILE_TOO_LARGE",
**params
)
class InvalidFileTypeError(BadRequestError):
"""Invalid file type error."""
def __init__(self, file_type: Optional[str] = None, **params):
if file_type:
params["file_type"] = file_type
super().__init__(
error_key="errors.file.invalid_type",
error_code="INVALID_FILE_TYPE",
**params
)
class RateLimitExceededError(I18nException):
"""Rate limit exceeded error (429)."""
def __init__(self, **params):
super().__init__(
error_key="errors.api.rate_limit_exceeded",
status_code=429,
error_code="RATE_LIMIT_EXCEEDED",
**params
)
class QuotaExceededError(ForbiddenError):
"""Quota exceeded error."""
def __init__(self, resource: Optional[str] = None, **params):
if resource:
params["resource"] = resource
super().__init__(
error_key="errors.api.quota_exceeded",
error_code="QUOTA_EXCEEDED",
**params
)

199
api/app/i18n/loader.py Normal file
View File

@@ -0,0 +1,199 @@
"""
Translation file loader for i18n system.
This module handles loading translation files from multiple directories
(community edition + enterprise edition) and provides hot reload support.
"""
import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class TranslationLoader:
"""
Translation file loader that supports:
- Loading from multiple directories (community + enterprise)
- Hot reload of translation files
- Automatic locale detection
"""
def __init__(self, locales_dirs: Optional[List[str]] = None):
"""
Initialize the translation loader.
Args:
locales_dirs: List of directories containing translation files.
If None, will auto-detect from settings.
"""
if locales_dirs is None:
locales_dirs = self._detect_locales_dirs()
self.locales_dirs = [Path(d) for d in locales_dirs]
logger.info(f"TranslationLoader initialized with directories: {self.locales_dirs}")
def _detect_locales_dirs(self) -> List[str]:
"""
Auto-detect translation directories from settings.
Returns:
List of translation directory paths
"""
from app.core.config import settings
dirs = []
# 1. Core locales directory (community edition, required)
core_dir = Path(settings.I18N_CORE_LOCALES_DIR)
if core_dir.exists():
dirs.append(str(core_dir))
logger.debug(f"Found core locales directory: {core_dir}")
else:
logger.warning(f"Core locales directory not found: {core_dir}")
# 2. Premium locales directory (enterprise edition, optional)
if settings.I18N_PREMIUM_LOCALES_DIR:
premium_dir = Path(settings.I18N_PREMIUM_LOCALES_DIR)
if premium_dir.exists():
dirs.append(str(premium_dir))
logger.debug(f"Found premium locales directory: {premium_dir}")
else:
# Auto-detect premium directory
premium_dir = Path("premium/locales")
if premium_dir.exists():
dirs.append(str(premium_dir))
logger.debug(f"Auto-detected premium locales directory: {premium_dir}")
if not dirs:
logger.error("No translation directories found!")
return dirs
def get_available_locales(self) -> List[str]:
"""
Get list of all available locales across all directories.
Returns:
List of locale codes (e.g., ['zh', 'en'])
"""
locales = set()
for locales_dir in self.locales_dirs:
if not locales_dir.exists():
continue
for locale_dir in locales_dir.iterdir():
if locale_dir.is_dir() and not locale_dir.name.startswith('.'):
locales.add(locale_dir.name)
return sorted(list(locales))
def load_locale(self, locale: str) -> Dict[str, Any]:
"""
Load all translation files for a specific locale from all directories.
Translation files are merged with priority:
- Later directories override earlier directories
- Enterprise translations override community translations
Args:
locale: Locale code (e.g., 'zh', 'en')
Returns:
Dictionary of translations organized by namespace
Format: {namespace: {key: value, ...}, ...}
"""
translations = {}
# Load from each directory in order (later directories override earlier)
for locales_dir in self.locales_dirs:
locale_dir = locales_dir / locale
if not locale_dir.exists():
logger.debug(f"Locale directory not found: {locale_dir}")
continue
# Load all JSON files in this locale directory
for json_file in locale_dir.glob("*.json"):
namespace = json_file.stem
try:
with open(json_file, "r", encoding="utf-8") as f:
new_translations = json.load(f)
# Merge translations (deep merge)
if namespace in translations:
translations[namespace] = self._deep_merge(
translations[namespace],
new_translations
)
logger.debug(
f"Merged translations: {locale}/{namespace} from {json_file}"
)
else:
translations[namespace] = new_translations
logger.debug(
f"Loaded translations: {locale}/{namespace} from {json_file}"
)
except json.JSONDecodeError as e:
logger.error(
f"Failed to parse JSON file {json_file}: {e}"
)
except Exception as e:
logger.error(
f"Failed to load translation file {json_file}: {e}"
)
if not translations:
logger.warning(f"No translations found for locale: {locale}")
return translations
def reload(self, locale: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
"""
Reload translation files.
Args:
locale: Specific locale to reload. If None, reloads all locales.
Returns:
Dictionary of reloaded translations
Format: {locale: {namespace: {key: value}}}
"""
if locale:
logger.info(f"Reloading translations for locale: {locale}")
return {locale: self.load_locale(locale)}
else:
logger.info("Reloading all translations")
all_translations = {}
for loc in self.get_available_locales():
all_translations[loc] = self.load_locale(loc)
return all_translations
def _deep_merge(self, base: Dict, override: Dict) -> Dict:
"""
Deep merge two dictionaries.
Args:
base: Base dictionary
override: Dictionary with values to override
Returns:
Merged dictionary
"""
result = base.copy()
for key, value in override.items():
if (
key in result
and isinstance(result[key], dict)
and isinstance(value, dict)
):
result[key] = self._deep_merge(result[key], value)
else:
result[key] = value
return result

382
api/app/i18n/logger.py Normal file
View File

@@ -0,0 +1,382 @@
"""
Translation logging for i18n system.
This module provides:
- TranslationLogger for recording missing translations
- Missing translation report generation
- Integration with existing logging system
- Structured logging for translation events
"""
import logging
from typing import Dict, List, Optional, Set
from datetime import datetime
from collections import defaultdict
from pathlib import Path
import json
from app.core.logging_config import get_logger
logger = get_logger(__name__)
class TranslationLogger:
"""
Logger for translation events and missing translations.
Features:
- Records missing translations with context
- Generates missing translation reports
- Integrates with existing logging system
- Provides structured logging for analysis
"""
def __init__(self, log_file: Optional[str] = None):
"""
Initialize translation logger.
Args:
log_file: Optional custom log file path for missing translations
"""
self.log_file = log_file or "logs/i18n/missing_translations.log"
self._missing_translations: Dict[str, Set[str]] = defaultdict(set)
self._missing_with_context: List[Dict] = []
self._max_context_entries = 10000 # Keep last 10k entries
# Ensure log directory exists
log_path = Path(self.log_file)
log_path.parent.mkdir(parents=True, exist_ok=True)
# Create dedicated file handler for missing translations
self._file_handler = logging.FileHandler(
self.log_file,
encoding='utf-8'
)
self._file_handler.setLevel(logging.WARNING)
# Create formatter
formatter = logging.Formatter(
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
self._file_handler.setFormatter(formatter)
# Create dedicated logger for missing translations
self._logger = logging.getLogger("i18n.missing_translations")
self._logger.setLevel(logging.WARNING)
self._logger.addHandler(self._file_handler)
self._logger.propagate = False # Don't propagate to root logger
logger.info(f"TranslationLogger initialized with log file: {self.log_file}")
def log_missing_translation(
self,
key: str,
locale: str,
context: Optional[Dict] = None
):
"""
Log a missing translation.
Args:
key: Translation key that was not found
locale: Locale code
context: Optional context information (e.g., request path, user info)
"""
# Add to missing set
self._missing_translations[locale].add(key)
# Create context entry
entry = {
"timestamp": datetime.now().isoformat(),
"key": key,
"locale": locale,
"context": context or {}
}
# Keep only recent entries to avoid memory bloat
if len(self._missing_with_context) >= self._max_context_entries:
self._missing_with_context.pop(0)
self._missing_with_context.append(entry)
# Log to file
context_str = f" (context: {context})" if context else ""
self._logger.warning(
f"Missing translation: key='{key}', locale='{locale}'{context_str}"
)
def log_translation_error(
self,
error_type: str,
message: str,
key: Optional[str] = None,
locale: Optional[str] = None,
context: Optional[Dict] = None
):
"""
Log a translation error.
Args:
error_type: Type of error (e.g., "format_error", "parameter_missing")
message: Error message
key: Translation key (optional)
locale: Locale code (optional)
context: Optional context information
"""
error_data = {
"error_type": error_type,
"message": message,
"key": key,
"locale": locale,
"context": context or {},
"timestamp": datetime.now().isoformat()
}
self._logger.error(
f"Translation error: {error_type} - {message} "
f"(key: {key}, locale: {locale})"
)
def log_translation_success(
self,
key: str,
locale: str,
duration_ms: Optional[float] = None
):
"""
Log a successful translation (debug level).
Args:
key: Translation key
locale: Locale code
duration_ms: Optional duration in milliseconds
"""
duration_str = f" ({duration_ms:.3f}ms)" if duration_ms else ""
logger.debug(
f"Translation success: key='{key}', locale='{locale}'{duration_str}"
)
def get_missing_translations(
self,
locale: Optional[str] = None
) -> Dict[str, List[str]]:
"""
Get missing translations.
Args:
locale: Specific locale (optional, returns all if None)
Returns:
Dictionary of missing translations by locale
"""
if locale:
return {locale: sorted(list(self._missing_translations.get(locale, set())))}
return {
loc: sorted(list(keys))
for loc, keys in self._missing_translations.items()
}
def get_missing_with_context(
self,
locale: Optional[str] = None,
limit: Optional[int] = None
) -> List[Dict]:
"""
Get missing translations with context.
Args:
locale: Filter by locale (optional)
limit: Maximum number of entries to return (optional)
Returns:
List of missing translation entries with context
"""
entries = self._missing_with_context
# Filter by locale if specified
if locale:
entries = [e for e in entries if e["locale"] == locale]
# Apply limit if specified
if limit:
entries = entries[-limit:]
return entries
def generate_report(
self,
locale: Optional[str] = None,
output_file: Optional[str] = None
) -> Dict:
"""
Generate a missing translation report.
Args:
locale: Specific locale (optional, generates for all if None)
output_file: Optional file path to save report as JSON
Returns:
Report dictionary
"""
missing = self.get_missing_translations(locale)
report = {
"generated_at": datetime.now().isoformat(),
"total_missing": sum(len(keys) for keys in missing.values()),
"missing_by_locale": {
loc: {
"count": len(keys),
"keys": keys
}
for loc, keys in missing.items()
},
"recent_context": self.get_missing_with_context(locale, limit=100)
}
# Save to file if specified
if output_file:
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False)
logger.info(f"Missing translation report saved to: {output_file}")
return report
def get_statistics(self) -> Dict:
"""
Get statistics about missing translations.
Returns:
Dictionary with statistics
"""
total_missing = sum(len(keys) for keys in self._missing_translations.values())
# Count by namespace
namespace_counts = defaultdict(int)
for locale, keys in self._missing_translations.items():
for key in keys:
namespace = key.split('.')[0] if '.' in key else 'unknown'
namespace_counts[namespace] += 1
return {
"total_missing": total_missing,
"locales_affected": len(self._missing_translations),
"missing_by_locale": {
loc: len(keys)
for loc, keys in self._missing_translations.items()
},
"missing_by_namespace": dict(namespace_counts),
"total_context_entries": len(self._missing_with_context)
}
def clear(self, locale: Optional[str] = None):
"""
Clear missing translation records.
Args:
locale: Specific locale to clear (optional, clears all if None)
"""
if locale:
self._missing_translations.pop(locale, None)
self._missing_with_context = [
e for e in self._missing_with_context
if e["locale"] != locale
]
logger.info(f"Cleared missing translations for locale: {locale}")
else:
self._missing_translations.clear()
self._missing_with_context.clear()
logger.info("Cleared all missing translations")
def export_to_json(self, output_file: str):
"""
Export all missing translations to JSON file.
Args:
output_file: Output file path
"""
data = {
"exported_at": datetime.now().isoformat(),
"missing_translations": self.get_missing_translations(),
"statistics": self.get_statistics(),
"recent_context": self.get_missing_with_context(limit=1000)
}
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
logger.info(f"Missing translations exported to: {output_file}")
def __del__(self):
"""Cleanup file handler on deletion."""
try:
if hasattr(self, '_file_handler'):
self._file_handler.close()
self._logger.removeHandler(self._file_handler)
except Exception:
pass
# Global translation logger instance
_translation_logger: Optional[TranslationLogger] = None
def get_translation_logger() -> TranslationLogger:
"""
Get the global translation logger instance.
Returns:
TranslationLogger singleton
"""
global _translation_logger
if _translation_logger is None:
_translation_logger = TranslationLogger()
return _translation_logger
def log_missing_translation(
key: str,
locale: str,
context: Optional[Dict] = None
):
"""
Log a missing translation (convenience function).
Args:
key: Translation key
locale: Locale code
context: Optional context information
"""
translation_logger = get_translation_logger()
translation_logger.log_missing_translation(key, locale, context)
def log_translation_error(
error_type: str,
message: str,
key: Optional[str] = None,
locale: Optional[str] = None,
context: Optional[Dict] = None
):
"""
Log a translation error (convenience function).
Args:
error_type: Type of error
message: Error message
key: Translation key (optional)
locale: Locale code (optional)
context: Optional context information
"""
translation_logger = get_translation_logger()
translation_logger.log_translation_error(
error_type, message, key, locale, context
)

337
api/app/i18n/metrics.py Normal file
View File

@@ -0,0 +1,337 @@
"""
Performance monitoring and metrics for i18n system.
This module provides:
- Translation request counters
- Translation timing metrics
- Missing translation tracking
- Performance monitoring decorators
- Prometheus-compatible metrics
"""
import logging
import time
from functools import wraps
from typing import Any, Callable, Dict, Optional
from collections import defaultdict
from datetime import datetime
logger = logging.getLogger(__name__)
class TranslationMetrics:
"""
Metrics collector for translation operations.
Tracks:
- Translation request counts
- Translation timing (latency)
- Missing translations
- Cache performance
- Locale usage
"""
def __init__(self):
"""Initialize metrics collector."""
# Request counters by locale
self._request_counts: Dict[str, int] = defaultdict(int)
# Missing translation tracker
self._missing_translations: Dict[str, set] = defaultdict(set)
# Timing metrics (in milliseconds)
self._timing_data: list = []
self._max_timing_samples = 10000 # Keep last 10k samples
# Locale usage
self._locale_usage: Dict[str, int] = defaultdict(int)
# Namespace usage
self._namespace_usage: Dict[str, int] = defaultdict(int)
# Error counts
self._error_counts: Dict[str, int] = defaultdict(int)
# Start time
self._start_time = datetime.now()
logger.info("TranslationMetrics initialized")
def record_request(self, locale: str, namespace: str = None):
"""
Record a translation request.
Args:
locale: Locale code
namespace: Translation namespace (optional)
"""
self._request_counts[locale] += 1
self._locale_usage[locale] += 1
if namespace:
self._namespace_usage[namespace] += 1
def record_missing(self, key: str, locale: str):
"""
Record a missing translation.
Args:
key: Translation key
locale: Locale code
"""
self._missing_translations[locale].add(key)
logger.debug(f"Missing translation recorded: {key} (locale: {locale})")
def record_timing(self, duration_ms: float, locale: str, operation: str = "translate"):
"""
Record translation operation timing.
Args:
duration_ms: Duration in milliseconds
locale: Locale code
operation: Operation type
"""
# Keep only recent samples to avoid memory bloat
if len(self._timing_data) >= self._max_timing_samples:
self._timing_data.pop(0)
self._timing_data.append({
"duration_ms": duration_ms,
"locale": locale,
"operation": operation,
"timestamp": time.time()
})
def record_error(self, error_type: str):
"""
Record an error.
Args:
error_type: Type of error
"""
self._error_counts[error_type] += 1
def get_summary(self) -> Dict[str, Any]:
"""
Get metrics summary.
Returns:
Dictionary with metrics summary
"""
total_requests = sum(self._request_counts.values())
total_missing = sum(len(keys) for keys in self._missing_translations.values())
# Calculate timing statistics
timing_stats = self._calculate_timing_stats()
# Calculate uptime
uptime_seconds = (datetime.now() - self._start_time).total_seconds()
return {
"uptime_seconds": round(uptime_seconds, 2),
"total_requests": total_requests,
"requests_per_locale": dict(self._request_counts),
"total_missing_translations": total_missing,
"missing_by_locale": {
locale: len(keys)
for locale, keys in self._missing_translations.items()
},
"timing": timing_stats,
"locale_usage": dict(self._locale_usage),
"namespace_usage": dict(self._namespace_usage),
"error_counts": dict(self._error_counts)
}
def _calculate_timing_stats(self) -> Dict[str, Any]:
"""
Calculate timing statistics.
Returns:
Dictionary with timing statistics
"""
if not self._timing_data:
return {
"count": 0,
"avg_ms": 0,
"min_ms": 0,
"max_ms": 0,
"p50_ms": 0,
"p95_ms": 0,
"p99_ms": 0
}
durations = [d["duration_ms"] for d in self._timing_data]
durations.sort()
count = len(durations)
avg = sum(durations) / count
# Calculate percentiles
p50_idx = int(count * 0.50)
p95_idx = int(count * 0.95)
p99_idx = int(count * 0.99)
return {
"count": count,
"avg_ms": round(avg, 3),
"min_ms": round(durations[0], 3),
"max_ms": round(durations[-1], 3),
"p50_ms": round(durations[p50_idx], 3),
"p95_ms": round(durations[p95_idx], 3),
"p99_ms": round(durations[p99_idx], 3)
}
def get_missing_translations(self, locale: Optional[str] = None) -> Dict[str, list]:
"""
Get missing translations.
Args:
locale: Specific locale (optional, returns all if None)
Returns:
Dictionary of missing translations by locale
"""
if locale:
return {locale: list(self._missing_translations.get(locale, set()))}
return {
locale: list(keys)
for locale, keys in self._missing_translations.items()
}
def reset(self):
"""Reset all metrics."""
self._request_counts.clear()
self._missing_translations.clear()
self._timing_data.clear()
self._locale_usage.clear()
self._namespace_usage.clear()
self._error_counts.clear()
self._start_time = datetime.now()
logger.info("Metrics reset")
def export_prometheus(self) -> str:
"""
Export metrics in Prometheus format.
Returns:
Prometheus-formatted metrics string
"""
lines = []
# Translation requests counter
lines.append("# HELP i18n_translation_requests_total Total number of translation requests")
lines.append("# TYPE i18n_translation_requests_total counter")
for locale, count in self._request_counts.items():
lines.append(f'i18n_translation_requests_total{{locale="{locale}"}} {count}')
# Missing translations counter
lines.append("# HELP i18n_missing_translations_total Total number of missing translations")
lines.append("# TYPE i18n_missing_translations_total counter")
for locale, keys in self._missing_translations.items():
lines.append(f'i18n_missing_translations_total{{locale="{locale}"}} {len(keys)}')
# Timing metrics
timing_stats = self._calculate_timing_stats()
lines.append("# HELP i18n_translation_duration_ms Translation operation duration in milliseconds")
lines.append("# TYPE i18n_translation_duration_ms summary")
lines.append(f'i18n_translation_duration_ms{{quantile="0.5"}} {timing_stats["p50_ms"]}')
lines.append(f'i18n_translation_duration_ms{{quantile="0.95"}} {timing_stats["p95_ms"]}')
lines.append(f'i18n_translation_duration_ms{{quantile="0.99"}} {timing_stats["p99_ms"]}')
lines.append(f'i18n_translation_duration_ms_sum {sum(d["duration_ms"] for d in self._timing_data)}')
lines.append(f'i18n_translation_duration_ms_count {timing_stats["count"]}')
# Error counter
lines.append("# HELP i18n_errors_total Total number of i18n errors")
lines.append("# TYPE i18n_errors_total counter")
for error_type, count in self._error_counts.items():
lines.append(f'i18n_errors_total{{type="{error_type}"}} {count}')
return "\n".join(lines)
# Global metrics instance
_metrics: Optional[TranslationMetrics] = None
def get_metrics() -> TranslationMetrics:
"""
Get the global metrics instance.
Returns:
TranslationMetrics singleton
"""
global _metrics
if _metrics is None:
_metrics = TranslationMetrics()
return _metrics
def monitor_performance(operation: str = "translate"):
"""
Decorator to monitor translation operation performance.
Args:
operation: Operation name for metrics
Returns:
Decorated function
Example:
@monitor_performance("translate")
def translate(key: str, locale: str) -> str:
...
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.perf_counter()
try:
result = func(*args, **kwargs)
# Record timing
duration_ms = (time.perf_counter() - start_time) * 1000
# Try to extract locale from args/kwargs
locale = kwargs.get("locale", "unknown")
if not locale and len(args) > 1:
locale = args[1] if isinstance(args[1], str) else "unknown"
metrics = get_metrics()
metrics.record_timing(duration_ms, locale, operation)
return result
except Exception as e:
# Record error
metrics = get_metrics()
metrics.record_error(type(e).__name__)
raise
return wrapper
return decorator
def track_missing_translation(key: str, locale: str):
"""
Track a missing translation.
Args:
key: Translation key
locale: Locale code
"""
metrics = get_metrics()
metrics.record_missing(key, locale)
def track_translation_request(locale: str, namespace: str = None):
"""
Track a translation request.
Args:
locale: Locale code
namespace: Translation namespace (optional)
"""
metrics = get_metrics()
metrics.record_request(locale, namespace)

202
api/app/i18n/middleware.py Normal file
View File

@@ -0,0 +1,202 @@
"""
Language detection middleware for i18n system.
This middleware determines the language to use for each request based on:
1. Query parameter (?lang=en)
2. Accept-Language HTTP header
3. User language preference (from database)
4. Tenant default language
5. System default language
The detected language is injected into request.state.language and
added to the response Content-Language header.
"""
import logging
import re
from typing import Optional
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class LanguageMiddleware(BaseHTTPMiddleware):
"""
Language detection middleware.
Determines the language for each request based on multiple sources
with a clear priority order, validates the language is supported,
and injects it into the request context.
"""
async def dispatch(self, request: Request, call_next):
"""
Process the request and determine the language.
Args:
request: The incoming request
call_next: The next middleware/handler in the chain
Returns:
Response with Content-Language header added
"""
# Determine the language for this request
language = await self._determine_language(request)
# Validate language is supported
from app.core.config import settings
if language not in settings.I18N_SUPPORTED_LANGUAGES:
logger.warning(
f"Unsupported language '{language}' requested, "
f"falling back to default: {settings.I18N_DEFAULT_LANGUAGE}"
)
language = settings.I18N_DEFAULT_LANGUAGE
# Inject language into request state
request.state.language = language
# Also set in context variable for exception handling
from app.i18n.exceptions import set_current_locale
set_current_locale(language)
logger.debug(f"Request language set to: {language}")
# Process the request
response = await call_next(request)
# Add Content-Language header to response
response.headers["Content-Language"] = language
return response
async def _determine_language(self, request: Request) -> str:
"""
Determine the language to use based on priority order.
Priority:
1. Query parameter (?lang=en)
2. Accept-Language HTTP header
3. User language preference (from database)
4. Tenant default language
5. System default language
Args:
request: The incoming request
Returns:
Language code (e.g., "zh", "en")
"""
from app.core.config import settings
# 1. Check query parameter (?lang=en)
if "lang" in request.query_params:
lang = request.query_params["lang"].strip().lower()
if lang:
logger.debug(f"Language from query parameter: {lang}")
return lang
# 2. Check Accept-Language HTTP header
if "Accept-Language" in request.headers:
lang = self._parse_accept_language(
request.headers["Accept-Language"]
)
if lang:
logger.debug(f"Language from Accept-Language header: {lang}")
return lang
# 3. Check user language preference (requires authentication)
# Note: This assumes user is already loaded into request.state by auth middleware
if hasattr(request.state, "user") and request.state.user:
user = request.state.user
if hasattr(user, "preferred_language") and user.preferred_language:
logger.debug(
f"Language from user preference: {user.preferred_language}"
)
return user.preferred_language
# 4. Check tenant default language
# Note: This assumes tenant is already loaded into request.state
if hasattr(request.state, "tenant") and request.state.tenant:
tenant = request.state.tenant
if hasattr(tenant, "default_language") and tenant.default_language:
logger.debug(
f"Language from tenant default: {tenant.default_language}"
)
return tenant.default_language
# 5. Fall back to system default language
logger.debug(
f"Using system default language: {settings.I18N_DEFAULT_LANGUAGE}"
)
return settings.I18N_DEFAULT_LANGUAGE
def _parse_accept_language(self, header: str) -> Optional[str]:
"""
Parse the Accept-Language HTTP header.
The Accept-Language header format:
Accept-Language: zh-CN,zh;q=0.9,en;q=0.8,en-US;q=0.7
This method:
1. Parses all language codes and their quality values
2. Extracts the base language code (zh-CN -> zh)
3. Sorts by quality value (higher first)
4. Returns the first supported language
Args:
header: Accept-Language header value
Returns:
Language code if found and supported, None otherwise
Examples:
_parse_accept_language("zh-CN,zh;q=0.9,en;q=0.8")
# => "zh" (if zh is supported)
_parse_accept_language("en-US,en;q=0.9")
# => "en" (if en is supported)
"""
from app.core.config import settings
if not header:
return None
# Parse language preferences with quality values
languages = []
for item in header.split(","):
item = item.strip()
if not item:
continue
# Split language code and quality value
parts = item.split(";")
lang_code = parts[0].strip()
# Extract base language code (zh-CN -> zh, en-US -> en)
base_lang = lang_code.split("-")[0].lower()
# Extract quality value (default: 1.0)
quality = 1.0
if len(parts) > 1:
# Look for q=0.9 pattern
q_match = re.search(r"q=([\d.]+)", parts[1])
if q_match:
try:
quality = float(q_match.group(1))
except ValueError:
quality = 1.0
languages.append((base_lang, quality))
# Sort by quality value (descending)
languages.sort(key=lambda x: x[1], reverse=True)
# Return the first supported language
for lang_code, _ in languages:
if lang_code in settings.I18N_SUPPORTED_LANGUAGES:
return lang_code
return None

221
api/app/i18n/serializers.py Normal file
View File

@@ -0,0 +1,221 @@
"""
国际化响应序列化器
提供基础的 I18nResponseMixin 类,用于为 API 响应添加国际化字段。
"""
from typing import Any, Dict, List, Union
from pydantic import BaseModel
class I18nResponseMixin:
"""国际化响应混入类
为响应数据添加国际化字段,特别是为枚举值添加 _display 后缀的翻译字段。
使用方法:
1. 继承此类
2. 实现 _get_enum_fields() 方法定义需要翻译的枚举字段
3. 调用 serialize_with_i18n() 方法序列化数据
示例:
class WorkspaceSerializer(I18nResponseMixin):
def _get_enum_fields(self) -> Dict[str, str]:
return {
"role": "workspace_role",
"status": "workspace_status"
}
def serialize(self, workspace: Workspace, locale: str = "zh") -> Dict:
data = {
"id": str(workspace.id),
"name": workspace.name,
"role": workspace.role,
"status": workspace.status
}
return self.serialize_with_i18n(data, locale)
"""
def serialize_with_i18n(
self,
data: Any,
locale: str = "zh"
) -> Union[Dict, List[Dict], Any]:
"""序列化数据并添加国际化字段
Args:
data: 要序列化的数据(字典、列表或 Pydantic 模型)
locale: 语言代码
Returns:
序列化后的数据,包含国际化字段
"""
# 如果是 Pydantic 模型,转换为字典
if isinstance(data, BaseModel):
data = data.model_dump()
# 处理不同类型的数据
if isinstance(data, dict):
return self._serialize_dict(data, locale)
elif isinstance(data, list):
return [self._serialize_dict(item, locale) if isinstance(item, dict) else item for item in data]
else:
return data
def _serialize_dict(self, data: Dict, locale: str) -> Dict:
"""序列化字典并添加 _display 字段
Args:
data: 字典数据
locale: 语言代码
Returns:
添加了 _display 字段的字典
"""
from app.i18n.service import get_translation_service
translation_service = get_translation_service()
result = data.copy()
# 获取需要翻译的枚举字段
enum_fields = self._get_enum_fields()
# 为每个枚举字段添加 _display 字段
for field, enum_type in enum_fields.items():
if field in result and result[field] is not None:
value = result[field]
# 翻译枚举值
display_value = translation_service.translate_enum(
enum_type=enum_type,
value=str(value),
locale=locale
)
# 添加 _display 字段
result[f"{field}_display"] = display_value
return result
def _get_enum_fields(self) -> Dict[str, str]:
"""获取需要翻译的枚举字段
子类必须实现此方法,返回字段名到枚举类型的映射。
Returns:
字段名到枚举类型的映射
例如: {"role": "workspace_role", "status": "workspace_status"}
"""
return {}
class WorkspaceSerializer(I18nResponseMixin):
"""工作空间序列化器
为工作空间响应添加国际化字段。
"""
def _get_enum_fields(self) -> Dict[str, str]:
"""定义工作空间的枚举字段"""
return {
"role": "workspace_role",
"status": "workspace_status"
}
def serialize(self, workspace_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
"""序列化工作空间数据
Args:
workspace_data: 工作空间数据(字典或 Pydantic 模型)
locale: 语言代码
Returns:
序列化后的工作空间数据,包含国际化字段
"""
return self.serialize_with_i18n(workspace_data, locale)
def serialize_list(self, workspaces: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
"""序列化工作空间列表
Args:
workspaces: 工作空间列表
locale: 语言代码
Returns:
序列化后的工作空间列表
"""
return [self.serialize(ws, locale) for ws in workspaces]
class WorkspaceMemberSerializer(I18nResponseMixin):
"""工作空间成员序列化器
为工作空间成员响应添加国际化字段。
"""
def _get_enum_fields(self) -> Dict[str, str]:
"""定义工作空间成员的枚举字段"""
return {
"role": "workspace_role"
}
def serialize(self, member_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
"""序列化工作空间成员数据
Args:
member_data: 成员数据(字典或 Pydantic 模型)
locale: 语言代码
Returns:
序列化后的成员数据,包含国际化字段
"""
return self.serialize_with_i18n(member_data, locale)
def serialize_list(self, members: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
"""序列化工作空间成员列表
Args:
members: 成员列表
locale: 语言代码
Returns:
序列化后的成员列表
"""
return [self.serialize(member, locale) for member in members]
class WorkspaceInviteSerializer(I18nResponseMixin):
"""工作空间邀请序列化器
为工作空间邀请响应添加国际化字段。
"""
def _get_enum_fields(self) -> Dict[str, str]:
"""定义工作空间邀请的枚举字段"""
return {
"status": "invite_status",
"role": "workspace_role"
}
def serialize(self, invite_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
"""序列化工作空间邀请数据
Args:
invite_data: 邀请数据(字典或 Pydantic 模型)
locale: 语言代码
Returns:
序列化后的邀请数据,包含国际化字段
"""
return self.serialize_with_i18n(invite_data, locale)
def serialize_list(self, invites: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
"""序列化工作空间邀请列表
Args:
invites: 邀请列表
locale: 语言代码
Returns:
序列化后的邀请列表
"""
return [self.serialize(invite, locale) for invite in invites]

370
api/app/i18n/service.py Normal file
View File

@@ -0,0 +1,370 @@
"""
Translation service for i18n system.
This module provides the core translation functionality including:
- Translation lookup with fallback mechanism
- Parameterized message support
- Enum value translation
- Memory caching for performance
- Performance monitoring and metrics
"""
import logging
from functools import lru_cache
from typing import Any, Dict, Optional
from app.i18n.loader import TranslationLoader
from app.i18n.cache import TranslationCache
from app.i18n.metrics import get_metrics, monitor_performance, track_missing_translation, track_translation_request
from app.i18n.logger import get_translation_logger
logger = logging.getLogger(__name__)
class TranslationService:
"""
Translation service that provides:
- Fast translation lookup with memory cache
- Parameterized message support ({param} syntax)
- Fallback mechanism (current locale → default locale → key)
- Enum value translation
- Deep merge of multi-directory translations
"""
def __init__(self, locales_dirs: Optional[list] = None):
"""
Initialize the translation service.
Args:
locales_dirs: List of directories containing translation files.
If None, will auto-detect from settings.
"""
from app.core.config import settings
self.loader = TranslationLoader(locales_dirs)
self.default_locale = settings.I18N_DEFAULT_LANGUAGE
self.fallback_locale = settings.I18N_FALLBACK_LANGUAGE
self.log_missing = settings.I18N_LOG_MISSING_TRANSLATIONS
self.enable_cache = settings.I18N_ENABLE_TRANSLATION_CACHE
# Initialize advanced cache with LRU
lru_cache_size = getattr(settings, 'I18N_LRU_CACHE_SIZE', 1000)
self.cache = TranslationCache(
max_lru_size=lru_cache_size,
enable_lazy_load=False # Load all at startup for now
)
# Load all translations into cache
self._load_all_locales()
# Initialize metrics
self.metrics = get_metrics()
# Initialize translation logger
self.translation_logger = get_translation_logger()
logger.info(
f"TranslationService initialized with default locale: {self.default_locale}, "
f"LRU cache size: {lru_cache_size}"
)
def _load_all_locales(self):
"""Load all available locales into memory cache."""
available_locales = self.loader.get_available_locales()
logger.info(f"Loading translations for locales: {available_locales}")
for locale in available_locales:
locale_data = self.loader.load_locale(locale)
self.cache.set_locale_data(locale, locale_data)
logger.info(f"Loaded {len(available_locales)} locales into cache")
@monitor_performance("translate")
def translate(
self,
key: str,
locale: Optional[str] = None,
**params
) -> str:
"""
Translate a key to the target locale.
Supports:
- Dot-separated keys (e.g., "common.success.created")
- Parameterized messages (e.g., "Hello {name}")
- Fallback mechanism
Args:
key: Translation key (format: "namespace.key.subkey")
locale: Target locale (defaults to default locale)
**params: Parameters for parameterized messages
Returns:
Translated string, or the key itself if translation not found
Examples:
translate("common.success.created", "zh")
# => "创建成功"
translate("common.validation.required", "zh", field="名称")
# => "名称不能为空"
"""
if locale is None:
locale = self.default_locale
# Parse key (namespace.key.subkey)
parts = key.split(".", 1)
if len(parts) < 2:
if self.log_missing:
logger.warning(f"Invalid translation key format: {key}")
return key
namespace = parts[0]
key_path = parts[1].split(".")
# Track request
track_translation_request(locale, namespace)
# Get translation from cache
translation = self.cache.get_translation(locale, namespace, key_path)
# Fallback to default locale if not found
if translation is None and locale != self.fallback_locale:
translation = self.cache.get_translation(
self.fallback_locale, namespace, key_path
)
# If still not found, return the key itself
if translation is None:
if self.log_missing:
logger.warning(
f"Missing translation: {key} (locale: {locale})"
)
track_missing_translation(key, locale)
# Log to translation logger with context
self.translation_logger.log_missing_translation(
key=key,
locale=locale,
context={"namespace": namespace}
)
return key
# Apply parameters if provided
if params:
try:
translation = translation.format(**params)
except KeyError as e:
error_msg = f"Missing parameter in translation '{key}': {e}"
logger.error(error_msg)
self.translation_logger.log_translation_error(
error_type="parameter_missing",
message=error_msg,
key=key,
locale=locale,
context={"params": list(params.keys())}
)
except Exception as e:
error_msg = f"Error formatting translation '{key}': {e}"
logger.error(error_msg)
self.translation_logger.log_translation_error(
error_type="format_error",
message=error_msg,
key=key,
locale=locale
)
return translation
def _get_translation(
self,
locale: str,
namespace: str,
key_path: list
) -> Optional[str]:
"""
Get translation from cache (deprecated, use cache.get_translation).
Args:
locale: Locale code
namespace: Translation namespace
key_path: List of nested keys
Returns:
Translation string or None if not found
"""
return self.cache.get_translation(locale, namespace, key_path)
@monitor_performance("translate_enum")
def translate_enum(
self,
enum_type: str,
value: str,
locale: Optional[str] = None
) -> str:
"""
Translate an enum value.
Args:
enum_type: Enum type name (e.g., "workspace_role")
value: Enum value (e.g., "manager")
locale: Target locale
Returns:
Translated enum display name
Examples:
translate_enum("workspace_role", "manager", "zh")
# => "管理员"
translate_enum("invite_status", "pending", "en")
# => "Pending"
"""
key = f"enums.{enum_type}.{value}"
return self.translate(key, locale)
def has_translation(self, key: str, locale: str) -> bool:
"""
Check if a translation exists for the given key and locale.
Args:
key: Translation key
locale: Locale code
Returns:
True if translation exists, False otherwise
"""
parts = key.split(".", 1)
if len(parts) < 2:
return False
namespace = parts[0]
key_path = parts[1].split(".")
translation = self.cache.get_translation(locale, namespace, key_path)
return translation is not None
def reload(self, locale: Optional[str] = None):
"""
Reload translation files.
Args:
locale: Specific locale to reload. If None, reloads all locales.
"""
logger.info(f"Reloading translations for locale: {locale or 'all'}")
if locale:
locale_data = self.loader.load_locale(locale)
self.cache.set_locale_data(locale, locale_data)
# Clear LRU cache for this locale
self.cache.clear_locale(locale)
else:
self._load_all_locales()
# Clear all LRU cache
self.cache.clear_lru()
logger.info("Translation reload completed")
def get_available_locales(self) -> list:
"""
Get list of all available locales.
Returns:
List of locale codes
"""
return self.cache.get_loaded_locales()
def get_cache_stats(self) -> Dict[str, Any]:
"""
Get cache statistics.
Returns:
Dictionary with cache statistics
"""
return self.cache.get_stats()
def get_metrics_summary(self) -> Dict[str, Any]:
"""
Get metrics summary.
Returns:
Dictionary with metrics summary
"""
return self.metrics.get_summary()
def get_memory_usage(self) -> Dict[str, Any]:
"""
Get memory usage information.
Returns:
Dictionary with memory usage information
"""
return self.cache.get_memory_usage()
def get_loaded_dirs(self) -> list:
"""
Get list of loaded translation directories.
Returns:
List of directory paths
"""
return self.loader.locales_dirs
# Global singleton instance
_translation_service: Optional[TranslationService] = None
def get_translation_service() -> TranslationService:
"""
Get the global translation service instance.
Returns:
TranslationService singleton
"""
global _translation_service
if _translation_service is None:
_translation_service = TranslationService()
return _translation_service
# Convenience functions for easy access
def t(key: str, locale: Optional[str] = None, **params) -> str:
"""
Translate a key (convenience function).
Args:
key: Translation key
locale: Target locale (optional, uses default if not provided)
**params: Parameters for parameterized messages
Returns:
Translated string
Examples:
t("common.success.created")
t("common.validation.required", field="名称")
t("workspace.member_count", count=5)
"""
service = get_translation_service()
return service.translate(key, locale, **params)
def t_enum(enum_type: str, value: str, locale: Optional[str] = None) -> str:
"""
Translate an enum value (convenience function).
Args:
enum_type: Enum type name
value: Enum value
locale: Target locale
Returns:
Translated enum display name
Examples:
t_enum("workspace_role", "manager")
t_enum("invite_status", "pending", "en")
"""
service = get_translation_service()
return service.translate_enum(enum_type, value, locale)

View File

@@ -0,0 +1,26 @@
# English Translation Files
This directory contains English translation files.
## File Structure
- `common.json` - Common translations (success messages, actions, validation)
- `auth.json` - Authentication module translations
- `workspace.json` - Workspace module translations
- `tenant.json` - Tenant module translations
- `errors.json` - Error message translations
- `enums.json` - Enum value translations
## Translation File Format
All translation files use JSON format and support nested structures.
Example:
```json
{
"success": {
"created": "Created successfully",
"updated": "Updated successfully"
}
}
```

View File

@@ -0,0 +1,55 @@
{
"login": {
"success": "Login successful",
"failed": "Login failed",
"invalid_credentials": "Invalid username or password",
"account_locked": "Account has been locked",
"account_disabled": "Account has been disabled"
},
"logout": {
"success": "Logout successful",
"failed": "Logout failed"
},
"token": {
"refresh_success": "Token refreshed successfully",
"invalid": "Invalid token",
"expired": "Token has expired",
"blacklisted": "Token has been invalidated",
"invalid_refresh_token": "Invalid refresh token",
"refresh_token_blacklisted": "Refresh token has been invalidated"
},
"registration": {
"success": "Registration successful",
"failed": "Registration failed",
"email_exists": "Email already in use",
"username_exists": "Username already taken"
},
"password": {
"reset_success": "Password reset successful",
"reset_failed": "Password reset failed",
"change_success": "Password changed successfully",
"change_failed": "Password change failed",
"incorrect": "Incorrect password",
"too_weak": "Password is too weak",
"mismatch": "Passwords do not match"
},
"invite": {
"invalid": "Invalid or expired invite code",
"email_mismatch": "Invite email does not match login email",
"accept_success": "Invite accepted successfully",
"accept_failed": "Failed to accept invite",
"password_verification_failed": "Failed to accept invite, password verification error",
"bind_workspace_success": "Workspace bound successfully",
"bind_workspace_failed": "Failed to bind workspace"
},
"user": {
"not_found": "User not found",
"already_exists": "User already exists",
"created_with_invite": "User created successfully and joined workspace"
},
"session": {
"expired": "Session expired, please login again",
"invalid": "Invalid session",
"single_session_enabled": "Single sign-on enabled, other device sessions will be logged out"
}
}

View File

@@ -0,0 +1,132 @@
{
"success": {
"created": "Created successfully",
"updated": "Updated successfully",
"deleted": "Deleted successfully",
"retrieved": "Retrieved successfully",
"saved": "Saved successfully",
"uploaded": "Uploaded successfully",
"downloaded": "Downloaded successfully",
"sent": "Sent successfully",
"completed": "Completed",
"confirmed": "Confirmed",
"cancelled": "Cancelled",
"archived": "Archived",
"restored": "Restored"
},
"actions": {
"create": "Create",
"update": "Update",
"delete": "Delete",
"view": "View",
"edit": "Edit",
"save": "Save",
"cancel": "Cancel",
"confirm": "Confirm",
"submit": "Submit",
"upload": "Upload",
"download": "Download",
"send": "Send",
"search": "Search",
"filter": "Filter",
"sort": "Sort",
"export": "Export",
"import": "Import",
"refresh": "Refresh",
"reset": "Reset",
"back": "Back",
"next": "Next",
"previous": "Previous",
"finish": "Finish",
"close": "Close",
"open": "Open",
"archive": "Archive",
"restore": "Restore",
"duplicate": "Duplicate",
"share": "Share",
"invite": "Invite",
"remove": "Remove",
"add": "Add",
"select": "Select",
"clear": "Clear"
},
"validation": {
"required": "{field} is required",
"invalid_format": "{field} format is invalid",
"too_long": "{field} cannot exceed {max} characters",
"too_short": "{field} must be at least {min} characters",
"invalid_email": "Invalid email format",
"invalid_url": "Invalid URL format",
"invalid_phone": "Invalid phone number format",
"invalid_date": "Invalid date format",
"invalid_number": "Must be a valid number",
"out_of_range": "{field} must be between {min} and {max}",
"already_exists": "{field} already exists",
"not_found": "{field} not found",
"invalid_value": "Invalid value for {field}",
"password_mismatch": "Passwords do not match",
"weak_password": "Password is too weak, please use a stronger password",
"invalid_credentials": "Invalid username or password",
"unauthorized": "Unauthorized access",
"forbidden": "Permission denied",
"expired": "{field} has expired",
"invalid_token": "Invalid token",
"file_too_large": "File size cannot exceed {max}",
"invalid_file_type": "Unsupported file type",
"duplicate": "Duplicate {field}"
},
"status": {
"active": "Active",
"inactive": "Inactive",
"pending": "Pending",
"processing": "Processing",
"completed": "Completed",
"failed": "Failed",
"cancelled": "Cancelled",
"archived": "Archived",
"deleted": "Deleted",
"draft": "Draft",
"published": "Published",
"suspended": "Suspended",
"expired": "Expired"
},
"messages": {
"loading": "Loading...",
"saving": "Saving...",
"processing": "Processing...",
"uploading": "Uploading...",
"downloading": "Downloading...",
"no_data": "No data available",
"no_results": "No results found",
"confirm_delete": "Are you sure you want to delete? This action cannot be undone.",
"confirm_action": "Are you sure you want to perform this action?",
"operation_success": "Operation successful",
"operation_failed": "Operation failed",
"please_wait": "Please wait...",
"try_again": "Please try again",
"contact_support": "If the problem persists, please contact support"
},
"pagination": {
"page": "Page {page}",
"of": "of {total}",
"items": "{total} items",
"per_page": "{count} per page",
"showing": "Showing {from} to {to} of {total}",
"first": "First",
"last": "Last",
"next": "Next",
"previous": "Previous"
},
"time": {
"just_now": "Just now",
"minutes_ago": "{count} minutes ago",
"hours_ago": "{count} hours ago",
"days_ago": "{count} days ago",
"weeks_ago": "{count} weeks ago",
"months_ago": "{count} months ago",
"years_ago": "{count} years ago",
"today": "Today",
"yesterday": "Yesterday",
"tomorrow": "Tomorrow"
}
}

View File

@@ -0,0 +1,132 @@
{
"workspace_role": {
"owner": "Owner",
"manager": "Manager",
"member": "Member",
"guest": "Guest"
},
"workspace_status": {
"active": "Active",
"inactive": "Inactive",
"archived": "Archived",
"suspended": "Suspended",
"deleted": "Deleted"
},
"invite_status": {
"pending": "Pending",
"accepted": "Accepted",
"rejected": "Rejected",
"revoked": "Revoked",
"expired": "Expired"
},
"user_status": {
"active": "Active",
"inactive": "Inactive",
"suspended": "Suspended",
"deleted": "Deleted",
"pending": "Pending"
},
"tenant_status": {
"active": "Active",
"inactive": "Inactive",
"suspended": "Suspended",
"expired": "Expired",
"trial": "Trial"
},
"file_status": {
"uploading": "Uploading",
"processing": "Processing",
"completed": "Completed",
"failed": "Failed",
"deleted": "Deleted"
},
"task_status": {
"pending": "Pending",
"running": "Running",
"completed": "Completed",
"failed": "Failed",
"cancelled": "Cancelled",
"paused": "Paused"
},
"priority": {
"low": "Low",
"medium": "Medium",
"high": "High",
"urgent": "Urgent"
},
"visibility": {
"public": "Public",
"private": "Private",
"internal": "Internal",
"shared": "Shared"
},
"permission": {
"read": "Read",
"write": "Write",
"delete": "Delete",
"admin": "Admin",
"owner": "Owner"
},
"notification_type": {
"info": "Info",
"warning": "Warning",
"error": "Error",
"success": "Success"
},
"language": {
"zh": "Chinese (Simplified)",
"en": "English",
"ja": "Japanese",
"ko": "Korean",
"fr": "French",
"de": "German",
"es": "Spanish"
},
"timezone": {
"utc": "UTC",
"asia_shanghai": "Asia/Shanghai",
"asia_tokyo": "Asia/Tokyo",
"america_new_york": "America/New_York",
"europe_london": "Europe/London"
},
"date_format": {
"short": "Short",
"medium": "Medium",
"long": "Long",
"full": "Full"
},
"sort_order": {
"asc": "Ascending",
"desc": "Descending"
},
"filter_operator": {
"equals": "Equals",
"not_equals": "Not Equals",
"contains": "Contains",
"not_contains": "Not Contains",
"starts_with": "Starts With",
"ends_with": "Ends With",
"greater_than": "Greater Than",
"less_than": "Less Than",
"greater_or_equal": "Greater or Equal",
"less_or_equal": "Less or Equal",
"in": "In",
"not_in": "Not In",
"is_null": "Is Null",
"is_not_null": "Is Not Null"
},
"log_level": {
"debug": "Debug",
"info": "Info",
"warning": "Warning",
"error": "Error",
"critical": "Critical"
},
"api_method": {
"get": "GET",
"post": "POST",
"put": "PUT",
"patch": "PATCH",
"delete": "DELETE"
}
}

View File

@@ -0,0 +1,138 @@
{
"common": {
"internal_error": "Internal server error",
"network_error": "Network connection error",
"timeout": "Request timeout",
"service_unavailable": "Service temporarily unavailable",
"bad_request": "Bad request parameters",
"unauthorized": "Unauthorized access",
"forbidden": "Access forbidden",
"not_found": "Resource not found",
"method_not_allowed": "Method not allowed",
"conflict": "Resource conflict",
"too_many_requests": "Too many requests, please try again later",
"validation_failed": "Validation failed",
"database_error": "Database operation failed",
"file_operation_error": "File operation failed"
},
"auth": {
"invalid_credentials": "Invalid username or password",
"token_expired": "Session expired, please login again",
"token_invalid": "Invalid authentication token",
"token_missing": "Authentication token missing",
"unauthorized": "Unauthorized access",
"forbidden": "Permission denied",
"account_locked": "Account has been locked",
"account_disabled": "Account has been disabled",
"account_not_verified": "Account not verified",
"password_incorrect": "Incorrect password",
"password_too_weak": "Password is too weak",
"password_expired": "Password expired, please change it",
"email_not_verified": "Email not verified",
"phone_not_verified": "Phone number not verified",
"verification_code_invalid": "Invalid verification code",
"verification_code_expired": "Verification code expired",
"login_failed": "Login failed",
"logout_failed": "Logout failed",
"session_expired": "Session expired",
"already_logged_in": "Already logged in",
"not_logged_in": "Not logged in"
},
"user": {
"not_found": "User not found",
"already_exists": "User already exists",
"email_already_exists": "Email already in use",
"phone_already_exists": "Phone number already in use",
"username_already_exists": "Username already taken",
"invalid_email": "Invalid email format",
"invalid_phone": "Invalid phone number format",
"invalid_username": "Invalid username format",
"create_failed": "Failed to create user",
"update_failed": "Failed to update user",
"delete_failed": "Failed to delete user",
"cannot_delete_self": "Cannot delete yourself",
"cannot_update_self_role": "Cannot update your own role",
"profile_update_failed": "Failed to update profile",
"avatar_upload_failed": "Failed to upload avatar",
"password_change_failed": "Failed to change password",
"old_password_incorrect": "Old password is incorrect"
},
"workspace": {
"not_found": "Workspace not found",
"already_exists": "Workspace already exists",
"name_required": "Workspace name is required",
"name_too_long": "Workspace name is too long",
"create_failed": "Failed to create workspace",
"update_failed": "Failed to update workspace",
"delete_failed": "Failed to delete workspace",
"permission_denied": "Permission denied to access this workspace",
"not_member": "Not a workspace member",
"already_member": "Already a workspace member",
"member_limit_reached": "Member limit reached",
"cannot_leave_last_manager": "Cannot leave, you are the last manager",
"cannot_remove_last_manager": "Cannot remove the last manager",
"cannot_remove_self": "Cannot remove yourself",
"invite_not_found": "Invite not found",
"invite_expired": "Invite has expired",
"invite_already_accepted": "Invite already accepted",
"invite_already_revoked": "Invite already revoked",
"invite_send_failed": "Failed to send invite",
"archived": "Workspace is archived",
"suspended": "Workspace is suspended"
},
"tenant": {
"not_found": "Tenant not found",
"already_exists": "Tenant already exists",
"create_failed": "Failed to create tenant",
"update_failed": "Failed to update tenant",
"delete_failed": "Failed to delete tenant",
"suspended": "Tenant is suspended",
"expired": "Tenant has expired",
"license_invalid": "Invalid license",
"license_expired": "License has expired",
"quota_exceeded": "Quota exceeded"
},
"file": {
"not_found": "File not found",
"upload_failed": "File upload failed",
"download_failed": "File download failed",
"delete_failed": "File deletion failed",
"too_large": "File size exceeds limit",
"invalid_type": "Unsupported file type",
"invalid_format": "Invalid file format",
"corrupted": "File is corrupted",
"storage_full": "Storage is full",
"access_denied": "Access denied to this file"
},
"api": {
"rate_limit_exceeded": "API rate limit exceeded",
"quota_exceeded": "API quota exceeded",
"invalid_api_key": "Invalid API key",
"api_key_expired": "API key has expired",
"api_key_revoked": "API key has been revoked",
"endpoint_not_found": "API endpoint not found",
"method_not_allowed": "Method not allowed",
"invalid_request": "Invalid request",
"missing_parameter": "Missing required parameter: {param}",
"invalid_parameter": "Invalid parameter: {param}"
},
"database": {
"connection_failed": "Database connection failed",
"query_failed": "Database query failed",
"transaction_failed": "Database transaction failed",
"constraint_violation": "Data constraint violation",
"duplicate_key": "Duplicate data",
"foreign_key_violation": "Foreign key constraint violation",
"deadlock": "Database deadlock"
},
"validation": {
"invalid_input": "Invalid input data",
"missing_field": "Missing required field: {field}",
"invalid_field": "Invalid field: {field}",
"field_too_long": "Field too long: {field}",
"field_too_short": "Field too short: {field}",
"invalid_format": "Invalid format: {field}",
"invalid_value": "Invalid value: {field}",
"out_of_range": "Value out of range: {field}"
}
}

View File

@@ -0,0 +1,27 @@
{
"language": {
"not_found": "Language {locale} not found",
"already_exists": "Language {locale} already exists",
"add_instructions": "Language {locale} validated successfully. Please create translation files in {dir} directory to complete the addition.",
"update_instructions": "Language {locale} update validated successfully. Please update I18N_SUPPORTED_LANGUAGES environment variable to apply configuration changes."
},
"namespace": {
"not_found": "Namespace {namespace} not found in language {locale}"
},
"translation": {
"invalid_key_format": "Invalid translation key format: {key}. Should use format: namespace.key.subkey",
"update_instructions": "Translation {locale}/{key} update validated successfully. Please modify the corresponding JSON translation file to apply changes."
},
"reload": {
"disabled": "Translation hot reload is disabled. Please enable I18N_ENABLE_HOT_RELOAD in configuration.",
"success": "Translations reloaded successfully",
"failed": "Translation reload failed: {error}"
},
"metrics": {
"reset_success": "Performance metrics reset successfully"
},
"logs": {
"export_success": "Missing translations exported to: {file}",
"clear_success": "Missing translation logs cleared successfully"
}
}

View File

@@ -0,0 +1,63 @@
{
"info": {
"get_success": "Tenant information retrieved successfully",
"get_failed": "Failed to retrieve tenant information",
"update_success": "Tenant information updated successfully",
"update_failed": "Failed to update tenant information"
},
"create": {
"success": "Tenant created successfully",
"failed": "Failed to create tenant"
},
"delete": {
"success": "Tenant deleted successfully",
"failed": "Failed to delete tenant"
},
"status": {
"activate_success": "Tenant activated successfully",
"activate_failed": "Failed to activate tenant",
"deactivate_success": "Tenant deactivated successfully",
"deactivate_failed": "Failed to deactivate tenant"
},
"language": {
"get_success": "Tenant language configuration retrieved successfully",
"get_failed": "Failed to retrieve tenant language configuration",
"update_success": "Tenant language configuration updated successfully",
"update_failed": "Failed to update tenant language configuration",
"invalid_language": "Unsupported language code",
"default_not_in_supported": "Default language must be in the supported languages list"
},
"list": {
"get_success": "Tenant list retrieved successfully",
"get_failed": "Failed to retrieve tenant list"
},
"users": {
"list_success": "Tenant user list retrieved successfully",
"list_failed": "Failed to retrieve tenant user list",
"assign_success": "User assigned to tenant successfully",
"assign_failed": "Failed to assign user to tenant",
"remove_success": "User removed from tenant successfully",
"remove_failed": "Failed to remove user from tenant"
},
"statistics": {
"get_success": "Tenant statistics retrieved successfully",
"get_failed": "Failed to retrieve tenant statistics"
},
"validation": {
"name_required": "Tenant name is required",
"name_invalid": "Invalid tenant name format",
"name_too_long": "Tenant name cannot exceed {max} characters",
"description_too_long": "Tenant description cannot exceed {max} characters",
"language_code_invalid": "Invalid language code format",
"supported_languages_empty": "Supported languages list cannot be empty"
},
"errors": {
"not_found": "Tenant not found",
"already_exists": "Tenant name already exists",
"permission_denied": "Permission denied to access this tenant",
"has_users": "Cannot delete tenant, associated users exist",
"has_workspaces": "Cannot delete tenant, associated workspaces exist",
"already_active": "Tenant is already active",
"already_inactive": "Tenant is already inactive"
}
}

View File

@@ -0,0 +1,72 @@
{
"info": {
"get_success": "User information retrieved successfully",
"get_failed": "Failed to retrieve user information",
"update_success": "User information updated successfully",
"update_failed": "Failed to update user information"
},
"create": {
"success": "User created successfully",
"failed": "Failed to create user",
"superuser_success": "Superuser created successfully",
"superuser_failed": "Failed to create superuser"
},
"delete": {
"success": "User deleted successfully",
"failed": "Failed to delete user",
"deactivate_success": "User deactivated successfully",
"deactivate_failed": "Failed to deactivate user"
},
"activate": {
"success": "User activated successfully",
"failed": "Failed to activate user"
},
"language": {
"get_success": "Language preference retrieved successfully",
"get_failed": "Failed to retrieve language preference",
"update_success": "Language preference updated successfully",
"update_failed": "Failed to update language preference",
"invalid_language": "Unsupported language code",
"current": "Current language preference"
},
"email": {
"change_success": "Email changed successfully",
"change_failed": "Failed to change email",
"code_sent": "Verification code has been sent to your email",
"code_send_failed": "Failed to send verification code",
"code_invalid": "Invalid or expired verification code",
"already_exists": "Email already in use"
},
"list": {
"get_success": "User list retrieved successfully",
"get_failed": "Failed to retrieve user list",
"superusers_success": "Tenant superuser list retrieved successfully",
"superusers_failed": "Failed to retrieve tenant superuser list"
},
"validation": {
"username_required": "Username is required",
"username_invalid": "Invalid username format",
"username_too_long": "Username cannot exceed {max} characters",
"email_required": "Email is required",
"email_invalid": "Invalid email format",
"password_required": "Password is required",
"password_too_short": "Password must be at least {min} characters",
"password_too_long": "Password cannot exceed {max} characters",
"old_password_required": "Old password is required",
"new_password_required": "New password is required",
"verification_code_required": "Verification code is required",
"verification_code_invalid": "Invalid verification code format"
},
"errors": {
"not_found": "User not found",
"already_exists": "User already exists",
"permission_denied": "Permission denied to access this user",
"cannot_delete_self": "Cannot delete yourself",
"cannot_deactivate_self": "Cannot deactivate yourself",
"already_deactivated": "User is already deactivated",
"already_activated": "User is already activated",
"password_verification_failed": "Password verification failed",
"old_password_incorrect": "Old password is incorrect",
"same_as_old_password": "New password cannot be the same as old password"
}
}

View File

@@ -0,0 +1,44 @@
{
"list_retrieved": "Workspace list retrieved successfully",
"created": "Workspace created successfully",
"updated": "Workspace updated successfully",
"deleted": "Workspace deleted successfully",
"switched": "Workspace switched successfully",
"not_found": "Workspace not found or access denied",
"already_exists": "Workspace already exists",
"permission_denied": "No permission to access this workspace",
"name_required": "Workspace name is required",
"invalid_name": "Invalid workspace name format",
"members": {
"list_retrieved": "Workspace members list retrieved successfully",
"role_updated": "Member role updated successfully",
"deleted": "Member deleted successfully",
"not_found": "Member not found",
"cannot_remove_self": "Cannot remove yourself",
"cannot_remove_last_manager": "Cannot remove the last manager",
"already_member": "User is already a workspace member"
},
"invites": {
"created": "Invite created successfully",
"list_retrieved": "Invite list retrieved successfully",
"validated": "Invite validated successfully",
"revoked": "Invite revoked successfully",
"accepted": "Invite accepted",
"not_found": "Invite not found",
"expired": "Invite has expired",
"already_used": "Invite has already been used",
"invalid_token": "Invalid invite token",
"email_required": "Email address is required",
"invalid_email": "Invalid email address format"
},
"storage": {
"type_retrieved": "Storage type retrieved successfully",
"type_updated": "Storage type updated successfully",
"invalid_type": "Invalid storage type"
},
"models": {
"config_retrieved": "Model configuration retrieved successfully",
"config_updated": "Model configuration updated successfully",
"invalid_config": "Invalid model configuration"
}
}

View File

@@ -0,0 +1,26 @@
# 中文翻译文件
此目录包含中文(简体)的翻译文件。
## 文件结构
- `common.json` - 通用翻译(成功消息、操作、验证)
- `auth.json` - 认证模块翻译
- `workspace.json` - 工作空间模块翻译
- `tenant.json` - 租户模块翻译
- `errors.json` - 错误消息翻译
- `enums.json` - 枚举值翻译
## 翻译文件格式
所有翻译文件使用 JSON 格式,支持嵌套结构。
示例:
```json
{
"success": {
"created": "创建成功",
"updated": "更新成功"
}
}
```

View File

@@ -0,0 +1,55 @@
{
"login": {
"success": "登录成功",
"failed": "登录失败",
"invalid_credentials": "用户名或密码错误",
"account_locked": "账户已被锁定",
"account_disabled": "账户已被禁用"
},
"logout": {
"success": "登出成功",
"failed": "登出失败"
},
"token": {
"refresh_success": "token刷新成功",
"invalid": "无效的token",
"expired": "token已过期",
"blacklisted": "token已失效",
"invalid_refresh_token": "无效的refresh token",
"refresh_token_blacklisted": "Refresh token已失效"
},
"registration": {
"success": "注册成功",
"failed": "注册失败",
"email_exists": "邮箱已被使用",
"username_exists": "用户名已被使用"
},
"password": {
"reset_success": "密码重置成功",
"reset_failed": "密码重置失败",
"change_success": "密码修改成功",
"change_failed": "密码修改失败",
"incorrect": "密码错误",
"too_weak": "密码强度不够",
"mismatch": "两次输入的密码不一致"
},
"invite": {
"invalid": "邀请码无效或已过期",
"email_mismatch": "邀请邮箱与登录邮箱不匹配",
"accept_success": "接受邀请成功",
"accept_failed": "接受邀请失败",
"password_verification_failed": "接受邀请失败,密码验证错误",
"bind_workspace_success": "绑定工作空间成功",
"bind_workspace_failed": "绑定工作空间失败"
},
"user": {
"not_found": "用户不存在",
"already_exists": "用户已存在",
"created_with_invite": "用户创建成功并已加入工作空间"
},
"session": {
"expired": "会话已过期,请重新登录",
"invalid": "无效的会话",
"single_session_enabled": "单点登录已启用,其他设备的登录将被注销"
}
}

View File

@@ -0,0 +1,132 @@
{
"success": {
"created": "创建成功",
"updated": "更新成功",
"deleted": "删除成功",
"retrieved": "获取成功",
"saved": "保存成功",
"uploaded": "上传成功",
"downloaded": "下载成功",
"sent": "发送成功",
"completed": "完成",
"confirmed": "已确认",
"cancelled": "已取消",
"archived": "已归档",
"restored": "已恢复"
},
"actions": {
"create": "创建",
"update": "更新",
"delete": "删除",
"view": "查看",
"edit": "编辑",
"save": "保存",
"cancel": "取消",
"confirm": "确认",
"submit": "提交",
"upload": "上传",
"download": "下载",
"send": "发送",
"search": "搜索",
"filter": "筛选",
"sort": "排序",
"export": "导出",
"import": "导入",
"refresh": "刷新",
"reset": "重置",
"back": "返回",
"next": "下一步",
"previous": "上一步",
"finish": "完成",
"close": "关闭",
"open": "打开",
"archive": "归档",
"restore": "恢复",
"duplicate": "复制",
"share": "分享",
"invite": "邀请",
"remove": "移除",
"add": "添加",
"select": "选择",
"clear": "清除"
},
"validation": {
"required": "{field}不能为空",
"invalid_format": "{field}格式不正确",
"too_long": "{field}长度不能超过{max}个字符",
"too_short": "{field}长度不能少于{min}个字符",
"invalid_email": "邮箱格式不正确",
"invalid_url": "URL格式不正确",
"invalid_phone": "手机号格式不正确",
"invalid_date": "日期格式不正确",
"invalid_number": "必须是有效的数字",
"out_of_range": "{field}必须在{min}和{max}之间",
"already_exists": "{field}已存在",
"not_found": "{field}不存在",
"invalid_value": "{field}的值无效",
"password_mismatch": "两次输入的密码不一致",
"weak_password": "密码强度不够,请使用更复杂的密码",
"invalid_credentials": "用户名或密码错误",
"unauthorized": "未授权访问",
"forbidden": "没有权限执行此操作",
"expired": "{field}已过期",
"invalid_token": "无效的令牌",
"file_too_large": "文件大小不能超过{max}",
"invalid_file_type": "不支持的文件类型",
"duplicate": "重复的{field}"
},
"status": {
"active": "活跃",
"inactive": "未激活",
"pending": "待处理",
"processing": "处理中",
"completed": "已完成",
"failed": "失败",
"cancelled": "已取消",
"archived": "已归档",
"deleted": "已删除",
"draft": "草稿",
"published": "已发布",
"suspended": "已暂停",
"expired": "已过期"
},
"messages": {
"loading": "加载中...",
"saving": "保存中...",
"processing": "处理中...",
"uploading": "上传中...",
"downloading": "下载中...",
"no_data": "暂无数据",
"no_results": "没有找到结果",
"confirm_delete": "确定要删除吗?此操作不可恢复。",
"confirm_action": "确定要执行此操作吗?",
"operation_success": "操作成功",
"operation_failed": "操作失败",
"please_wait": "请稍候...",
"try_again": "请重试",
"contact_support": "如果问题持续,请联系技术支持"
},
"pagination": {
"page": "第{page}页",
"of": "共{total}页",
"items": "共{total}条",
"per_page": "每页{count}条",
"showing": "显示第{from}到第{to}条,共{total}条",
"first": "首页",
"last": "末页",
"next": "下一页",
"previous": "上一页"
},
"time": {
"just_now": "刚刚",
"minutes_ago": "{count}分钟前",
"hours_ago": "{count}小时前",
"days_ago": "{count}天前",
"weeks_ago": "{count}周前",
"months_ago": "{count}个月前",
"years_ago": "{count}年前",
"today": "今天",
"yesterday": "昨天",
"tomorrow": "明天"
}
}

View File

@@ -0,0 +1,132 @@
{
"workspace_role": {
"owner": "所有者",
"manager": "管理员",
"member": "成员",
"guest": "访客"
},
"workspace_status": {
"active": "活跃",
"inactive": "未激活",
"archived": "已归档",
"suspended": "已暂停",
"deleted": "已删除"
},
"invite_status": {
"pending": "待处理",
"accepted": "已接受",
"rejected": "已拒绝",
"revoked": "已撤销",
"expired": "已过期"
},
"user_status": {
"active": "活跃",
"inactive": "未激活",
"suspended": "已暂停",
"deleted": "已删除",
"pending": "待激活"
},
"tenant_status": {
"active": "活跃",
"inactive": "未激活",
"suspended": "已暂停",
"expired": "已过期",
"trial": "试用中"
},
"file_status": {
"uploading": "上传中",
"processing": "处理中",
"completed": "已完成",
"failed": "失败",
"deleted": "已删除"
},
"task_status": {
"pending": "待处理",
"running": "运行中",
"completed": "已完成",
"failed": "失败",
"cancelled": "已取消",
"paused": "已暂停"
},
"priority": {
"low": "低",
"medium": "中",
"high": "高",
"urgent": "紧急"
},
"visibility": {
"public": "公开",
"private": "私有",
"internal": "内部",
"shared": "共享"
},
"permission": {
"read": "读取",
"write": "写入",
"delete": "删除",
"admin": "管理",
"owner": "所有者"
},
"notification_type": {
"info": "信息",
"warning": "警告",
"error": "错误",
"success": "成功"
},
"language": {
"zh": "中文(简体)",
"en": "English",
"ja": "日本語",
"ko": "한국어",
"fr": "Français",
"de": "Deutsch",
"es": "Español"
},
"timezone": {
"utc": "UTC",
"asia_shanghai": "亚洲/上海",
"asia_tokyo": "亚洲/东京",
"america_new_york": "美洲/纽约",
"europe_london": "欧洲/伦敦"
},
"date_format": {
"short": "短日期",
"medium": "中等日期",
"long": "长日期",
"full": "完整日期"
},
"sort_order": {
"asc": "升序",
"desc": "降序"
},
"filter_operator": {
"equals": "等于",
"not_equals": "不等于",
"contains": "包含",
"not_contains": "不包含",
"starts_with": "开始于",
"ends_with": "结束于",
"greater_than": "大于",
"less_than": "小于",
"greater_or_equal": "大于等于",
"less_or_equal": "小于等于",
"in": "在列表中",
"not_in": "不在列表中",
"is_null": "为空",
"is_not_null": "不为空"
},
"log_level": {
"debug": "调试",
"info": "信息",
"warning": "警告",
"error": "错误",
"critical": "严重"
},
"api_method": {
"get": "GET",
"post": "POST",
"put": "PUT",
"patch": "PATCH",
"delete": "DELETE"
}
}

View File

@@ -0,0 +1,138 @@
{
"common": {
"internal_error": "服务器内部错误",
"network_error": "网络连接错误",
"timeout": "请求超时",
"service_unavailable": "服务暂时不可用",
"bad_request": "请求参数错误",
"unauthorized": "未授权访问",
"forbidden": "没有权限访问",
"not_found": "请求的资源不存在",
"method_not_allowed": "不支持的请求方法",
"conflict": "资源冲突",
"too_many_requests": "请求过于频繁,请稍后再试",
"validation_failed": "数据验证失败",
"database_error": "数据库操作失败",
"file_operation_error": "文件操作失败"
},
"auth": {
"invalid_credentials": "用户名或密码错误",
"token_expired": "登录已过期,请重新登录",
"token_invalid": "无效的登录令牌",
"token_missing": "缺少登录令牌",
"unauthorized": "未授权访问",
"forbidden": "没有权限执行此操作",
"account_locked": "账户已被锁定",
"account_disabled": "账户已被禁用",
"account_not_verified": "账户未验证",
"password_incorrect": "密码错误",
"password_too_weak": "密码强度不够",
"password_expired": "密码已过期,请修改密码",
"email_not_verified": "邮箱未验证",
"phone_not_verified": "手机号未验证",
"verification_code_invalid": "验证码无效",
"verification_code_expired": "验证码已过期",
"login_failed": "登录失败",
"logout_failed": "登出失败",
"session_expired": "会话已过期",
"already_logged_in": "已经登录",
"not_logged_in": "未登录"
},
"user": {
"not_found": "用户不存在",
"already_exists": "用户已存在",
"email_already_exists": "邮箱已被使用",
"phone_already_exists": "手机号已被使用",
"username_already_exists": "用户名已被使用",
"invalid_email": "邮箱格式不正确",
"invalid_phone": "手机号格式不正确",
"invalid_username": "用户名格式不正确",
"create_failed": "创建用户失败",
"update_failed": "更新用户失败",
"delete_failed": "删除用户失败",
"cannot_delete_self": "不能删除自己",
"cannot_update_self_role": "不能修改自己的角色",
"profile_update_failed": "更新个人资料失败",
"avatar_upload_failed": "上传头像失败",
"password_change_failed": "修改密码失败",
"old_password_incorrect": "原密码错误"
},
"workspace": {
"not_found": "工作空间不存在",
"already_exists": "工作空间已存在",
"name_required": "工作空间名称不能为空",
"name_too_long": "工作空间名称过长",
"create_failed": "创建工作空间失败",
"update_failed": "更新工作空间失败",
"delete_failed": "删除工作空间失败",
"permission_denied": "没有权限访问此工作空间",
"not_member": "不是工作空间成员",
"already_member": "已经是工作空间成员",
"member_limit_reached": "成员数量已达上限",
"cannot_leave_last_manager": "不能离开,您是最后一个管理员",
"cannot_remove_last_manager": "不能移除最后一个管理员",
"cannot_remove_self": "不能移除自己",
"invite_not_found": "邀请不存在",
"invite_expired": "邀请已过期",
"invite_already_accepted": "邀请已被接受",
"invite_already_revoked": "邀请已被撤销",
"invite_send_failed": "发送邀请失败",
"archived": "工作空间已归档",
"suspended": "工作空间已暂停"
},
"tenant": {
"not_found": "租户不存在",
"already_exists": "租户已存在",
"create_failed": "创建租户失败",
"update_failed": "更新租户失败",
"delete_failed": "删除租户失败",
"suspended": "租户已暂停",
"expired": "租户已过期",
"license_invalid": "许可证无效",
"license_expired": "许可证已过期",
"quota_exceeded": "配额已超限"
},
"file": {
"not_found": "文件不存在",
"upload_failed": "文件上传失败",
"download_failed": "文件下载失败",
"delete_failed": "文件删除失败",
"too_large": "文件大小超过限制",
"invalid_type": "不支持的文件类型",
"invalid_format": "文件格式不正确",
"corrupted": "文件已损坏",
"storage_full": "存储空间已满",
"access_denied": "没有权限访问此文件"
},
"api": {
"rate_limit_exceeded": "API调用频率超限",
"quota_exceeded": "API调用配额已用完",
"invalid_api_key": "无效的API密钥",
"api_key_expired": "API密钥已过期",
"api_key_revoked": "API密钥已被撤销",
"endpoint_not_found": "API端点不存在",
"method_not_allowed": "不支持的请求方法",
"invalid_request": "无效的请求",
"missing_parameter": "缺少必需参数:{param}",
"invalid_parameter": "参数无效:{param}"
},
"database": {
"connection_failed": "数据库连接失败",
"query_failed": "数据库查询失败",
"transaction_failed": "数据库事务失败",
"constraint_violation": "数据约束冲突",
"duplicate_key": "数据重复",
"foreign_key_violation": "外键约束冲突",
"deadlock": "数据库死锁"
},
"validation": {
"invalid_input": "输入数据无效",
"missing_field": "缺少必需字段:{field}",
"invalid_field": "字段无效:{field}",
"field_too_long": "字段过长:{field}",
"field_too_short": "字段过短:{field}",
"invalid_format": "格式不正确:{field}",
"invalid_value": "值无效:{field}",
"out_of_range": "值超出范围:{field}"
}
}

View File

@@ -0,0 +1,27 @@
{
"language": {
"not_found": "语言 {locale} 不存在",
"already_exists": "语言 {locale} 已存在",
"add_instructions": "语言 {locale} 验证成功。请在 {dir} 目录下创建翻译文件以完成添加。",
"update_instructions": "语言 {locale} 更新验证成功。请更新环境变量 I18N_SUPPORTED_LANGUAGES 以应用配置更改。"
},
"namespace": {
"not_found": "命名空间 {namespace} 在语言 {locale} 中不存在"
},
"translation": {
"invalid_key_format": "翻译键格式无效: {key}。应使用格式: namespace.key.subkey",
"update_instructions": "翻译 {locale}/{key} 更新验证成功。请修改对应的 JSON 翻译文件以应用更改。"
},
"reload": {
"disabled": "翻译热重载功能已禁用。请在配置中启用 I18N_ENABLE_HOT_RELOAD。",
"success": "翻译重载成功",
"failed": "翻译重载失败: {error}"
},
"metrics": {
"reset_success": "性能指标已重置"
},
"logs": {
"export_success": "缺失翻译已导出到: {file}",
"clear_success": "缺失翻译日志已清除"
}
}

View File

@@ -0,0 +1,63 @@
{
"info": {
"get_success": "租户信息获取成功",
"get_failed": "租户信息获取失败",
"update_success": "租户信息更新成功",
"update_failed": "租户信息更新失败"
},
"create": {
"success": "租户创建成功",
"failed": "租户创建失败"
},
"delete": {
"success": "租户删除成功",
"failed": "租户删除失败"
},
"status": {
"activate_success": "租户启用成功",
"activate_failed": "租户启用失败",
"deactivate_success": "租户禁用成功",
"deactivate_failed": "租户禁用失败"
},
"language": {
"get_success": "租户语言配置获取成功",
"get_failed": "租户语言配置获取失败",
"update_success": "租户语言配置更新成功",
"update_failed": "租户语言配置更新失败",
"invalid_language": "不支持的语言代码",
"default_not_in_supported": "默认语言必须在支持的语言列表中"
},
"list": {
"get_success": "租户列表获取成功",
"get_failed": "租户列表获取失败"
},
"users": {
"list_success": "租户用户列表获取成功",
"list_failed": "租户用户列表获取失败",
"assign_success": "用户分配到租户成功",
"assign_failed": "用户分配到租户失败",
"remove_success": "用户从租户移除成功",
"remove_failed": "用户从租户移除失败"
},
"statistics": {
"get_success": "租户统计信息获取成功",
"get_failed": "租户统计信息获取失败"
},
"validation": {
"name_required": "租户名称不能为空",
"name_invalid": "租户名称格式不正确",
"name_too_long": "租户名称长度不能超过{max}个字符",
"description_too_long": "租户描述长度不能超过{max}个字符",
"language_code_invalid": "语言代码格式不正确",
"supported_languages_empty": "支持的语言列表不能为空"
},
"errors": {
"not_found": "租户不存在",
"already_exists": "租户名称已存在",
"permission_denied": "没有权限访问此租户",
"has_users": "无法删除租户,存在关联的用户",
"has_workspaces": "无法删除租户,存在关联的工作空间",
"already_active": "租户已处于激活状态",
"already_inactive": "租户已处于禁用状态"
}
}

View File

@@ -0,0 +1,72 @@
{
"info": {
"get_success": "用户信息获取成功",
"get_failed": "用户信息获取失败",
"update_success": "用户信息更新成功",
"update_failed": "用户信息更新失败"
},
"create": {
"success": "用户创建成功",
"failed": "用户创建失败",
"superuser_success": "超级管理员创建成功",
"superuser_failed": "超级管理员创建失败"
},
"delete": {
"success": "用户删除成功",
"failed": "用户删除失败",
"deactivate_success": "用户停用成功",
"deactivate_failed": "用户停用失败"
},
"activate": {
"success": "用户激活成功",
"failed": "用户激活失败"
},
"language": {
"get_success": "语言偏好获取成功",
"get_failed": "语言偏好获取失败",
"update_success": "语言偏好更新成功",
"update_failed": "语言偏好更新失败",
"invalid_language": "不支持的语言代码",
"current": "当前语言偏好"
},
"email": {
"change_success": "邮箱修改成功",
"change_failed": "邮箱修改失败",
"code_sent": "验证码已发送到您的邮箱,请查收",
"code_send_failed": "验证码发送失败",
"code_invalid": "验证码无效或已过期",
"already_exists": "该邮箱已被使用"
},
"list": {
"get_success": "用户列表获取成功",
"get_failed": "用户列表获取失败",
"superusers_success": "租户超管列表获取成功",
"superusers_failed": "租户超管列表获取失败"
},
"validation": {
"username_required": "用户名不能为空",
"username_invalid": "用户名格式不正确",
"username_too_long": "用户名长度不能超过{max}个字符",
"email_required": "邮箱不能为空",
"email_invalid": "邮箱格式不正确",
"password_required": "密码不能为空",
"password_too_short": "密码长度不能少于{min}个字符",
"password_too_long": "密码长度不能超过{max}个字符",
"old_password_required": "旧密码不能为空",
"new_password_required": "新密码不能为空",
"verification_code_required": "验证码不能为空",
"verification_code_invalid": "验证码格式不正确"
},
"errors": {
"not_found": "用户不存在",
"already_exists": "用户已存在",
"permission_denied": "没有权限访问此用户",
"cannot_delete_self": "不能删除自己",
"cannot_deactivate_self": "不能停用自己",
"already_deactivated": "用户已被停用",
"already_activated": "用户已处于激活状态",
"password_verification_failed": "密码验证失败",
"old_password_incorrect": "旧密码不正确",
"same_as_old_password": "新密码不能与旧密码相同"
}
}

View File

@@ -0,0 +1,44 @@
{
"list_retrieved": "工作空间列表获取成功",
"created": "工作空间创建成功",
"updated": "工作空间更新成功",
"deleted": "工作空间删除成功",
"switched": "工作空间切换成功",
"not_found": "工作空间不存在或无权访问",
"already_exists": "工作空间已存在",
"permission_denied": "没有权限访问此工作空间",
"name_required": "工作空间名称不能为空",
"invalid_name": "工作空间名称格式不正确",
"members": {
"list_retrieved": "工作空间成员列表获取成功",
"role_updated": "成员角色更新成功",
"deleted": "成员删除成功",
"not_found": "成员不存在",
"cannot_remove_self": "不能删除自己",
"cannot_remove_last_manager": "不能删除最后一个管理员",
"already_member": "用户已经是工作空间成员"
},
"invites": {
"created": "邀请创建成功",
"list_retrieved": "邀请列表获取成功",
"validated": "邀请验证成功",
"revoked": "邀请撤销成功",
"accepted": "邀请已接受",
"not_found": "邀请不存在",
"expired": "邀请已过期",
"already_used": "邀请已被使用",
"invalid_token": "无效的邀请令牌",
"email_required": "邮箱地址不能为空",
"invalid_email": "邮箱地址格式不正确"
},
"storage": {
"type_retrieved": "存储类型获取成功",
"type_updated": "存储类型更新成功",
"invalid_type": "无效的存储类型"
},
"models": {
"config_retrieved": "模型配置获取成功",
"config_updated": "模型配置更新成功",
"invalid_config": "无效的模型配置"
}
}

View File

@@ -92,6 +92,10 @@ app.add_middleware(
allow_headers=["*"],
)
# Add i18n language detection middleware
from app.i18n.middleware import LanguageMiddleware
app.add_middleware(LanguageMiddleware)
logger.info("FastAPI应用程序启动")
@@ -129,6 +133,11 @@ from app.core.exceptions import (
from app.core.sensitive_filter import SensitiveDataFilter
import traceback
# Import i18n exception support
from app.i18n.exceptions import I18nException
from app.i18n.service import get_translation_service
from pydantic import ValidationError as PydanticValidationError
# 处理验证异常
@app.exception_handler(ValidationException)
@@ -156,6 +165,131 @@ async def validation_exception_handler(request: Request, exc: ValidationExceptio
)
# 处理 i18n 异常(国际化异常)
@app.exception_handler(I18nException)
async def i18n_exception_handler(request: Request, exc: I18nException):
"""
处理国际化异常
I18nException 已经自动翻译了错误消息,直接返回即可
"""
# 获取当前语言
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
# 获取异常详情(已经包含翻译后的消息)
detail = exc.detail
# 过滤敏感信息
if isinstance(detail, dict):
filtered_message = SensitiveDataFilter.filter_string(detail.get("message", ""))
filtered_detail = {
**detail,
"message": filtered_message
}
else:
filtered_detail = SensitiveDataFilter.filter_string(str(detail))
logger.warning(
f"I18n exception: {exc.error_key}",
extra={
"path": request.url.path,
"method": request.method,
"error_code": exc.error_code,
"error_key": exc.error_key,
"language": language,
"status_code": exc.status_code,
"params": exc.params
}
)
return JSONResponse(
status_code=exc.status_code,
content={
"success": False,
**filtered_detail
},
headers=exc.headers
)
# 处理 Pydantic 验证错误(国际化支持)
@app.exception_handler(PydanticValidationError)
async def pydantic_validation_exception_handler(request: Request, exc: PydanticValidationError):
"""
处理 Pydantic 验证错误,支持国际化
"""
# 获取当前语言
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
# 获取翻译服务
translation_service = get_translation_service()
# 翻译验证错误消息
errors = []
for error in exc.errors():
field = ".".join(str(loc) for loc in error["loc"])
error_type = error["type"]
# 尝试翻译错误消息
if error_type == "value_error.missing":
message = translation_service.translate(
"errors.validation.missing_field",
language,
field=field
)
elif error_type == "value_error.any_str.max_length":
message = translation_service.translate(
"errors.validation.field_too_long",
language,
field=field
)
elif error_type == "value_error.any_str.min_length":
message = translation_service.translate(
"errors.validation.field_too_short",
language,
field=field
)
else:
# 使用通用验证错误消息
message = translation_service.translate(
"errors.validation.invalid_field",
language,
field=field
)
errors.append({
"field": field,
"message": message,
"type": error_type
})
# 翻译主错误消息
main_message = translation_service.translate(
"errors.common.validation_failed",
language
)
logger.warning(
f"Pydantic validation error: {len(errors)} errors",
extra={
"path": request.url.path,
"method": request.method,
"language": language,
"errors": errors
}
)
return JSONResponse(
status_code=422,
content={
"success": False,
"error_code": "VALIDATION_FAILED",
"message": main_message,
"errors": errors
}
)
# 处理资源不存在异常
@app.exception_handler(ResourceNotFoundException)
async def not_found_exception_handler(request: Request, exc: ResourceNotFoundException):
@@ -354,31 +488,66 @@ async def business_exception_handler(request: Request, exc: BusinessException):
)
# 统一异常处理将HTTPException转换为统一响应结构
# 统一异常处理将HTTPException转换为统一响应结构(支持国际化)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""处理HTTP异常"""
# 过滤敏感信息
filtered_detail = SensitiveDataFilter.filter_string(str(exc.detail))
"""处理HTTP异常,支持国际化"""
# 获取当前语言
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
# 获取翻译服务
translation_service = get_translation_service()
# 尝试翻译标准HTTP错误
error_key_map = {
400: "errors.common.bad_request",
401: "errors.common.unauthorized",
403: "errors.common.forbidden",
404: "errors.common.not_found",
405: "errors.common.method_not_allowed",
409: "errors.common.conflict",
422: "errors.common.validation_failed",
429: "errors.common.too_many_requests",
500: "errors.common.internal_error",
503: "errors.common.service_unavailable",
}
# 如果有对应的翻译键,使用翻译
if exc.status_code in error_key_map:
translated_message = translation_service.translate(
error_key_map[exc.status_code],
language
)
else:
# 否则过滤原始消息
translated_message = SensitiveDataFilter.filter_string(str(exc.detail))
logger.warning(
f"HTTP exception: {filtered_detail}",
f"HTTP exception: {translated_message}",
extra={
"path": request.url.path,
"method": request.method,
"status_code": exc.status_code
"status_code": exc.status_code,
"language": language
}
)
return JSONResponse(
status_code=exc.status_code,
content=fail(code=exc.status_code, msg=filtered_detail, error=filtered_detail)
content=fail(code=exc.status_code, msg=translated_message, error=translated_message)
)
# 捕获未处理的异常,返回统一错误结构
# 捕获未处理的异常,返回统一错误结构(支持国际化)
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception):
"""处理未捕获的异常"""
"""处理未捕获的异常,支持国际化"""
# 获取当前语言
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
# 获取翻译服务
translation_service = get_translation_service()
# 记录完整的堆栈跟踪(日志过滤器会自动过滤敏感信息)
logger.error(
f"Unhandled exception: {exc}",
@@ -386,6 +555,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
"path": request.url.path,
"method": request.method,
"exception_type": type(exc).__name__,
"language": language,
"traceback": traceback.format_exc()
},
exc_info=True
@@ -394,7 +564,11 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
# 生产环境隐藏详细错误信息
environment = os.getenv("ENVIRONMENT", "development")
if environment == "production":
message = "服务器内部错误,请稍后重试"
# 使用翻译的通用错误消息
message = translation_service.translate(
"errors.common.internal_error",
language
)
else:
# 开发环境也要过滤敏感信息
message = SensitiveDataFilter.filter_string(str(exc))

View File

@@ -7,6 +7,7 @@ from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.dialects.postgresql import JSONB
from app.db import Base
from app.schemas import FileType
class PerceptualType(IntEnum):
@@ -15,6 +16,16 @@ class PerceptualType(IntEnum):
TEXT = 3
CONVERSATION = 4
@staticmethod
def trans_from_file_type(file_type: FileType | str):
type_map = {
FileType.IMAGE: PerceptualType.VISION,
FileType.AUDIO: PerceptualType.AUDIO,
FileType.VIDEO: PerceptualType.VISION,
FileType.DOCUMENT: PerceptualType.TEXT
}
return type_map.get(file_type, PerceptualType.TEXT)
class FileStorageService(IntEnum):
LOCAL = 1

View File

@@ -1,7 +1,7 @@
import datetime
import uuid
from sqlalchemy import Column, String, DateTime, Boolean
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy import Column, String, DateTime, Boolean, text
from sqlalchemy.dialects.postgresql import UUID, ARRAY
from sqlalchemy.orm import relationship
from app.db import Base
@@ -20,6 +20,10 @@ class Tenants(Base):
external_id = Column(String(100), nullable=True, index=True) # 外部企业ID
external_source = Column(String(50), nullable=True) # 来源系统
# 国际化语言配置字段
default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言
supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表
# Relationship to users - one tenant has many users
users = relationship("User", back_populates="tenant")

View File

@@ -1,6 +1,6 @@
import datetime
import uuid
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.db import Base
@@ -22,6 +22,9 @@ class User(Base):
external_id = Column(String(100), nullable=True) # 外部用户ID
external_source = Column(String(50), nullable=True) # 来源系统
# 用户语言偏好
preferred_language = Column(String(10), server_default=text("'zh'"), default='zh', nullable=False, index=True) # 用户偏好语言,默认中文
current_workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=True) # 当前工作空间ID可为空
# Foreign key to tenant - each user belongs to exactly one tenant

View File

@@ -2,7 +2,7 @@ import uuid
from datetime import datetime
from typing import List, Tuple, Optional
from sqlalchemy import and_, desc
from sqlalchemy import and_, desc, select
from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger
@@ -127,6 +127,17 @@ class MemoryPerceptualRepository:
db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}")
raise
def get_by_url(
self,
file_url: str
) -> list[MemoryPerceptualModel]:
try:
stmt = select(MemoryPerceptualModel).where(MemoryPerceptualModel.file_path == file_url)
return list(self.db.execute(stmt).scalars())
except Exception:
db_logger.error(f"Failed to query perceptual memories by file_url: file_url={file_url}")
raise
def get_by_type(
self,
end_user_id: uuid.UUID,

View File

@@ -0,0 +1,73 @@
"""
I18n Management API Schemas
This module defines Pydantic schemas for i18n management APIs.
"""
from pydantic import BaseModel, Field
from typing import Dict, List, Optional, Any
# ============================================================================
# Language Management Schemas
# ============================================================================
class LanguageInfo(BaseModel):
"""Language information"""
code: str = Field(..., description="Language code (e.g., 'zh', 'en')")
name: str = Field(..., description="Language name (e.g., 'Chinese', 'English')")
native_name: str = Field(..., description="Native language name (e.g., '中文', 'English')")
is_enabled: bool = Field(..., description="Whether the language is enabled")
is_default: bool = Field(..., description="Whether this is the default language")
class LanguageListResponse(BaseModel):
"""Response for language list"""
languages: List[LanguageInfo] = Field(..., description="List of available languages")
class LanguageCreateRequest(BaseModel):
"""Request to add a new language"""
code: str = Field(..., description="Language code (e.g., 'ja', 'ko')", min_length=2, max_length=10)
name: str = Field(..., description="Language name", min_length=1, max_length=100)
native_name: str = Field(..., description="Native language name", min_length=1, max_length=100)
is_enabled: bool = Field(default=True, description="Whether to enable the language")
class LanguageUpdateRequest(BaseModel):
"""Request to update language configuration"""
is_enabled: Optional[bool] = Field(None, description="Whether the language is enabled")
is_default: Optional[bool] = Field(None, description="Whether this is the default language")
# ============================================================================
# Translation Management Schemas
# ============================================================================
class TranslationResponse(BaseModel):
"""Response for translation data"""
translations: Dict[str, Dict[str, Any]] = Field(
...,
description="Translations organized by locale and namespace"
)
class TranslationUpdateRequest(BaseModel):
"""Request to update a translation"""
value: str = Field(..., description="New translation value", min_length=1)
description: Optional[str] = Field(None, description="Optional description of the translation")
class MissingTranslationsResponse(BaseModel):
"""Response for missing translations"""
missing_translations: Dict[str, List[str]] = Field(
...,
description="Missing translation keys organized by locale"
)
class ReloadResponse(BaseModel):
"""Response for translation reload"""
success: bool = Field(..., description="Whether the reload was successful")
reloaded_locales: List[str] = Field(..., description="List of reloaded locales")
total_locales: int = Field(..., description="Total number of available locales")

View File

@@ -25,5 +25,6 @@ class AgentMemory_Long_Term(ABC):
STRATEGY_CHUNK = "chunk"
STRATEGY_TIME = "time"
DEFAULT_SCOPE = 6
TIME_SCOPE=5

View File

@@ -1,5 +1,4 @@
import uuid
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
@@ -85,7 +84,6 @@ class Semantic(BaseModel):
class Content(BaseModel):
summary: str
keywords: list[str]
topic: str
domain: str

View File

@@ -326,3 +326,14 @@ class ModelBaseQuery(BaseModel):
is_official: Optional[bool] = Field(None, description="是否官方模型")
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
class ModelInfo(BaseModel):
"""模型信息Schema"""
model_name: str = Field(..., description="模型名称")
provider: str = Field(..., description="模型提供商")
api_key: str = Field(..., description="API密钥")
api_base: str = Field(..., description="API基础URL")
is_omni: bool = Field(default=False, description="是否为omni模型")
model_type: ModelType = Field(..., description="模型类型")
capability: List[str] = Field(default_factory=list, description="模型能力列表")

View File

@@ -11,6 +11,8 @@ class TenantBase(BaseModel):
name: str = Field(..., description="租户名称", max_length=255)
description: Optional[str] = Field(None, description="租户描述", max_length=1000)
is_active: bool = Field(True, description="是否激活")
default_language: Optional[str] = Field('zh', description="租户默认语言", max_length=10)
supported_languages: Optional[List[str]] = Field(['zh', 'en'], description="租户支持的语言列表")
@field_validator('name')
@classmethod
@@ -18,6 +20,26 @@ class TenantBase(BaseModel):
if not v or not v.strip():
raise ValidationException('租户名称不能为空', code=BizCode.VALIDATION_FAILED)
return v.strip()
@field_validator('default_language')
@classmethod
def validate_default_language(cls, v):
if v:
# Validate language code format (2-letter code, optionally with region)
import re
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', v):
raise ValidationException('语言代码格式不正确', code=BizCode.VALIDATION_FAILED)
return v
@field_validator('supported_languages')
@classmethod
def validate_supported_languages(cls, v):
if v:
import re
for lang in v:
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', lang):
raise ValidationException(f'语言代码格式不正确: {lang}', code=BizCode.VALIDATION_FAILED)
return v
class TenantCreate(TenantBase):
@@ -30,6 +52,8 @@ class TenantUpdate(BaseModel):
name: Optional[str] = Field(None, description="租户名称", max_length=255)
description: Optional[str] = Field(None, description="租户描述", max_length=1000)
is_active: Optional[bool] = Field(None, description="是否激活")
default_language: Optional[str] = Field(None, description="租户默认语言", max_length=10)
supported_languages: Optional[List[str]] = Field(None, description="租户支持的语言列表")
@field_validator('name')
@classmethod
@@ -37,6 +61,25 @@ class TenantUpdate(BaseModel):
if v is not None and (not v or not v.strip()):
raise ValidationException('租户名称不能为空', code=BizCode.VALIDATION_FAILED)
return v.strip() if v else v
@field_validator('default_language')
@classmethod
def validate_default_language(cls, v):
if v:
import re
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', v):
raise ValidationException('语言代码格式不正确', code=BizCode.VALIDATION_FAILED)
return v
@field_validator('supported_languages')
@classmethod
def validate_supported_languages(cls, v):
if v:
import re
for lang in v:
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', lang):
raise ValidationException(f'语言代码格式不正确: {lang}', code=BizCode.VALIDATION_FAILED)
return v
class Tenant(TenantBase):
@@ -62,4 +105,29 @@ class TenantList(BaseModel):
total: int
page: int
size: int
pages: int
pages: int
class TenantLanguageConfig(BaseModel):
"""租户语言配置Schema"""
default_language: str = Field(..., description="租户默认语言", max_length=10)
supported_languages: List[str] = Field(..., description="租户支持的语言列表")
@field_validator('default_language')
@classmethod
def validate_default_language(cls, v):
import re
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', v):
raise ValidationException('语言代码格式不正确', code=BizCode.VALIDATION_FAILED)
return v
@field_validator('supported_languages')
@classmethod
def validate_supported_languages(cls, v):
if not v:
raise ValidationException('支持的语言列表不能为空', code=BizCode.VALIDATION_FAILED)
import re
for lang in v:
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', lang):
raise ValidationException(f'语言代码格式不正确: {lang}', code=BizCode.VALIDATION_FAILED)
return v

View File

@@ -58,6 +58,16 @@ class VerifyPasswordRequest(BaseModel):
password: str = Field(..., description="密码")
class LanguagePreferenceRequest(BaseModel):
"""语言偏好设置请求"""
language: str = Field(..., min_length=2, max_length=10, description="语言代码,如 'zh', 'en'")
class LanguagePreferenceResponse(BaseModel):
"""语言偏好响应"""
language: str = Field(..., description="当前语言偏好")
class ChangePasswordResponse(BaseModel):
"""修改密码响应"""
message: str
@@ -74,6 +84,7 @@ class User(UserBase):
current_workspace_id: Optional[uuid.UUID] = None
current_workspace_name: Optional[str] = None
role: Optional[WorkspaceRole] = None
preferred_language: Optional[str] = "zh" # 用户语言偏好
# 将 datetime 转换为毫秒时间戳
@validator("created_at", pre=True)

View File

@@ -8,25 +8,21 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig
from app.models import MultiAgentConfig, AgentConfig, ModelType
from app.models import WorkflowConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas import DraftRunRequest
from app.schemas.app_schema import FileInput
from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \
AgentRunService
from app.services.draft_run_service import create_web_search_tool
from app.services.draft_run_service import AgentRunService
from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.multimodal_service import MultimodalService
from app.services.tool_service import ToolService
from app.services.workflow_service import WorkflowService
logger = get_business_logger()
@@ -126,8 +122,17 @@ class AppChatService:
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
processed_files = await multimodal_service.process_files(files)
model_info = ModelInfo(
model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
api_key=api_key_obj.api_key,
api_base=api_key_obj.api_base,
capability=api_key_obj.capability,
is_omni=api_key_obj.is_omni,
model_type=ModelType.LLM
)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件")
# 调用 Agent支持多模态
@@ -266,8 +271,17 @@ class AppChatService:
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
processed_files = await multimodal_service.process_files(files)
model_info = ModelInfo(
model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
api_key=api_key_obj.api_key,
api_base=api_key_obj.api_base,
capability=api_key_obj.capability,
is_omni=api_key_obj.is_omni,
model_type=ModelType.LLM
)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件")
# 流式调用 Agent支持多模态

View File

@@ -75,7 +75,7 @@ class AudioTranscriptionService:
try:
# 下载音频文件
async with httpx.AsyncClient(timeout=60.0) as client:
audio_response = await client.get(audio_url)
audio_response = await client.get(audio_url, follow_redirects=True)
audio_response.raise_for_status()
audio_data = audio_response.content

View File

@@ -80,6 +80,7 @@ def authenticate_user_or_raise(db: Session, email: str, password: str) -> User:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.core.logging_config import get_auth_logger
from app.i18n.service import t
logger = get_auth_logger()
@@ -87,17 +88,17 @@ def authenticate_user_or_raise(db: Session, email: str, password: str) -> User:
user = user_repository.get_user_by_email(db, email=email)
if not user:
logger.warning(f"用户不存在: {email}")
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
# 检查用户状态
if not user.is_active:
logger.warning(f"用户未激活: {email}")
raise BusinessException("用户未激活", code=BizCode.USER_NOT_FOUND)
raise BusinessException(t("auth.login.account_disabled"), code=BizCode.USER_NOT_FOUND)
# 验证密码
if not verify_password(password, user.hashed_password):
logger.warning(f"密码错误: {email}")
raise BusinessException("密码错误", code=BizCode.PASSWORD_ERROR)
raise BusinessException(t("auth.password.incorrect"), code=BizCode.PASSWORD_ERROR)
logger.info(f"用户认证成功: {email}")
return user
@@ -254,6 +255,8 @@ def decode_access_token(token: str) -> dict:
Raises:
BusinessException: token 无效
"""
from app.i18n.service import t
try:
payload = jwt.decode(token, TOKEN_SECRET_KEY, algorithms=[TOKEN_ALGORITHM])
return {
@@ -261,4 +264,4 @@ def decode_access_token(token: str) -> dict:
"share_token": payload["share_token"]
}
except jwt.InvalidTokenError:
raise BusinessException("无效的访问 token", BizCode.INVALID_TOKEN)
raise BusinessException(t("auth.token.invalid"), BizCode.INVALID_TOKEN)

View File

@@ -23,9 +23,10 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.models import AgentConfig, ModelConfig
from app.models import AgentConfig, ModelConfig, ModelType
from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput
from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services import task_service
from app.services.conversation_service import ConversationService
@@ -501,9 +502,18 @@ class AgentRunService:
processed_files = None
if files:
# 获取 provider 信息
model_info = ModelInfo(
model_name=api_key_config["model_name"],
provider=api_key_config["provider"],
api_key=api_key_config["api_key"],
api_base=api_key_config["api_base"],
capability=api_key_config["capability"],
is_omni=api_key_config["is_omni"],
model_type=ModelType.LLM
)
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
processed_files = await multimodal_service.process_files(files)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索
@@ -704,9 +714,18 @@ class AgentRunService:
processed_files = None
if files:
# 获取 provider 信息
model_info = ModelInfo(
model_name=api_key_config["model_name"],
provider=api_key_config["provider"],
api_key=api_key_config["api_key"],
api_base=api_key_config["api_base"],
capability=api_key_config["capability"],
is_omni=api_key_config["is_omni"],
model_type=ModelType.LLM
)
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
processed_files = await multimodal_service.process_files(files)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索
@@ -841,7 +860,8 @@ class AgentRunService:
"api_key": api_key.api_key,
"api_base": api_key.api_base,
"api_key_id": api_key.id,
"is_omni": api_key.is_omni
"is_omni": api_key.is_omni,
"capability": api_key.capability
}
async def _ensure_conversation(

View File

@@ -274,7 +274,7 @@ class MemoryAgentService:
Args:
end_user_id: Group identifier (also used as end_user_id)
message: Message to write
messages: Message to write
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)

View File

@@ -1,19 +1,27 @@
import os
import uuid
from typing import Dict, Any, Optional
from urllib.parse import urlparse, unquote
import json_repair
from jinja2 import Template
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
from app.models.prompt_optimizer_model import RoleType
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
from app.schemas import FileType
from app.schemas.memory_perceptual_schema import (
PerceptualQuerySchema,
PerceptualTimelineResponse,
PerceptualMemoryItem,
AudioModal, Content, VideoModal, TextModal
)
from app.schemas.model_schema import ModelInfo
business_logger = get_business_logger()
@@ -99,7 +107,7 @@ class MemoryPerceptualService:
"keywords": content.keywords,
"topic": content.topic,
"domain": content.domain,
"created_time": int(memory.created_time.timestamp()*1000),
"created_time": int(memory.created_time.timestamp() * 1000),
**detail
}
@@ -108,7 +116,8 @@ class MemoryPerceptualService:
return result
except Exception as e:
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}")
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
exc_info=True)
raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
BizCode.DB_ERROR)
@@ -138,7 +147,7 @@ class MemoryPerceptualService:
for memory in memories:
meta_data = memory.meta_data or {}
content = meta_data.get("content", {})
# 安全地提取 content 字段,提供默认值
if content:
content_obj = Content(**content)
@@ -149,7 +158,7 @@ class MemoryPerceptualService:
topic = "Unknown"
domain = "Unknown"
keywords = []
memory_item = PerceptualMemoryItem(
id=memory.id,
perceptual_type=PerceptualType(memory.perceptual_type),
@@ -161,7 +170,7 @@ class MemoryPerceptualService:
topic=topic,
domain=domain,
keywords=keywords,
created_time=int(memory.created_time.timestamp()*1000),
created_time=int(memory.created_time.timestamp() * 1000),
storage_service=FileStorageService(memory.storage_service),
)
memory_items.append(memory_item)
@@ -183,3 +192,98 @@ class MemoryPerceptualService:
except Exception as e:
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
async def generate_perceptual_memory(
self,
end_user_id: str,
model_config: ModelInfo,
file_type: str,
file_url: str,
file_message: dict,
):
memories = self.repository.get_by_url(file_url)
if memories:
business_logger.info(f"Perceptual memory already exists: {file_url}")
if end_user_id not in [memory.end_user_id for memory in memories]:
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
memory_cache = memories[0]
self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType(memory_cache.perceptual_type),
file_path=memory_cache.file_path,
file_name=memory_cache.file_name,
file_ext=memory_cache.file_ext,
summary=memory_cache.summary,
meta_data=memory_cache.meta_data
)
self.db.commit()
return
llm = RedBearLLM(RedBearModelConfig(
model_name=model_config.model_name,
provider=model_config.provider,
api_key=model_config.api_key,
base_url=model_config.api_base,
is_omni=model_config.is_omni
), type=model_config.model_type)
try:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
messages = [
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
{"role": RoleType.USER.value, "content": [
{"type": "text", "text": "Summarize the following file"}, file_message
]}
]
result = await llm.ainvoke(messages)
content = json_repair.repair_json(result.content, return_objects=True)
path = urlparse(file_url).path
filename = os.path.basename(path)
filename = unquote(filename)
file_ext = os.path.splitext(filename)[1]
if not file_ext:
if file_type == FileType.AUDIO:
file_ext = ".mp3"
elif file_type == FileType.VIDEO:
file_ext = ".mp4"
elif file_type == FileType.DOCUMENT:
file_ext = ".txt"
elif file_type == FileType.IMAGE:
file_ext = ".jpg"
filename += file_ext
file_content = {
"keywords": content.get("keywords", []),
"topic": content.get("topic"),
"domain": content.get("domain")
}
if file_type in [FileType.IMAGE, FileType.VIDEO]:
file_modalities = {
"scene": content.get("scene")
}
elif file_type in [FileType.DOCUMENT]:
file_modalities = {
"section_count": content.get("section_count"),
"title": content.get("title"),
"first_line": content.get("first_line")
}
else:
file_modalities = {
"speaker_count": content.get("speaker_count")
}
self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType.trans_from_file_type(file_type),
file_path=file_url,
file_name=filename,
file_ext=file_ext,
summary=content.get('summary'),
meta_data={
"content": file_content,
"modalities": file_modalities
}
)
self.db.commit()

View File

@@ -10,6 +10,7 @@
"""
import base64
import io
import uuid
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
@@ -23,9 +24,12 @@ from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.models import ModelApiKey
from app.models.file_metadata_model import FileMetadata
from app.schemas.app_schema import FileInput, FileType, TransferMethod
from app.schemas.model_schema import ModelInfo
from app.services.audio_transcription_service import AudioTranscriptionService
from app.tasks import write_perceptual_memory
logger = get_business_logger()
@@ -39,6 +43,7 @@ DOC_MIME = [
class MultimodalFormatStrategy(ABC):
"""多模态格式策略基类"""
def __init__(self, file: FileInput):
self.file = file
@@ -95,7 +100,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
if transcription:
return {
"type": "text",
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
"text": f"<audio url=\"{url}\">\ntext_transcription:{transcription}\n</audio>"
}
# 通义千问音频格式:{"type": "audio", "audio": "url"}
return {
@@ -125,7 +130,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
# 下载图片
if content is None:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response = await client.get(url, follow_redirects=True)
response.raise_for_status()
content = response.content
self.file.set_content(content)
@@ -231,7 +236,7 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
audio_data = content
if content is None:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response = await client.get(url, follow_redirects=True)
response.raise_for_status()
audio_data = response.content
self.file.set_content(audio_data)
@@ -284,34 +289,56 @@ PROVIDER_STRATEGIES = {
class MultimodalService:
"""多模态文件处理服务"""
"""
Service for handling multimodal file processing.
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None,
enable_audio_transcription: bool = False, is_omni: bool = False):
Attributes:
db (Session): Database session.
model_api_key (str): API key for the model provider.
provider (str): Name of the model provider.
is_omni (bool): Indicates whether the model supports full multimodal capability.
capability (list): Capability configuration of the model.
audio_api_key (str | None): API key used for audio transcription.
enable_audio_transcription (bool): Whether audio transcription is enabled.
"""
def __init__(
self,
db: Session,
api_config: ModelInfo | None = None,
audio_api_key: Optional[str] = None,
enable_audio_transcription: bool = False,
):
"""
初始化多模态服务
Initialize the multimodal service.
Args:
db: 数据库会话
provider: 模型提供商dashscope, bedrock, anthropic, openai 等)
api_key: API 密钥(用于音频转文本)
enable_audio_transcription: 是否启用音频转文本
is_omni: 是否为 Omni 模型dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
db (Session): Database session.
api_config (ModelApiKey | None): Model API configuration.
audio_api_key (str | None): API key for audio transcription.
enable_audio_transcription (bool): Enable audio transcription.
"""
self.db = db
self.provider = provider.lower()
self.api_key = api_key
self.api_config = api_config
if self.api_config is not None:
self.model_api_key = api_config.api_key
self.provider = api_config.provider.lower()
self.is_omni = api_config.is_omni
self.capability = api_config.capability
self.audio_api_key = audio_api_key
self.enable_audio_transcription = enable_audio_transcription
self.is_omni = is_omni
async def process_files(
self,
files: Optional[List[FileInput]]
end_user_id: uuid.UUID | str,
files: Optional[List[FileInput]],
) -> List[Dict[str, Any]]:
"""
处理文件列表,返回 LLM 可用的格式
Args:
end_user_id: 用户ID
files: 文件输入列表
Returns:
@@ -319,6 +346,8 @@ class MultimodalService:
"""
if not files:
return []
if isinstance(end_user_id, uuid.UUID):
end_user_id = str(end_user_id)
# 获取对应的策略
# dashscope 的 omni 模型使用 OpenAI 兼容格式
@@ -333,19 +362,25 @@ class MultimodalService:
result = []
for idx, file in enumerate(files):
strategy = strategy_class(file)
if not file.url:
file.url = await self.get_file_url(file)
try:
if file.type == FileType.IMAGE:
if file.type == FileType.IMAGE and "vision" in self.capability:
content = await self._process_image(file, strategy)
result.append(content)
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.DOCUMENT:
content = await self._process_document(file, strategy)
result.append(content)
elif file.type == FileType.AUDIO:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.AUDIO and "audio" in self.capability:
content = await self._process_audio(file, strategy)
result.append(content)
elif file.type == FileType.VIDEO:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.VIDEO and "video" in self.capability:
content = await self._process_video(file, strategy)
result.append(content)
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
else:
logger.warning(f"不支持的文件类型: {file.type}")
except Exception as e:
@@ -355,7 +390,8 @@ class MultimodalService:
"file_index": idx,
"file_type": file.type,
"error": str(e)
}
},
exc_info=True
)
# 继续处理其他文件,不中断整个流程
result.append({
@@ -366,6 +402,17 @@ class MultimodalService:
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result
def write_perceptual_memory(
self,
end_user_id: str,
file_type: str,
file_url: str,
file_message: dict
):
"""写入感知记忆"""
if end_user_id and self.api_config:
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理图片文件
@@ -387,43 +434,6 @@ class MultimodalService:
"text": f"[图片处理失败: {str(e)}]"
}
@staticmethod
async def _download_and_encode_image(url: str) -> tuple[str, str]:
"""
下载图片并转换为 base64
Args:
url: 图片 URL
Returns:
tuple: (base64_data, media_type)
"""
from mimetypes import guess_type
# 下载图片
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
# 获取图片数据
image_data = response.content
# 确定 media type
content_type = response.headers.get("content-type")
if content_type and content_type.startswith("image/"):
media_type = content_type
else:
# 从 URL 推断
guessed_type, _ = guess_type(url)
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
# 转换为 base64
base64_data = base64.b64encode(image_data).decode("utf-8")
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
return base64_data, media_type
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理文档文件PDF、Word 等)
@@ -436,7 +446,6 @@ class MultimodalService:
Dict: 根据 provider 返回不同格式的文档内容
"""
if file.transfer_method == TransferMethod.REMOTE_URL:
# 远程文档暂不支持提取
return {
"type": "text",
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
@@ -471,12 +480,12 @@ class MultimodalService:
# 如果启用音频转文本且有 API Key
transcription = None
if self.enable_audio_transcription and self.api_key:
if self.enable_audio_transcription and self.audio_api_key:
logger.info(f"开始音频转文本: {url}")
if self.provider == "dashscope":
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.audio_api_key)
elif self.provider == "openai":
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
transcription = await AudioTranscriptionService.transcribe_openai(url, self.audio_api_key)
else:
logger.warning(f"Provider {self.provider} 不支持音频转文本")
@@ -557,7 +566,7 @@ class MultimodalService:
file_content = file.get_content()
if not file_content:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file.url)
response = await client.get(file.url, follow_redirects=True)
response.raise_for_status()
file_content = response.content
file.set_content(file_content)

View File

@@ -0,0 +1,53 @@
{% raw %}You are a professional information extraction system.
Your task is to analyze the provided document content and generate structured metadata.
Extract the following fields:
* **summary**: A concise summary of the document in 24 sentences.
* **keywords**: 510 important keywords or key phrases that best represent the document. This field MUST be a JSON array of strings.
* **topic**: The primary topic of the document expressed as a short phrase (38 words).
* **domain**: The broader knowledge domain or field the document belongs to (e.g., Artificial Intelligence, Computer Science, Finance, Healthcare, Education, Law, etc.).
STRICT RULES:
1. Output MUST be valid JSON.
2. Do NOT output markdown.
3. Do NOT output explanations.
4. Do NOT output any text before or after the JSON.
5. The JSON MUST contain EXACTLY these four keys:
* summary
* keywords
* topic
* domain{% endraw %}
{% if file_type == 'image' or file_type == 'video' %} * scene {% endif %}
{% if file_type == 'audio' %} * speaker_count {% endif %}
{% if file_type == 'document' %} * section_count
* title
* first_line
{% endif %}
{% raw %}
6. `keywords` MUST be a JSON array of strings.
7. If the document content is insufficient, infer the best possible answer based on context.
8. Ensure the JSON is syntactically correct.
{% endraw %}
9. Output using the language {{ language }}
{% raw %}
Required JSON format:
{
"summary": "string",
"keywords": ["keyword1", "keyword2", "keyword3", "keyword4", "keyword5"],
"topic": "string",
"domain": "string",
{% endraw %}
{% if file_type == 'image' or file_type == 'video' %} "scene": ["string", "string"] {% endif %}
{% if file_type == 'document' %} "section_count": integer
"title": "string",
"first_line": "string"
{% endif %}
{% if file_type == 'audio' %} "speaker_count": integer {% endif %}
{% raw %}
}
Now analyze the following document and return the JSON result.{% endraw %}

View File

@@ -217,4 +217,55 @@ class TenantService:
skip=skip,
limit=limit,
is_active=is_active
)
)
def get_tenant_language_config(self, tenant_id: uuid.UUID) -> Optional[dict]:
"""获取租户语言配置"""
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
if not tenant:
raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND)
return {
"default_language": tenant.default_language,
"supported_languages": tenant.supported_languages
}
def update_tenant_language_config(
self,
tenant_id: uuid.UUID,
default_language: str,
supported_languages: list
) -> Optional[dict]:
"""更新租户语言配置"""
# 检查租户是否存在
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
if not tenant:
raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND)
# 验证默认语言在支持的语言列表中
if default_language not in supported_languages:
raise BusinessException(
"默认语言必须在支持的语言列表中",
code=BizCode.VALIDATION_FAILED
)
try:
# 更新语言配置
tenant.default_language = default_language
tenant.supported_languages = supported_languages
self.db.commit()
self.db.refresh(tenant)
business_logger.info(
f"更新租户语言配置成功: {tenant.name} (ID: {tenant.id}), "
f"默认语言: {default_language}, 支持语言: {supported_languages}"
)
return {
"default_language": tenant.default_language,
"supported_languages": tenant.supported_languages
}
except Exception as e:
self.db.rollback()
business_logger.error(f"更新租户语言配置失败: {str(e)}")
raise BusinessException(f"更新租户语言配置失败: {str(e)}", code=BizCode.DB_ERROR)

View File

@@ -438,24 +438,26 @@ def update_last_login_time(db: Session, user_id: uuid.UUID) -> User:
async def change_password(db: Session, user_id: uuid.UUID, old_password: str, new_password: str, current_user: User) -> User:
"""普通用户修改自己的密码"""
from app.i18n.service import t
business_logger.info(f"用户修改密码请求: user_id={user_id}, current_user={current_user.id}")
# 检查权限:只能修改自己的密码
if current_user.id != user_id:
business_logger.warning(f"用户尝试修改他人密码: current_user={current_user.id}, target_user={user_id}")
raise PermissionDeniedException("You can only change your own password")
raise PermissionDeniedException(t("auth.password.change_failed"))
try:
# 获取用户
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
if not db_user:
business_logger.warning(f"用户不存在: {user_id}")
raise BusinessException("User not found", code=BizCode.USER_NOT_FOUND)
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
# 验证旧密码
if not verify_password(old_password, db_user.hashed_password):
business_logger.warning(f"用户旧密码验证失败: {user_id}")
raise BusinessException("当前密码不正确", code=BizCode.VALIDATION_FAILED)
raise BusinessException(t("auth.password.incorrect"), code=BizCode.VALIDATION_FAILED)
# 更新密码
db_user.hashed_password = get_password_hash(new_password)
@@ -471,7 +473,7 @@ async def change_password(db: Session, user_id: uuid.UUID, old_password: str, ne
except Exception as e:
business_logger.error(f"修改用户密码失败: user_id={user_id} - {str(e)}")
db.rollback()
raise BusinessException(f"修改用户密码失败: user_id={user_id} - {str(e)}", code=BizCode.DB_ERROR)
raise BusinessException(t("auth.password.change_failed"), code=BizCode.DB_ERROR)
async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_password: str = None, current_user: User = None) -> tuple[User, str]:
@@ -487,6 +489,8 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
Returns:
tuple[User, str]: (更新后的用户对象, 实际使用的密码)
"""
from app.i18n.service import t
business_logger.info(f"管理员修改用户密码请求: admin={current_user.id}, target_user={target_user_id}")
# 检查权限:只有超级管理员可以修改他人密码
@@ -496,7 +500,7 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
try:
permission_service.check_superuser(
subject,
error_message="只有超级管理员可以修改他人密码"
error_message=t("auth.password.change_failed")
)
except PermissionDeniedException as e:
business_logger.warning(f"非超管用户尝试修改他人密码: current_user={current_user.id}")
@@ -507,12 +511,12 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
target_user = user_repository.get_user_by_id(db=db, user_id=target_user_id)
if not target_user:
business_logger.warning(f"目标用户不存在: {target_user_id}")
raise BusinessException("目标用户不存在", code=BizCode.USER_NOT_FOUND)
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
# 检查租户权限:超管只能修改同租户用户的密码
if current_user.tenant_id != target_user.tenant_id:
business_logger.warning(f"跨租户密码修改尝试: admin_tenant={current_user.tenant_id}, target_tenant={target_user.tenant_id}")
raise BusinessException("不可跨租户修改用户密码", code=BizCode.FORBIDDEN)
raise BusinessException(t("auth.password.change_failed"), code=BizCode.FORBIDDEN)
# 如果没有提供新密码,则生成随机密码
actual_password = new_password if new_password else generate_random_password()
@@ -532,7 +536,7 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
except Exception as e:
business_logger.error(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}")
db.rollback()
raise BusinessException(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}", code=BizCode.DB_ERROR)
raise BusinessException(t("auth.password.change_failed"), code=BizCode.DB_ERROR)
def generate_random_password(length: int = 12) -> str:
@@ -740,3 +744,54 @@ async def verify_and_change_email(db: Session, user_id: uuid.UUID, new_email: Em
#
# business_logger.info(f"用户邮箱修改成功: {db_user.username}, new_email={new_email}")
# return db_user
def get_user_language_preference(db: Session, user_id: uuid.UUID, current_user: User) -> str:
"""获取用户语言偏好"""
business_logger.info(f"获取用户语言偏好: user_id={user_id}")
# 权限检查:只能获取自己的语言偏好
if current_user.id != user_id:
raise PermissionDeniedException("只能获取自己的语言偏好")
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
if not db_user:
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
language = db_user.preferred_language or "zh"
business_logger.info(f"用户语言偏好: {db_user.username}, language={language}")
return language
def update_user_language_preference(
db: Session,
user_id: uuid.UUID,
language: str,
current_user: User
) -> User:
"""更新用户语言偏好"""
business_logger.info(f"更新用户语言偏好: user_id={user_id}, language={language}")
# 权限检查:只能修改自己的语言偏好
if current_user.id != user_id:
raise PermissionDeniedException("只能修改自己的语言偏好")
# 验证语言代码是否支持
from app.core.config import settings
if language not in settings.I18N_SUPPORTED_LANGUAGES:
raise BusinessException(
f"不支持的语言代码: {language}。支持的语言: {', '.join(settings.I18N_SUPPORTED_LANGUAGES)}",
code=BizCode.VALIDATION_FAILED
)
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
if not db_user:
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
# 更新语言偏好
db_user.preferred_language = language
db.commit()
db.refresh(db_user)
business_logger.info(f"用户语言偏好更新成功: {db_user.username}, language={language}")
return db_user

View File

@@ -1,6 +1,5 @@
import asyncio
import json
import logging
import hashlib
import os
import re
import shutil
@@ -11,20 +10,48 @@ from datetime import datetime, timezone
from math import ceil
from pathlib import Path
from typing import Any, Dict, List, Optional
from uuid import UUID
import redis
import requests
from redis.exceptions import RedisError
logger = logging.getLogger(__name__)
# Import a unified Celery instance
from app.celery_app import celery_app
from app.core.config import settings
from app.core.logging_config import get_logger
from app.core.rag.crawler.web_crawler import WebCrawler
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
from app.core.rag.integrations.feishu.client import FeishuAPIClient
from app.core.rag.integrations.feishu.models import FileInfo
from app.core.rag.integrations.yuque.client import YuqueAPIClient
from app.core.rag.integrations.yuque.models import YuqueDocInfo
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.cv_model import QWenCV
from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.prompts.generator import question_proposal
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchVectorFactory,
)
from app.db import get_db, get_db_context
from app.models import Document, File, Knowledge
from app.schemas import document_schema, file_schema
from app.schemas.model_schema import ModelInfo
from app.services.memory_agent_service import MemoryAgentService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.utils.config_utils import resolve_config_id
from app.utils.redis_lock import RedisLock
logger = get_logger(__name__)
# 模块级同步 Redis 连接池,供 Celery 任务共享使用
# 连接 CELERY_BACKEND DB与 write_message:last_done 时间戳写入保持一致
# 使用连接池而非单例客户端,提供更好的并发性能和自动重连
_sync_redis_pool: redis.ConnectionPool = None
_sync_redis_pool: redis.ConnectionPool | None = None
def _get_or_create_redis_pool() -> redis.ConnectionPool:
def _get_or_create_redis_pool() -> redis.ConnectionPool | None:
"""获取或创建 Redis 连接池(懒初始化)"""
global _sync_redis_pool
if _sync_redis_pool is None:
@@ -47,6 +74,7 @@ def _get_or_create_redis_pool() -> redis.ConnectionPool:
return None
return _sync_redis_pool
def get_sync_redis_client() -> Optional[redis.StrictRedis]:
"""获取同步 Redis 客户端(使用连接池)
@@ -60,7 +88,7 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
pool = _get_or_create_redis_pool()
if pool is None:
return None
client = redis.StrictRedis(connection_pool=pool)
# 验证连接可用性
client.ping()
@@ -72,32 +100,18 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
logger.error(f"Unexpected error getting Redis client: {e}", exc_info=True)
return None
# Import a unified Celery instance
from app.celery_app import celery_app
from app.core.config import settings
from app.core.rag.crawler.web_crawler import WebCrawler
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
from app.core.rag.integrations.feishu.client import FeishuAPIClient
from app.core.rag.integrations.feishu.models import FileInfo
from app.core.rag.integrations.yuque.client import YuqueAPIClient
from app.core.rag.integrations.yuque.models import YuqueDocInfo
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.cv_model import QWenCV
from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.prompts.generator import question_proposal
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchVectorFactory,
)
from app.db import get_db, get_db_context
from app.models.document_model import Document
from app.models.file_model import File
from app.models.knowledge_model import Knowledge
from app.schemas import document_schema, file_schema
from app.services.memory_agent_service import MemoryAgentService
from app.utils.config_utils import resolve_config_id
def set_asyncio_event_loop():
"""Set the asyncio event loop for the current thread."""
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
@celery_app.task(name="tasks.process_item")
@@ -294,9 +308,18 @@ def parse_document(file_path: str, document_id: uuid.UUID):
vector_size = len(vts[0])
init_graphrag(task, vector_size)
async def _run(row: dict, document_ids: list[str], language: str, parser_config: dict, vector_service,
chat_model, embedding_model, callback, with_resolution: bool = True,
with_community: bool = True, ) -> dict:
async def _run(
row: dict,
document_ids: list[str],
language: str,
parser_config: dict,
vector_service,
chat_model,
embedding_model,
callback,
with_resolution: bool = True,
with_community: bool = True
) -> dict:
await trio.sleep(5) # Delay for 10 seconds
nonlocal progress_msg # Declare the use of an external progress_msg variable
result = await run_graphrag_for_kb(
@@ -329,6 +352,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
with_community=with_community,
)
)
try:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(sync_task)
@@ -448,6 +472,7 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
with_community=with_community,
)
)
try:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(sync_task)
@@ -1002,29 +1027,21 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
# Log but continue - will fail later with proper error
pass
async def _run() -> str:
async def _run() -> dict:
with get_db_context() as db:
service = MemoryAgentService()
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db,
storage_type, user_rag_memory_id)
return await service.read_memory(
end_user_id,
message,
history,
search_switch,
actual_config_id, db,
storage_type, user_rag_memory_id
)
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
@@ -1056,7 +1073,8 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, user_rag_memory_id: str,
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str,
user_rag_memory_id: str,
language: str = "zh") -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService.
Args:
@@ -1073,10 +1091,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
Raises:
Exception on failure
"""
from app.core.logging_config import get_logger
logger = get_logger(__name__)
logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id} (type: {type(config_id).__name__}), storage_type={storage_type}, language={language}")
logger.info(
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
f"config_id={config_id} (type: {type(config_id).__name__}), "
f"storage_type={storage_type}, language={language}")
start_time = time.time()
# Convert config_id to UUID
@@ -1086,13 +1105,14 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
try:
with get_db_context() as db:
actual_config_id = resolve_config_id(config_id, db)
print(100*'-')
print(100 * '-')
print(actual_config_id)
print(100*'-')
print(100 * '-')
logger.info(
f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
except (ValueError, AttributeError) as e:
logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}")
logger.error(
f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}")
return {
"status": "FAILURE",
"error": f"Invalid config_id format: {config_id} - {str(e)}",
@@ -1116,7 +1136,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
async def _run() -> str:
with get_db_context() as db:
logger.info(
f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
service = MemoryAgentService()
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type,
user_rag_memory_id, language)
@@ -1124,22 +1145,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
return result
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
@@ -1193,28 +1200,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
}
def reflection_engine() -> None:
"""Empty function placeholder for timed background reflection.
Intentionally left blank; replace with real reflection logic later.
"""
import asyncio
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
asyncio.run(self_reflexion(host_id))
@celery_app.task(name="app.core.memory.agent.reflection.timer")
def reflection_timer_task() -> None:
"""Periodic Celery task that invokes reflection_engine.
Raises an exception on failure.
"""
reflection_engine()
# unused task
# @celery_app.task(name="app.core.memory.agent.health.check_read_service")
# def check_read_service_task() -> Dict[str, str]:
@@ -1368,6 +1353,8 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
"workspace_id": workspace_id,
"elapsed_time": elapsed_time,
}
@celery_app.task(
name="app.tasks.write_all_workspaces_memory_task",
bind=True,
@@ -1391,15 +1378,12 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_api_logger
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.models.workspace_model import Workspace
from app.repositories.memory_increment_repository import write_memory_increment
from app.services.memory_storage_service import search_all
api_logger = get_api_logger()
with get_db_context() as db:
try:
# 获取所有活跃的工作空间
@@ -1408,7 +1392,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
).all()
if not workspaces:
api_logger.warning("没有找到活跃的工作空间")
logger.warning("没有找到活跃的工作空间")
return {
"status": "SUCCESS",
"message": "没有找到活跃的工作空间",
@@ -1416,13 +1400,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
"workspace_results": []
}
api_logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
all_workspace_results = []
# 遍历每个工作空间
for workspace in workspaces:
workspace_id = workspace.id
api_logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
try:
# 1. 查询当前workspace下的所有app仅未删除的
@@ -1447,7 +1431,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
"memory_increment_id": str(memory_increment.id),
"created_at": memory_increment.created_at.isoformat(),
})
api_logger.info(f"工作空间 {workspace.name} 没有应用记录总量为0")
logger.info(f"工作空间 {workspace.name} 没有应用记录总量为0")
continue
# 2. 查询所有app下的end_user_id去重
@@ -1472,7 +1456,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
})
except Exception as e:
# 记录单个用户查询失败,但继续处理其他用户
api_logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
end_user_details.append({
"end_user_id": str(end_user_id),
"total": 0,
@@ -1496,13 +1480,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
"created_at": memory_increment.created_at.isoformat(),
})
api_logger.info(
logger.info(
f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}"
)
except Exception as e:
db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间
api_logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
all_workspace_results.append({
"workspace_id": str(workspace_id),
"workspace_name": workspace.name,
@@ -1525,7 +1509,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
}
except Exception as e:
api_logger.error(f"记忆增量统计任务执行失败: {str(e)}")
logger.error(f"记忆增量统计任务执行失败: {str(e)}")
return {
"status": "FAILURE",
"error": str(e),
@@ -1534,22 +1518,8 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
}
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
@@ -1597,11 +1567,9 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.repositories.end_user_repository import EndUserRepository
from app.services.user_memory_service import UserMemoryService
logger = get_logger(__name__)
logger.info("开始执行记忆缓存重新生成定时任务")
service = UserMemoryService()
@@ -1734,22 +1702,8 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
}
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
@@ -1785,15 +1739,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_api_logger
from app.models.workspace_model import Workspace
from app.services.memory_reflection_service import (
MemoryReflectionService,
WorkspaceAppService,
)
api_logger = get_api_logger()
with get_db_context() as db:
try:
# 获取所有工作空间
@@ -1812,7 +1763,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
# 遍历每个工作空间
for workspace in workspaces:
workspace_id = workspace.id
api_logger.info(f"开始处理工作空间反思workspace_id: {workspace_id}")
logger.info(f"开始处理工作空间反思workspace_id: {workspace_id}")
try:
reflection_service = MemoryReflectionService(db)
@@ -1824,7 +1775,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
workspace_reflection_results = []
for data in result['apps_detailed_info']:
if data['memory_configs'] == []:
if not data['memory_configs']:
continue
releases = data['releases']
@@ -1835,7 +1786,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str(
user['app_id']):
# 调用反思服务
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config['config_id']}")
logger.info(f"为用户 {user['id']} 启动反思config_id: {config['config_id']}")
reflection_result = await reflection_service.start_reflection_from_data(
config_data=config,
@@ -1855,12 +1806,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
"reflection_results": workspace_reflection_results
})
api_logger.info(
logger.info(
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
except Exception as e:
db.rollback() # Rollback failed transaction to allow next query
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
all_reflection_results.append({
"workspace_id": str(workspace_id),
"error": str(e),
@@ -1879,7 +1830,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
}
except Exception as e:
api_logger.error(f"工作空间反思任务执行失败: {str(e)}")
logger.error(f"工作空间反思任务执行失败: {str(e)}")
return {
"status": "FAILURE",
"error": str(e),
@@ -1888,22 +1839,8 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
}
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
@@ -1944,18 +1881,16 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_api_logger
from app.services.memory_forget_service import MemoryForgetService
api_logger = get_api_logger()
with get_db_context() as db:
try:
api_logger.info(f"开始执行遗忘周期定时任务config_id: {config_id}")
logger.info(f"开始执行遗忘周期定时任务config_id: {config_id}")
forget_service = MemoryForgetService()
# 运行遗忘周期
# FIXME: MemeoryForgetService
report = await forget_service.trigger_forgetting(
db=db,
end_user_id=None, # 处理所有组
@@ -1964,7 +1899,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
duration = time.time() - start_time
api_logger.info(
logger.info(
f"遗忘周期定时任务完成: "
f"融合 {report['merged_count']} 对节点, "
f"失败 {report['failed_count']} 对, "
@@ -1980,7 +1915,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
except Exception as e:
duration = time.time() - start_time
api_logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
return {
"status": "FAILED",
@@ -1997,6 +1932,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
finally:
loop.close()
# =============================================================================
# Long-term Memory Storage Tasks (Batched Write Strategies)
# =============================================================================
@@ -2222,9 +2158,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
start_time = time.time()
async def _run() -> Dict[str, Any]:
from sqlalchemy import func, select
from sqlalchemy import select
from app.core.logging_config import get_logger
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
from app.repositories.implicit_emotions_storage_repository import (
ImplicitEmotionsStorageRepository,
@@ -2233,7 +2168,6 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.services.implicit_memory_service import ImplicitMemoryService
logger = get_logger(__name__)
logger.info("开始执行隐性记忆和情绪数据更新定时任务")
total_users = 0
@@ -2267,7 +2201,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
for end_user_id in refresh_iter:
logger.info(f"开始处理用户: {end_user_id}")
user_start_time = time.time()
implicit_success = False
emotion_success = False
errors = []
@@ -2318,7 +2252,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
failed += 1
user_elapsed = time.time() - user_start_time
# 记录用户处理结果
user_result = {
"end_user_id": end_user_id,
@@ -2460,22 +2394,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
}
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
@@ -2521,14 +2441,12 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str,
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.repositories.implicit_emotions_storage_repository import (
ImplicitEmotionsStorageRepository,
)
from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.services.implicit_memory_service import ImplicitMemoryService
logger = get_logger(__name__)
logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}")
initialized = 0
@@ -2587,20 +2505,7 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str,
}
try:
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
@@ -2633,6 +2538,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
默认生成中文zh兴趣分布数据。
Args:
self: task object
end_user_ids: 需要检查的用户ID列表
Returns:
@@ -2641,11 +2547,9 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE
from app.services.memory_agent_service import MemoryAgentService
logger = get_logger(__name__)
logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}")
initialized = 0
@@ -2694,20 +2598,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
}
try:
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
@@ -2720,3 +2611,54 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
"elapsed_time": time.time() - start_time,
"task_id": self.request.id,
}
@celery_app.task(
name="app.tasks.write_perceptual_memory",
bind=True,
ignore_result=True,
max_retries=0,
acks_late=False,
time_limit=3600,
soft_time_limit=3300,
)
def write_perceptual_memory(
self,
end_user_id: str,
model_api_config: dict,
file_type: str,
file_url: str,
file_message: dict
):
"""
Write perceptual memory for a user into PostgreSQL and Neo4j.
This task generates or updates the user's perceptual memory
in the backend databases. It is intended to be executed asynchronously
via Celery.
Args:
end_user_id (uuid.UUID): The unique identifier of the end user.
model_api_config (ModelInfo): API configuration for the model
used to generate perceptual memory.
file_type (str): The file type
file_url (url): The url of file
file_message (dict): The file message containing details about the file
to be processed.
Returns:
None
"""
file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
set_asyncio_event_loop()
with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
model_info = ModelInfo(**model_api_config)
with get_db_context() as db:
memory_perceptual_service = MemoryPerceptualService(db)
return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
end_user_id,
model_info,
file_type,
file_url,
file_message,
))

View File

@@ -0,0 +1,61 @@
import redis
import uuid
import time
UNLOCK_SCRIPT = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
"""
class RedisLock:
def __init__(
self,
key: str,
redis_client: redis.StrictRedis,
expire: int = 60,
retry_interval: float = 0.1,
timeout: float = 30
):
self.key = key
self.expire = expire
self.value = str(uuid.uuid4())
self._locked = False
self.retry_interval = retry_interval
self.timeout = timeout
self.redis_client = redis_client
def acquire(self) -> bool:
start = time.time()
while True:
ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True)
if ok:
self._locked = True
return True
if time.time() - start >= self.timeout:
return False
time.sleep(self.retry_interval)
def release(self):
if not self._locked:
return
self.redis_client.eval(
UNLOCK_SCRIPT,
1,
self.key,
self.value
)
self._locked = False
def __enter__(self):
ok = self.acquire()
if not ok:
raise RuntimeError(f"Get redis lock timeout: {self.key}")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()

View File

@@ -0,0 +1,38 @@
"""202603131028
Revision ID: 01587a13522f
Revises: fb834419b18f
Create Date: 2026-03-13 10:28:43.601370
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '01587a13522f'
down_revision: Union[str, None] = 'fb834419b18f'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('tenants', sa.Column('default_language', sa.String(length=10), server_default='zh', nullable=False))
op.add_column('tenants', sa.Column('supported_languages', postgresql.ARRAY(sa.String(length=10)), server_default=sa.text("'{zh,en}'"), nullable=False))
op.create_index(op.f('ix_tenants_default_language'), 'tenants', ['default_language'], unique=False)
op.add_column('users', sa.Column('preferred_language', sa.String(length=10), server_default=sa.text("'zh'"), nullable=False))
op.create_index(op.f('ix_users_preferred_language'), 'users', ['preferred_language'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_users_preferred_language'), table_name='users')
op.drop_column('users', 'preferred_language')
op.drop_index(op.f('ix_tenants_default_language'), table_name='tenants')
op.drop_column('tenants', 'supported_languages')
op.drop_column('tenants', 'default_language')
# ### end Alembic commands ###

View File

@@ -0,0 +1,30 @@
"""202603131452
Revision ID: ea31b4e347d8
Revises: 01587a13522f
Create Date: 2026-03-13 14:53:20.587580
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'ea31b4e347d8'
down_revision: Union[str, None] = '01587a13522f'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('app_shares', sa.Column('permission', sa.String(), nullable=False, comment='权限模式: readonly | editable'))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('app_shares', 'permission')
# ### end Alembic commands ###

View File

@@ -36,6 +36,7 @@
"codemirror": "^6.0.2",
"copy-to-clipboard": "^3.3.3",
"crypto-js": "^4.2.0",
"d3": "^7.9.0",
"dayjs": "^1.11.18",
"echarts": "^5.6.0",
"echarts-for-react": "^3.0.2",
@@ -67,6 +68,7 @@
"@tailwindcss/vite": "^4.1.14",
"@types/codemirror": "^5.60.17",
"@types/crypto-js": "^4.2.2",
"@types/d3": "^7.4.3",
"@types/js-yaml": "^4.0.9",
"@types/node": "^24.6.0",
"@types/react": "^18.2.0",

View File

@@ -2,9 +2,10 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 14:00:06
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-04 10:58:41
* @Last Modified time: 2026-03-13 10:48:41
*/
import { request } from '@/utils/request'
import type { AxiosRequestConfig } from 'axios'
import type {
MemoryFormData,
} from '@/views/MemoryManagement/types'
@@ -94,8 +95,12 @@ export const updatedEndUserProfile = (values: EndUser) => {
return request.post(`/memory-storage/updated_end_user/profile`, values)
}
// User Memory - Relationship network
export const getMemorySearchEdges = (end_user_id: string) => {
return request.get(`/memory-storage/analytics/graph_data`, { end_user_id })
export const getMemorySearchEdges = (end_user_id: string, config?: AxiosRequestConfig) => {
return request.get(`/memory-storage/analytics/graph_data`, { end_user_id }, config)
}
// User Memory - Community graph
export const getMemoryCommunityGraph = (end_user_id: string, config?: AxiosRequestConfig) => {
return request.get(`/memory-storage/analytics/community_graph`, { end_user_id }, config)
}
// User Memory - User interest distribution
export const getInterestDistributionByUser = (end_user_id: string) => {

View File

@@ -0,0 +1,67 @@
import React, { useState, useRef, useMemo, useEffect, type FC } from 'react'
import Empty from '@/components/Empty'
import { GRAPH_COLORS, initCommunityGraph } from './utils'
import { useD3Graph } from './hooks'
import type { CommunityD3Node, D3Link, CommunityGraphProps } from './types'
// ─── Component ────────────────────────────────────────────────────────────────
// Renders a D3-powered community graph with optional tooltip and legend.
const CommunityGraph: FC<CommunityGraphProps> = ({
data,
empty: emptyProp,
colors = GRAPH_COLORS,
renderTooltip,
showLegend = true,
onCommunityClick,
onNodeClick,
defaultZoom = 1,
}) => {
// Tooltip position and hovered node state
const [tooltip, setTooltip] = useState<{ x: number; y: number; node: CommunityD3Node } | null>(null)
// Keep callback refs stable to avoid re-initializing the graph on every render
const onCommunityClickRef = useRef(onCommunityClick)
const onNodeClickRef = useRef(onNodeClick)
const renderTooltipRef = useRef(renderTooltip)
useEffect(() => { onCommunityClickRef.current = onCommunityClick }, [onCommunityClick])
useEffect(() => { onNodeClickRef.current = onNodeClick }, [onNodeClick])
useEffect(() => { renderTooltipRef.current = renderTooltip }, [renderTooltip])
const graphState = useMemo(() => data, [data])
// Show empty state when explicitly flagged or when there are no nodes
const isEmpty = emptyProp ?? !data?.nodes.length
// Initialize (or re-initialize) the D3 graph whenever relevant state changes
const containerRef = useD3Graph((container) => {
if (!graphState) return
return initCommunityGraph(
container,
graphState.nodes,
graphState.links as D3Link[],
graphState.communityMap,
graphState.communityCaption,
graphState.communityNodeMap,
{ colors, showLegend, defaultZoom, setTooltip: renderTooltip ? setTooltip : () => {}, onCommunityClickRef, onNodeClickRef }
)
}, [graphState, showLegend, defaultZoom])
// Resolve tooltip content: use custom renderer if provided, otherwise fall back to DefaultTooltip
const tooltipNode = tooltip && renderTooltipRef.current
? renderTooltipRef.current(tooltip.node)
: null
if (isEmpty) return <Empty className="rb:h-full" />
return (
<div className="rb:w-full rb:h-full rb:relative">
<div ref={containerRef} className="rb:w-full rb:h-full" />
{tooltipNode ? (
<div style={{ position: 'absolute', left: tooltip!.x + 14, top: tooltip!.y - 10, pointerEvents: 'none', zIndex: 20 }}>
{tooltipNode}
</div>
) : undefined}
</div>
)
}
export default React.memo(CommunityGraph)

View File

@@ -0,0 +1,24 @@
import { useRef, useEffect } from 'react'
import * as d3 from 'd3'
/**
* Generic hook that mounts a D3 graph inside a div container.
* Clears any existing SVG before calling initFn, and runs cleanup on unmount or dep change.
*/
export function useD3Graph<T>(
initFn: (container: HTMLDivElement) => (() => void) | void,
deps: T[]
) {
const containerRef = useRef<HTMLDivElement>(null)
useEffect(() => {
const container = containerRef.current
if (!container) return
d3.select(container).selectAll('svg').remove()
const cleanup = initFn(container)
return () => {
cleanup?.()
d3.select(container).selectAll('svg').remove()
}
}, deps)
return containerRef
}

View File

@@ -0,0 +1,102 @@
import type { ReactNode, RefObject } from 'react'
import type * as d3 from 'd3'
// ─── Raw input types (mirror of API response, no external dependency) ─────────
// These interfaces map 1-to-1 with the graph API response shape.
export interface RawCommunityNode {
id: string
label: 'Community'
properties: {
name: string
summary: string
member_entity_ids: string[]
member_count: number
core_entities: string[]
community_id: string
end_user_id?: string
updated_at?: string
}
}
export interface RawEntityNode {
id: string
label: 'ExtractedEntity'
properties: {
name: string
description: string
entity_type: string
community_name?: string
[key: string]: unknown
}
}
export interface RawEdge {
id: string
source: string
target: string
}
export interface RawCommunityGraphData {
nodes: (RawCommunityNode | RawEntityNode)[]
edges: RawEdge[]
}
// ─── D3 graph types ───────────────────────────────────────────────────────────
// Runtime node shape used by D3 simulations; extends SimulationNodeDatum for x/y/vx/vy.
export interface CommunityD3Node extends d3.SimulationNodeDatum {
id: string
name: string
community: string
label: string
symbolSize: number
color: string
properties?: RawEntityNode['properties']
}
export interface D3Link extends d3.SimulationLinkDatum<CommunityD3Node> {
isCross: boolean
}
// Convex-hull shape rendered behind each community cluster.
export interface HullDatum {
id: string
path: string
color: string
labelX: number
labelY: number
dashed: boolean
caption: string
}
// Fully transformed graph data ready to be passed into initCommunityGraph.
export interface CommunityGraphData {
nodes: CommunityD3Node[]
links: Array<{ source: string; target: string; isCross: boolean }>
communityMap: Map<string, string[]>
communityCaption: Map<string, string>
communityNodeMap: Map<string, RawCommunityNode>
}
// Props accepted by the CommunityGraph React component.
export interface CommunityGraphProps {
data: CommunityGraphData | null
empty?: boolean
colors?: string[]
renderTooltip?: (node: CommunityD3Node) => ReactNode
showLegend?: boolean
onCommunityClick?: (node: RawCommunityNode) => void
onNodeClick?: (node: CommunityD3Node) => void
defaultZoom?: number
}
// Options forwarded from the React component into the D3 initializer.
export interface InitOptions {
colors: string[]
showLegend: boolean
defaultZoom: number
setTooltip: (s: { x: number; y: number; node: CommunityD3Node } | null) => void
onCommunityClickRef: RefObject<((node: RawCommunityNode) => void) | undefined>
onNodeClickRef: RefObject<((node: CommunityD3Node) => void) | undefined>
}

View File

@@ -0,0 +1,547 @@
import * as d3 from 'd3'
import type { CommunityD3Node, D3Link, HullDatum, CommunityGraphData, RawCommunityGraphData, RawCommunityNode, RawEntityNode, InitOptions } from './types'
// ─── Colors ───────────────────────────────────────────────────────────────────
export const GRAPH_COLORS = ['#155EEF', '#369F21', '#4DA8FF', '#FF5D34', '#9C6FFF', '#FF8A4C', '#8BAEF7', '#FFB048']
export const colorAt = (i: number) => GRAPH_COLORS[i % GRAPH_COLORS.length]
export function connectionToRadius(connections: number): number {
if (connections <= 1) return 5
if (connections <= 10) return 8
if (connections <= 15) return 11
if (connections <= 20) return 16
return 22
}
// ─── Arrow markers ────────────────────────────────────────────────────────────
export function addArrowMarkers(
defs: d3.Selection<SVGDefsElement, unknown, null, undefined>,
markers: { id: string; color: string }[]
) {
markers.forEach(({ id, color }) => {
defs.append('marker')
.attr('id', id)
.attr('viewBox', '0 -4 8 8')
.attr('refX', 8).attr('refY', 0)
.attr('markerWidth', 6).attr('markerHeight', 6)
.attr('orient', 'auto')
.append('path').attr('d', 'M0,-4L8,0L0,4').attr('fill', color)
})
}
// ─── Zoom ─────────────────────────────────────────────────────────────────────
export function addZoom(
svg: d3.Selection<SVGSVGElement, unknown, null, undefined>,
g: d3.Selection<SVGGElement, unknown, null, undefined>
) {
svg.call(
d3.zoom<SVGSVGElement, unknown>().scaleExtent([0.2, 4])
.on('zoom', e => g.attr('transform', e.transform))
)
}
// ─── Node drag ────────────────────────────────────────────────────────────────
export function makeNodeDrag<N extends d3.SimulationNodeDatum>(
simulation: d3.Simulation<N, d3.SimulationLinkDatum<N>>
) {
return d3.drag<SVGGElement, N>()
.on('start', (e, d) => { if (!e.active) simulation.alphaTarget(0.3).restart(); d.fx = d.x; d.fy = d.y })
.on('drag', (e, d) => { d.fx = e.x; d.fy = e.y })
.on('end', (e, d) => { if (!e.active) simulation.alphaTarget(0); d.fx = e.x; d.fy = e.y })
}
// ─── Cluster force ────────────────────────────────────────────────────────────
// Works for both string and number group keys.
export function makeClusterForce<N extends d3.SimulationNodeDatum & { x?: number; y?: number; vx?: number; vy?: number }>(
nodes: N[],
getGroup: (d: N) => string | number,
centers: Record<string | number, { x: number; y: number }>,
width: number,
height: number,
opts: { pullStrength?: number; minSepRatio?: number; pushStrength?: number } = {}
) {
const { pullStrength = 0.45, minSepRatio = 0.68, pushStrength = 1.0 } = opts
return (alpha: number) => {
// pre-group nodes by key to avoid repeated filter() in hot path
const groups = new Map<string, N[]>()
nodes.forEach(d => {
const k = String(getGroup(d))
if (!groups.has(k)) groups.set(k, [])
groups.get(k)!.push(d)
})
// pull toward group center
nodes.forEach(d => {
const c = centers[getGroup(d)]
if (!c) return
d.vx = (d.vx ?? 0) + (c.x - (d.x ?? 0)) * pullStrength * alpha
d.vy = (d.vy ?? 0) + (c.y - (d.y ?? 0)) * pullStrength * alpha
})
// live centroids
const centroids: Record<string, { x: number; y: number; n: number }> = {}
nodes.forEach(d => {
const g = String(getGroup(d))
if (!centroids[g]) centroids[g] = { x: 0, y: 0, n: 0 }
centroids[g].x += d.x ?? 0
centroids[g].y += d.y ?? 0
centroids[g].n++
})
Object.values(centroids).forEach(c => { c.x /= c.n; c.y /= c.n })
// push groups apart
const keys = Object.keys(centroids)
const minSep = Math.min(width, height) * minSepRatio
for (let i = 0; i < keys.length; i++) {
for (let j = i + 1; j < keys.length; j++) {
const ci = centroids[keys[i]], cj = centroids[keys[j]]
const dx = cj.x - ci.x, dy = cj.y - ci.y
const dist = Math.sqrt(dx * dx + dy * dy) || 1
if (dist >= minSep) continue
const push = ((minSep - dist) / dist) * pushStrength * alpha
const fx = dx * push, fy = dy * push
groups.get(keys[i])?.forEach(d => { d.vx = (d.vx ?? 0) - fx; d.vy = (d.vy ?? 0) - fy })
groups.get(keys[j])?.forEach(d => { d.vx = (d.vx ?? 0) + fx; d.vy = (d.vy ?? 0) + fy })
}
}
}
}
// ─── Group centers ────────────────────────────────────────────────────────────
export function buildGroupCenters(
keys: (string | number)[],
width: number,
height: number,
radiusRatio = 0.4
): Record<string | number, { x: number; y: number }> {
const centers: Record<string | number, { x: number; y: number }> = {}
const r = Math.min(width, height) * radiusRatio
keys.forEach((key, i) => {
const angle = (i / keys.length) * 2 * Math.PI - Math.PI / 2
centers[key] = { x: width / 2 + r * Math.cos(angle), y: height / 2 + r * Math.sin(angle) }
})
return centers
}
// ─── Community graph data transform ─────────────────────────────────────────
export function buildCommunityGraphData(raw: RawCommunityGraphData, colors: string[] = GRAPH_COLORS): CommunityGraphData | null {
const getColor = (i: number) => colors[i % colors.length]
const communityNodes = raw.nodes.filter(n => n.label === 'Community') as RawCommunityNode[]
const communityCaption = new Map<string, string>()
const communityMap = new Map<string, string[]>()
communityNodes.forEach(n => {
communityCaption.set(n.id, n.properties.name)
communityMap.set(n.id, n.properties.member_entity_ids)
})
const entityToCommunity = new Map<string, string>()
communityMap.forEach((members, commId) => members.forEach(eid => entityToCommunity.set(eid, commId)))
const commKeys = Array.from(communityMap.keys())
const commIndex = new Map(commKeys.map((k, i) => [k, i]))
const entityNodes = raw.nodes.filter(n => n.label === 'ExtractedEntity') as RawEntityNode[]
const entityNodeSet = new Set(entityNodes.map(n => n.id))
const connectionCount: Record<string, number> = {}
raw.edges.forEach(e => {
if (entityNodeSet.has(e.source)) connectionCount[e.source] = (connectionCount[e.source] || 0) + 1
if (entityNodeSet.has(e.target)) connectionCount[e.target] = (connectionCount[e.target] || 0) + 1
})
const nodes: CommunityD3Node[] = entityNodes.map(n => {
const commId = entityToCommunity.get(n.id) ?? commKeys[0]
return {
id: n.id,
name: n.properties.name,
community: commId,
label: n.label,
symbolSize: connectionToRadius(connectionCount[n.id] || 0),
color: getColor(commIndex.get(commId) ?? 0),
properties: n.properties,
}
})
if (!nodes.length) return null
const links = raw.edges
.filter(e => entityNodeSet.has(e.source) && entityNodeSet.has(e.target))
.map(e => ({
source: e.source,
target: e.target,
isCross: entityToCommunity.get(e.source) !== entityToCommunity.get(e.target),
}))
const communityNodeMap = new Map<string, RawCommunityNode>(
communityNodes.map(n => [n.id, n])
)
return { nodes, links, communityMap, communityCaption, communityNodeMap }
}
// ─── Hull helpers ─────────────────────────────────────────────────────────────
const smoothLine = d3.line<[number, number]>()
.x(d => d[0]).y(d => d[1])
.curve(d3.curveCatmullRomClosed.alpha(0.5))
function expandPoints(pts: [number, number][], pad: number): [number, number][] {
const cx = pts.reduce((s, p) => s + p[0], 0) / pts.length
const cy = pts.reduce((s, p) => s + p[1], 0) / pts.length
return pts.map(([x, y]) => {
const dx = x - cx, dy = y - cy
const len = Math.sqrt(dx * dx + dy * dy) || 1
return [x + (dx / len) * pad, y + (dy / len) * pad]
})
}
function toHullPoints(pts: [number, number][]): [number, number][] {
if (pts.length === 1) {
const [x, y] = pts[0]
return [[x - 1, y - 1], [x + 1, y - 1], [x, y + 1]]
}
if (pts.length === 2) {
const [[x1, y1], [x2, y2]] = pts
return [[x1, y1], [x2, y2], [(x1 + x2) / 2, (y1 + y2) / 2 - 1]]
}
return d3.polygonHull(pts) ?? pts
}
const CIRCLE_THRESHOLD = 4 // 节点数 < 此值时使用圆形
const CIRCLE_SEGMENTS = 32
function circlePoints(cx: number, cy: number, r: number): [number, number][] {
return Array.from({ length: CIRCLE_SEGMENTS }, (_, i) => {
const a = (i / CIRCLE_SEGMENTS) * 2 * Math.PI
return [cx + r * Math.cos(a), cy + r * Math.sin(a)] as [number, number]
})
}
export function buildHullData(
nodes: CommunityD3Node[],
communityMap: Map<string, string[]>,
communityCaption: Map<string, string>,
colors: string[]
): HullDatum[] {
const getColor = (i: number) => colors[i % colors.length]
const byComm = new Map<string, [number, number][]>()
communityMap.forEach((_, id) => byComm.set(id, []))
nodes.forEach(d => {
if (d.x != null && d.y != null) byComm.get(d.community)?.push([d.x, d.y])
})
const hulls: HullDatum[] = []
let ci = 0
byComm.forEach((pts, id) => {
const color = getColor(ci++)
if (!pts.length) return
let pathPoints: [number, number][]
if (pts.length < CIRCLE_THRESHOLD) {
const cx = pts.reduce((s, p) => s + p[0], 0) / pts.length
const cy = pts.reduce((s, p) => s + p[1], 0) / pts.length
pathPoints = circlePoints(cx, cy, 60)
} else {
pathPoints = expandPoints(toHullPoints(pts), 60) as [number, number][]
}
const path = smoothLine(pathPoints)
if (!path) return
hulls.push({
id, path, color,
labelX: pathPoints.reduce((s, p) => s + p[0], 0) / pathPoints.length,
labelY: Math.min(...pathPoints.map(p => p[1])) - 10,
dashed: pts.length <= 2,
caption: communityCaption.get(id) ?? id,
})
})
return hulls
}
// ─── Hull render ──────────────────────────────────────────────────────────────
export function renderHulls(
hullG: d3.Selection<SVGGElement, unknown, null, undefined>,
hulls: HullDatum[],
hiddenCommunities: Set<string>,
nodes: CommunityD3Node[],
simulation: d3.Simulation<CommunityD3Node, D3Link>,
onCommunityClick?: (node: RawCommunityNode) => void,
communityNodeMap?: Map<string, RawCommunityNode>
) {
let dragNodes: CommunityD3Node[] = []
let dragStart = { x: 0, y: 0 }
const communityDrag = d3.drag<SVGPathElement, HullDatum>()
.on('start', (event, d) => {
if (!event.active) simulation.alphaTarget(0.3).restart()
dragNodes = nodes.filter(n => n.community === d.id)
dragStart = { x: event.x, y: event.y }
dragNodes.forEach(n => { n.fx = n.x; n.fy = n.y })
})
.on('drag', (event) => {
const dx = event.x - dragStart.x, dy = event.y - dragStart.y
dragStart = { x: event.x, y: event.y }
dragNodes.forEach(n => { n.fx = (n.fx ?? n.x ?? 0) + dx; n.fy = (n.fy ?? n.y ?? 0) + dy })
})
.on('end', (event) => { if (!event.active) simulation.alphaTarget(0) })
const pathSel = hullG.selectAll<SVGPathElement, HullDatum>('path.hull').data(hulls, d => d.id)
pathSel.enter().append('path').attr('class', 'hull').style('cursor', 'grab')
.merge(pathSel)
.call(communityDrag)
.attr('d', d => d.path)
.attr('fill', d => d.color).attr('fill-opacity', 0.08)
.attr('stroke', d => d.color).attr('stroke-opacity', 0.5).attr('stroke-width', 1.5)
.attr('stroke-dasharray', 'none')
.style('display', d => hiddenCommunities.has(d.id) ? 'none' : null)
.on('click', (event, d) => {
if ((event as MouseEvent).defaultPrevented) return
const node = communityNodeMap?.get(d.id)
if (node) onCommunityClick?.(node)
})
pathSel.exit().remove()
const labelSel = hullG.selectAll<SVGTextElement, HullDatum>('text.hull-label').data(hulls, d => d.id)
labelSel.enter().append('text').attr('class', 'hull-label')
.attr('text-anchor', 'middle').attr('font-size', '12px').attr('font-weight', '500')
.style('pointer-events', 'none')
.merge(labelSel)
.attr('x', d => d.labelX).attr('y', d => d.labelY)
.attr('fill', d => d.color)
.style('display', d => hiddenCommunities.has(d.id) ? 'none' : null)
.text(d => d.caption)
labelSel.exit().remove()
}
// ─── Community graph init ─────────────────────────────────────────────────────
export function initCommunityGraph(
container: HTMLDivElement,
nodes: CommunityD3Node[],
links: D3Link[],
communityMap: Map<string, string[]>,
communityCaption: Map<string, string>,
communityNodeMap: Map<string, RawCommunityNode>,
opts: InitOptions
) {
const { colors, showLegend, defaultZoom, setTooltip, onCommunityClickRef, onNodeClickRef } = opts
const getColor = (i: number) => colors[i % colors.length]
const width = container.clientWidth || 600
const height = container.clientHeight || 518
const svg = d3.select(container).append('svg')
.attr('width', width).attr('height', height)
.style('width', '100%').style('height', '100%')
.style('background', '#F6F8FC')
const g = svg.append('g')
const zoom = d3.zoom<SVGSVGElement, unknown>()
.scaleExtent([0.2, 4])
.on('zoom', e => g.attr('transform', e.transform))
svg.call(zoom)
if (defaultZoom !== 1) {
svg.call(zoom.transform, d3.zoomIdentity
.translate(width / 2 * (1 - defaultZoom), height / 2 * (1 - defaultZoom))
.scale(defaultZoom)
)
}
const defs = svg.append('defs')
addArrowMarkers(defs, [{ id: 'arrow', color: 'rgba(91, 97, 103, 0.7)' }])
const commKeys = Array.from(communityMap.keys())
const centers = buildGroupCenters(commKeys, width, height, 0.45)
const linkedIds = new Set(links.flatMap(l => [l.source as string, l.target as string]))
const simulation = d3.forceSimulation(nodes)
.force('link', d3.forceLink<CommunityD3Node, D3Link>(links).id(d => d.id).distance(60))
.force('charge', d3.forceManyBody().strength(-120))
.force('center', d3.forceCenter(width / 2, height / 2).strength(0.02))
.force('collision', d3.forceCollide<CommunityD3Node>(d => d.symbolSize + 16))
.force('cluster', makeClusterForce(nodes, d => d.community, centers, width, height, {
pullStrength: 0.45, minSepRatio: 0.68, pushStrength: 1.0,
}))
.force('isolatedPull', (alpha: number) => {
nodes.forEach(d => {
if (linkedIds.has(d.id)) return
const c = centers[d.community]
if (!c) return
d.vx = (d.vx ?? 0) + (c.x - (d.x ?? 0)) * 0.4 * alpha
d.vy = (d.vy ?? 0) + (c.y - (d.y ?? 0)) * 0.4 * alpha
})
})
const hullG = g.append('g').attr('class', 'hulls')
const hiddenCommunities = new Set<string>()
const linkSel = g.append('g').selectAll<SVGLineElement, D3Link>('line')
.data(links).enter().append('line')
.attr('stroke', '#5B6167')
.attr('stroke-opacity', d => d.isCross ? 0.3 : 0.5)
.attr('stroke-width', d => d.isCross ? 1 : 1.2)
.attr('marker-end', 'url(#arrow)')
const nodeSel = g.append('g').selectAll<SVGGElement, CommunityD3Node>('g')
.data(nodes).enter().append('g')
.call(makeNodeDrag(simulation))
nodeSel.append('circle')
.attr('r', d => d.symbolSize)
.attr('fill', d => d.color).attr('fill-opacity', 0.85)
.attr('stroke', '#fff').attr('stroke-width', 1.5)
.style('cursor', 'pointer')
.on('mouseenter', (event: MouseEvent, d: CommunityD3Node) => {
const { left, top } = container.getBoundingClientRect()
setTooltip({ x: event.clientX - left, y: event.clientY - top, node: d })
})
.on('mousemove', (event: MouseEvent) => {
const { left, top } = container.getBoundingClientRect()
const nd = d3.select<SVGCircleElement, CommunityD3Node>(event.target as SVGCircleElement).datum()
setTooltip({ x: event.clientX - left, y: event.clientY - top, node: nd })
})
.on('mouseleave', () => setTooltip(null))
.on('click', (_event: MouseEvent, d: CommunityD3Node) => onNodeClickRef.current?.(d))
nodeSel.append('text')
.text(d => d.name)
.attr('x', 0).attr('dy', d => -(d.symbolSize + 5))
.attr('text-anchor', 'middle').attr('font-size', '11px').attr('fill', '#444')
.style('pointer-events', 'none')
if (showLegend) {
renderLegend(
svg,
commKeys.map((cid, i) => ({ key: cid, label: communityCaption.get(cid) ?? cid, color: getColor(i) })),
width, height,
(key, hidden) => {
const cid = key as string
if (hidden) hiddenCommunities.add(cid)
else hiddenCommunities.delete(cid)
nodeSel.style('display', d => hiddenCommunities.has(d.community) ? 'none' : null)
linkSel.style('display', d => {
const s = d.source as CommunityD3Node, t = d.target as CommunityD3Node
return hiddenCommunities.has(s.community) || hiddenCommunities.has(t.community) ? 'none' : null
})
hullG.selectAll<SVGPathElement, HullDatum>('path.hull').style('display', d => hiddenCommunities.has(d.id) ? 'none' : null)
hullG.selectAll<SVGTextElement, HullDatum>('text.hull-label').style('display', d => hiddenCommunities.has(d.id) ? 'none' : null)
}
)
}
simulation.on('tick', () => {
linkSel
.attr('x1', d => (d.source as CommunityD3Node).x ?? 0)
.attr('y1', d => (d.source as CommunityD3Node).y ?? 0)
.attr('x2', d => {
const s = d.source as CommunityD3Node, t = d.target as CommunityD3Node
const dx = (t.x ?? 0) - (s.x ?? 0), dy = (t.y ?? 0) - (s.y ?? 0)
const dist = Math.sqrt(dx * dx + dy * dy) || 1
return (t.x ?? 0) - (dx / dist) * (t.symbolSize + 2)
})
.attr('y2', d => {
const s = d.source as CommunityD3Node, t = d.target as CommunityD3Node
const dx = (t.x ?? 0) - (s.x ?? 0), dy = (t.y ?? 0) - (s.y ?? 0)
const dist = Math.sqrt(dx * dx + dy * dy) || 1
return (t.y ?? 0) - (dy / dist) * (t.symbolSize + 2)
})
nodeSel.attr('transform', d => `translate(${d.x ?? 0},${d.y ?? 0})`)
renderHulls(hullG, buildHullData(nodes, communityMap, communityCaption, colors), hiddenCommunities, nodes, simulation, (n) => onCommunityClickRef.current?.(n), communityNodeMap)
})
return () => { simulation.stop(); d3.select(container).selectAll('svg').remove() }
}
// ─── Legend ───────────────────────────────────────────────────────────────────
export interface LegendItem {
key: string | number
label: string
color: string
}
const LEGEND_GAP = 12
const LEGEND_RECT_W = 20
const LEGEND_RECT_H = 10
const LEGEND_TEXT_OFFSET = 24
const LEGEND_FONT_SIZE = 11
const LEGEND_ROW_H = 24
const LEGEND_BOTTOM_PAD = 8
// Approximate text width using canvas measureText if available, else char-based estimate
function measureText(text: string, fontSize: number): number {
try {
const ctx = document.createElement('canvas').getContext('2d')
if (ctx) { ctx.font = `${fontSize}px sans-serif`; return ctx.measureText(text).width }
} catch { /* noop */ }
return text.length * fontSize * 0.6
}
export function renderLegend(
svg: d3.Selection<SVGSVGElement, unknown, null, undefined>,
items: LegendItem[],
width: number,
height: number,
onToggle: (key: string | number, hidden: boolean) => void
) {
// Compute per-item width: rect + text-offset + textW
const itemWidths = items.map(item =>
LEGEND_RECT_W + LEGEND_TEXT_OFFSET + measureText(item.label, LEGEND_FONT_SIZE)
)
// Layout items into rows
const rows: { item: LegendItem; w: number; x: number; row: number }[] = []
let rowIdx = 0, curX = 0
itemWidths.forEach((w, i) => {
const slotW = w + LEGEND_GAP
if (curX > 0 && curX + w > width - LEGEND_GAP * 2) { rowIdx++; curX = 0 }
rows.push({ item: items[i], w, x: curX, row: rowIdx })
curX += slotW
})
const totalRows = rowIdx + 1
const totalH = totalRows * LEGEND_ROW_H
const baseY = height - totalH - LEGEND_BOTTOM_PAD
// Center each row
const rowWidths: number[] = Array(totalRows).fill(0)
rows.forEach(({ w, row }, i) => {
rowWidths[row] += w + (i > 0 && rows[i - 1].row === row ? LEGEND_GAP : 0)
})
// Recalculate row widths properly
const rowTotals: number[] = Array(totalRows).fill(0)
const rowCounts: number[] = Array(totalRows).fill(0)
rows.forEach(r => { rowCounts[r.row]++; rowTotals[r.row] += r.w })
rowTotals.forEach((_, ri) => { rowTotals[ri] += Math.max(0, rowCounts[ri] - 1) * LEGEND_GAP })
const legendG = svg.append('g')
rows.forEach(({ item, x, row }) => {
const rowOffsetX = (width - rowTotals[row]) / 2
const g = legendG.append('g')
.attr('transform', `translate(${rowOffsetX + x},${baseY + row * LEGEND_ROW_H + LEGEND_ROW_H / 2})`)
.style('cursor', 'pointer')
const rect = g.append('rect')
.attr('x', 0).attr('y', -LEGEND_RECT_H / 2)
.attr('width', LEGEND_RECT_W).attr('height', LEGEND_RECT_H).attr('rx', 2)
.attr('fill', item.color)
const text = g.append('text')
.text(item.label)
.attr('x', LEGEND_TEXT_OFFSET).attr('dy', '0.35em')
.attr('font-size', `${LEGEND_FONT_SIZE}px`).attr('fill', '#5B6167')
let hidden = false
g.on('click', () => {
hidden = !hidden
rect.attr('fill', hidden ? '#ccc' : item.color)
text.attr('fill', hidden ? '#bbb' : '#5B6167')
onToggle(item.key, hidden)
})
})
}

View File

@@ -1482,6 +1482,33 @@ export const en = {
memoryNum: 'memories',
memory_config_name: 'Memory Engine',
searchPlaceholder: 'Search memory store name',
communityNetwork: 'Community Graph',
community: 'Community',
"Person": "Person Entity Node",
"Organization": "Organization Entity Node",
"ORG": "Organization Entity Node",
"Location": "Location Entity Node",
"LOC": "Location Entity Node",
"Event": "Event Entity Node",
"Concept": "Concept Entity Node",
"Time": "Time Entity Node",
"Position": "Position Entity Node",
"WorkRole": "Work Role Entity Node",
"System": "System Entity Node",
"Policy": "Policy Entity Node",
"HistoricalPeriod": "Historical Period Entity Node",
"HistoricalState": "Historical State Entity Node",
"HistoricalEvent": "Historical Event Entity Node",
"EconomicFactor": "Economic Factor Entity Node",
"Condition": "Condition Entity Node",
"Numeric": "Numeric Entity Node",
"Work": "Work / Output",
member_count: 'Member Count',
member_count_desc: 'entities',
summary: 'Summary',
core_entities: 'Core Entities',
communityDetailEmptyDesc: 'Click on a community in the chart on the left to view details',
},
space: {
createSpace: 'Create Space',

View File

@@ -1480,6 +1480,33 @@ export const zh = {
memoryNum: '条记忆',
memory_config_name: '记忆引擎',
searchPlaceholder: '搜索记忆库名称',
communityNetwork: '社区图谱',
community: '社区',
"Person": "人物实体节点",
"Organization": "组织实体节点",
"ORG": "组织实体节点",
"Location": "地点实体节点",
"LOC": "地点实体节点",
"Event": "事件实体节点",
"Concept": "概念实体节点",
"Time": "时间实体节点",
"Position": "职位实体节点",
"WorkRole": "职业实体节点",
"System": "系统实体节点",
"Policy": "政策实体节点",
"HistoricalPeriod": "历史时期实体节点",
"HistoricalState": "历史国家实体节点",
"HistoricalEvent": "历史事件实体节点",
"EconomicFactor": "经济因素实体节点",
"Condition": "条件实体节点",
"Numeric": "数值实体节点",
"Work": "作品/工作成果",
member_count: '成员数',
member_count_desc: '个实体',
summary: '摘要',
core_entities: '核心实体',
communityDetailEmptyDesc: '点击左侧图表中的社区查看详情',
},
space: {
createSpace: '创建空间',

View File

@@ -0,0 +1,72 @@
import React, { useState, type FC, useEffect } from 'react'
import { useParams } from 'react-router-dom'
import { useTranslation } from 'react-i18next'
import type { CommunityD3Node, CommunityGraphData, RawCommunityGraphData, RawCommunityNode } from '@/components/D3Graph/types'
import { buildCommunityGraphData } from '@/components/D3Graph/utils'
import CommunityGraph from '@/components/D3Graph/CommunityGraph'
import { getMemoryCommunityGraph } from '@/api/memory'
// ─── Tooltip ──────────────────────────────────────────────────────────────────
const NodeTooltip: FC<{ node: CommunityD3Node }> = ({ node }) => {
const { t } = useTranslation()
return (
<div style={{
background: '#fff', border: '1px solid #DFE4ED', borderRadius: 8,
boxShadow: '0 4px 16px rgba(0,0,0,0.12)', padding: '10px 14px',
minWidth: 180, maxWidth: 260, fontSize: 13,
}}>
<div style={{ fontWeight: 600, marginBottom: 6, color: '#1a1a1a', fontSize: 14 }}>
{node.properties?.name ?? node.name}
</div>
{node.properties?.description && (
<div style={{ color: '#5B6167', lineHeight: '20px', marginBottom: 4 }}>
{node.properties.description}
</div>
)}
<div style={{ color: '#5B6167', lineHeight: '22px' }}>
{t('userMemory.type')}
<span style={{ color: '#1a1a1a' }}>{t(`userMemory.${node.properties?.entity_type}`)}</span>
</div>
<div style={{ color: '#5B6167', lineHeight: '22px' }}>
{t('userMemory.community')}
<span style={{ color: node.color, fontWeight: 500 }}>{node.properties?.community_name}</span>
</div>
</div>
)
}
// ─── Component ────────────────────────────────────────────────────────────────
const CommunityNetwork: FC<{ onSelectCommunity?: (node: RawCommunityNode) => void }> = ({ onSelectCommunity }) => {
const { id } = useParams()
const [graphData, setGraphData] = useState<CommunityGraphData | null>(null)
const [empty, setEmpty] = useState(false)
useEffect(() => {
if (!id) return
const controller = new AbortController()
setEmpty(false)
setGraphData(null)
getMemoryCommunityGraph(id, { signal: controller.signal }).then(res => {
const raw = res as RawCommunityGraphData
if (!raw.nodes?.length) { setEmpty(true); return }
const built = buildCommunityGraphData(raw)
if (!built) { setEmpty(true); return }
setGraphData(built)
}).catch((e) => { if (e?.code !== 'ERR_CANCELED') setEmpty(true) })
return () => controller.abort()
}, [id])
return (
<CommunityGraph
data={graphData}
empty={empty}
showLegend={false}
onCommunityClick={onSelectCommunity}
renderTooltip={node => <NodeTooltip node={node} />}
/>
)
}
export default React.memo(CommunityNetwork)

View File

@@ -1,8 +1,8 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-03 18:32:00
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 18:32:00
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-13 14:51:17
*/
/**
* Relationship Network Component
@@ -13,18 +13,20 @@
import React, { type FC, useEffect, useState, useRef, useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { useParams, useNavigate } from 'react-router-dom'
import { Col, Row, Space, Button } from 'antd'
import { Col, Row, Space, Button, Tabs, Flex, Divider } from 'antd'
import dayjs from 'dayjs'
import ReactEcharts from 'echarts-for-react'
import RbCard from '@/components/RbCard/Card'
import detailEmpty from '@/assets/images/userMemory/detail_empty.png'
import type { Node, Edge, GraphData, StatementNodeProperties, ExtractedEntityNodeProperties } from '../types'
import type { RawCommunityNode } from '@/components/D3Graph/types'
import {
getMemorySearchEdges,
} from '@/api/memory'
import Empty from '@/components/Empty'
import Tag from '@/components/Tag'
import CommunityNetwork from './CommunityNetwork'
/** Node color palette */
const colors = ['#155EEF', '#369F21', '#4DA8FF', '#FF5D34', '#9C6FFF', '#FF8A4C', '#8BAEF7', '#FFB048']
@@ -36,16 +38,21 @@ const RelationshipNetwork:FC = () => {
const [nodes, setNodes] = useState<Node[]>([])
const [links, setLinks] = useState<Edge[]>([])
const [categories, setCategories] = useState<{ name: string }[]>([])
const [selectedNode, setSelectedNode] = useState<Node | null>(null)
const [selectedNode, setSelectedNode] = useState<Node | RawCommunityNode | null>(null)
// const [fullScreen, setFullScreen] = useState<boolean>(false)
const navigate = useNavigate()
const [activeTab, setActiveTab] = useState('relationshipNetwork')
console.log('categories', categories)
const edgeAbortRef = useRef<AbortController | null>(null)
/** Fetch relationship network data */
const getEdgeData = useCallback(() => {
if (!id) return
edgeAbortRef.current?.abort()
edgeAbortRef.current = new AbortController()
setSelectedNode(null)
getMemorySearchEdges(id).then((res) => {
getMemorySearchEdges(id, { signal: edgeAbortRef.current.signal }).then((res) => {
const { nodes, edges, statistics } = res as GraphData
const curNodes: Node[] = []
const curEdges: Edge[] = []
@@ -123,6 +130,7 @@ const RelationshipNetwork:FC = () => {
useEffect(() => {
if (!id) return
getEdgeData()
return () => { edgeAbortRef.current?.abort() }
}, [id])
useEffect(() => {
@@ -153,34 +161,36 @@ const RelationshipNetwork:FC = () => {
const params = new URLSearchParams({
nodeId: selectedNode.id,
nodeLabel: selectedNode.label,
nodeName: selectedNode.name || ''
nodeName: (selectedNode as Node).name || ''
})
navigate(`/user-memory/detail/${id}/GRAPH?${params.toString()}`)
}
const handleChangeTab = (tab: string) => {
if (tab === 'communityNetwork') {
edgeAbortRef.current?.abort()
} else {
getEdgeData()
}
setActiveTab(tab)
setSelectedNode(null)
}
return (
<Row gutter={16}>
{/* Relationship Network */}
<Col span={16}>
<RbCard
title={t('userMemory.relationshipNetwork')}
headerType="borderless"
headerClassName="rb:min-h-[46px]!"
// extra={
// <div
// onClick={handleFullScreen}
// className="rb:group rb:cursor-pointer rb:hover:text-[#212332] rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:flex rb:items-center rb:gap-1"
// >
// <div className="rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/fullScreen.svg')] rb:hover:bg-[url('@/assets/images/fullScreen_hover.svg')]"></div>
// {t('userMemory.fullScreen')}
// </div>
// }
>
<RbCard bodyClassName="rb:pt-0!">
<Tabs
items={['relationshipNetwork', 'communityNetwork'].map(key => ({ key, label: t(`userMemory.${key}`) }))}
activeKey={activeTab}
onChange={handleChangeTab}
/>
<div className="rb:h-129.5 rb:bg-[#F6F8FC] rb:border rb:border-[#DFE4ED] rb:rounded-sm">
{nodes.length === 0 ? (
<Empty className="rb:h-full" />
) : (
<ReactEcharts
{activeTab === 'communityNetwork'
? <CommunityNetwork onSelectCommunity={community => setSelectedNode(community)} />
: nodes.length === 0
? <Empty className="rb:h-full" />
: <ReactEcharts
option={{
colors: colors,
tooltip: {
@@ -253,103 +263,121 @@ const RelationshipNetwork:FC = () => {
}
}}
/>
)}
}
</div>
</RbCard>
</Col>
{/* Memory Details */}
<Col span={8}>
<RbCard
<RbCard
title={t('userMemory.memoryDetails')}
headerType="borderless"
headerClassName="rb:min-h-[46px]!"
bodyClassName='rb:p-0!'
extra={selectedNode && <Button type="text" onClick={handleViewAll}>
<div
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/userMemory/view.svg')] rb:hover:bg-[url('@/assets/images/userMemory/view_hover.svg')]"
></div>
{t('userMemory.completeMemory')}
</Button>}
bodyClassName="rb:p-0!"
extra={selectedNode && !(selectedNode as RawCommunityNode).properties.community_id && (
<Button type="text" onClick={handleViewAll}>
<div className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/userMemory/view.svg')] rb:hover:bg-[url('@/assets/images/userMemory/view_hover.svg')]" />
{t('userMemory.completeMemory')}
</Button>
)}
>
<div className="rb:h-133.5 rb:overflow-y-auto">
{!selectedNode
? <Empty
url={detailEmpty}
subTitle={t('userMemory.memoryDetailEmptyDesc')}
className="rb:h-full rb:mx-10 rb:text-center"
size={[197.81, 150]}
/>
: <>
{selectedNode.name && <div className="rb:bg-[#F6F8FC] rb:border-t rb:border-b rb:border-[#DFE4ED] rb:font-medium rb:py-2 rb:px-4 rb:h-10">{selectedNode.name}</div>}
<div className="rb:p-4">
<>
? <Empty url={detailEmpty} subTitle={activeTab === 'relationshipNetwork' ? t('userMemory.memoryDetailEmptyDesc') : t('userMemory.communityDetailEmptyDesc')} className="rb:h-full rb:mx-10 rb:text-center" size={[197.81, 150]} />
: (selectedNode as RawCommunityNode).properties.community_id
? <div className="rb:p-3 rb:pt-0">
<div className="rb:font-medium rb:text-[#212332] rb:text-[16px] rb:leading-5.5 rb:pl-1">
{(selectedNode as RawCommunityNode).properties.name}
</div>
<div className="rb:mt-3 rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.summary')}</div>
<div className="rb:bg-[#F6F6F6] rb:rounded-xl rb:px-3 rb:py-2.5 rb:mt-2">
{(selectedNode as RawCommunityNode).properties.summary}
</div>
<Flex align="center" justify="space-between" className="rb:mt-5!">
<span className="rb:text-[#5B6167] rb:font-regular rb:pl-1">{t('userMemory.member_count')}</span>
<span className="rb:font-medium">{(selectedNode as RawCommunityNode).properties.member_count}{t('userMemory.member_count_desc')}</span>
</Flex>
<Divider className='rb:my-2.5!' />
<div className="rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.core_entities')}</div>
<ul className="rb:list-disc rb:pl-4 rb:text-[#5B6167] rb:mt-2">
{(selectedNode as RawCommunityNode).properties.core_entities.map((entity, index) => <li key={index}>{entity}</li>)}
</ul>
</div>
: <>
{(selectedNode as Node).name && (
<div className="rb:bg-[#F6F8FC] rb:border-t rb:border-b rb:border-[#DFE4ED] rb:font-medium rb:py-2 rb:px-4 rb:h-10">
{(selectedNode as Node).name}
</div>
)}
<div className="rb:p-4">
<div className="rb:font-medium rb:leading-5">{t('userMemory.memoryContent')}</div>
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
{['Chunk', 'Dialogue', 'MemorySummary'].includes(selectedNode.label) && 'content' in selectedNode.properties
? selectedNode.properties.content
: selectedNode.label === 'ExtractedEntity' && 'description' in selectedNode.properties
? selectedNode.properties.description
: selectedNode.label === 'Statement' && 'statement' in selectedNode.properties
? selectedNode.properties.statement
: ''
}
</div>
</>
<div className="rb:font-medium rb:mb-2 rb:mt-4">
<div className="rb:font-medium rb:leading-5">{t('userMemory.created_at')}</div>
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
{dayjs(selectedNode?.properties.created_at).format('YYYY-MM-DD HH:mm:ss')}
? selectedNode.properties.description
: selectedNode.label === 'Statement' && 'statement' in selectedNode.properties
? selectedNode.properties.statement
: ''}
</div>
{selectedNode?.properties.associative_memory > 0 && <div className="rb:mt-4">
<div className="rb:font-medium rb:leading-5">{t('userMemory.associative_memory')}</div>
<div className="rb:font-medium rb:mb-2 rb:mt-4">
<div className="rb:font-medium rb:leading-5">{t('userMemory.created_at')}</div>
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
<span className="rb:text-[#155EEF] rb:font-medium">{selectedNode?.properties.associative_memory}</span> {t('userMemory.unix')}{t('userMemory.associative_memory')}
{dayjs((selectedNode as Node).properties.created_at).format('YYYY-MM-DD HH:mm:ss')}
</div>
</div>}
{selectedNode.label === 'Statement' && <>
{(['emotion_keywords', 'emotion_type', 'emotion_subject', 'importance_score'] as const).map(key => {
const statementProps = selectedNode.properties as StatementNodeProperties;
if ((key === 'emotion_keywords' && statementProps[key]?.length > 0) || typeof statementProps[key] === 'string') {
console.log('statementProps[key]', statementProps[key])
return (
<div className="rb:mt-4" key={key}>
{t(`userMemory.Statement_${key}`)}
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
{key === 'emotion_keywords'
? <Space>{statementProps.emotion_keywords.map((vo, index) => <Tag key={index}>{vo}</Tag>)}</Space>
: statementProps[key]
}
{(selectedNode as Node).properties.associative_memory > 0 && (
<div className="rb:mt-4">
<div className="rb:font-medium rb:leading-5">{t('userMemory.associative_memory')}</div>
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
<span className="rb:text-[#155EEF] rb:font-medium">{(selectedNode as Node).properties.associative_memory}</span>
{' '}{t('userMemory.unix')}{t('userMemory.associative_memory')}
</div>
</div>
)}
{selectedNode.label === 'Statement' && (
(['emotion_keywords', 'emotion_type', 'emotion_subject', 'importance_score'] as const).map(key => {
const p = selectedNode.properties as StatementNodeProperties
if ((key === 'emotion_keywords' && p[key]?.length > 0) || typeof p[key] === 'string') {
return (
<div className="rb:mt-4" key={key}>
{t(`userMemory.Statement_${key}`)}
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
{key === 'emotion_keywords'
? <Space>{p.emotion_keywords.map((v, i) => <Tag key={i}>{v}</Tag>)}</Space>
: p[key]}
</div>
</div>
</div>
)
}
return null
})}
</>}
{selectedNode.label === 'ExtractedEntity' && <>
{(['name', 'entity_type', 'aliases', 'connect_strngth', 'importance_score'] as const).map(key => {
const entityProps = selectedNode.properties as ExtractedEntityNodeProperties;
if (entityProps[key]) {
return (
<div className="rb:mt-4" key={key}>
{t(`userMemory.ExtractedEntity_${key}`)}
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
{Array.isArray(entityProps[key]) && entityProps[key].length > 0
? entityProps[key].map((vo, index) => <div key={index}>- {vo}</div>)
: entityProps[key]
}
)
}
return null
})
)}
{selectedNode.label === 'ExtractedEntity' && (
(['name', 'entity_type', 'aliases', 'connect_strngth', 'importance_score'] as const).map(key => {
const p = selectedNode.properties as ExtractedEntityNodeProperties
if (p[key]) {
return (
<div className="rb:mt-4" key={key}>
{t(`userMemory.ExtractedEntity_${key}`)}
<div className="rb:text-[#5B6167] rb:font-regular rb:leading-5 rb:mt-1 rb:pb-4 rb:border-b rb:border-[#DFE4ED]">
{Array.isArray(p[key]) && p[key].length > 0
? p[key].map((v, i) => <div key={i}>- {v}</div>)
: p[key]}
</div>
</div>
</div>
)
}
return null
})}
</>}
)
}
return null
})
)}
</div>
</div>
</div>
</>
</>
}
</div>
</RbCard>

View File

@@ -1,8 +1,8 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-03 17:57:15
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 17:57:15
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-13 11:49:52
*/
/**
* User Memory Detail Types
@@ -90,6 +90,7 @@ export interface ExtractedEntityNodeProperties {
connect_strngth: string;
importance_score: number;
associative_memory: number;
community_name?: string;
}
/**
* Memory summary node
@@ -246,4 +247,53 @@ export interface ForgetData {
*/
export interface GraphDetailRef {
handleOpen: (vo: Node) => void
}
}
// Community
export type CommunityNodeType = 'Community' | 'ExtractedEntity';
export type CommunityEdgeType = 'BELONGS_TO_COMMUNITY' | 'EXTRACTED_RELATIONSHIP';
export type CommunityEntityType = "Person" | "Organization" | "ORG" | "Location" | "LOC" | "Event" | "Concept" | "Time" | "Position" | "WorkRole" | "System" | "Policy" | "HistoricalPeriod" | "HistoricalState" | "HistoricalEvent" | "EconomicFactor" | "Condition" | "Numeric" | "Work";
// 社区节点
export interface CommunityTypeNode {
id: string;
label: 'Community';
properties: {
community_id: string;
end_user_id: string;
member_count: number;
updated_at: string;
name: string;
summary: string;
core_entities: string[];
member_entity_ids: string[];
};
}
// 核心实体
export interface ExtractedEntityTypeNode {
id: string;
label: 'ExtractedEntity';
properties: {
name: string;
end_user_id: string;
description: string;
created_at: string;
entity_type: CommunityEntityType;
community_name: string;
};
}
// 社区图谱连线
export interface CommunityEdge {
id: string;
target: string;
source: string;
}
export interface CommunityStatistics {
total_nodes: number;
total_edges: number;
node_types: Record<CommunityNodeType, number>;
edge_types: Record<CommunityEdgeType, number>;
}
export interface CommunityGraphData {
nodes: (CommunityTypeNode | ExtractedEntityTypeNode)[];
edges: CommunityEdge[];
statistics: CommunityStatistics;
}