diff --git a/api/app/aioRedis.py b/api/app/aioRedis.py index f758dd15..aac2aa84 100644 --- a/api/app/aioRedis.py +++ b/api/app/aioRedis.py @@ -1,10 +1,11 @@ -import os import asyncio import json import logging from typing import Dict, Any, Optional + import redis.asyncio as redis from redis.asyncio import ConnectionPool + from app.core.config import settings # 设置日志记录器 diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 21ee291d..60c22855 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -62,7 +62,7 @@ celery_app.conf.update( task_serializer='json', accept_content=['json'], result_serializer='json', - + # 时区 timezone='Asia/Shanghai', enable_utc=False, @@ -70,43 +70,44 @@ celery_app.conf.update( # 任务追踪 task_track_started=True, task_ignore_result=False, - + # 超时设置 task_time_limit=3600, # 60分钟硬超时 task_soft_time_limit=3000, # 50分钟软超时 - + # Worker 设置 (per-worker settings are in docker-compose command line) worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution - + # 结果过期时间 result_expires=3600, # 结果保存1小时 - + # 任务确认设置 task_acks_late=True, task_reject_on_worker_lost=True, worker_disable_rate_limits=True, - + # FLower setting worker_send_task_events=True, task_send_sent_event=True, - + # task routing task_routes={ # Memory tasks → memory_tasks queue (threads worker) 'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'}, 'app.core.memory.agent.read_message': {'queue': 'memory_tasks'}, 'app.core.memory.agent.write_message': {'queue': 'memory_tasks'}, - + 'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'}, + # Long-term storage tasks → memory_tasks queue (batched write strategies) 'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'}, 'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'}, 'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'}, - + # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, 'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'}, - + # Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker) 'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'}, 'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'}, @@ -131,7 +132,7 @@ implicit_emotions_update_schedule = crontab( minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE, ) -#构建定时任务配置 +# 构建定时任务配置 beat_schedule_config = { "run-workspace-reflection": { "task": "app.tasks.workspace_reflection_task", diff --git a/api/app/celery_worker.py b/api/app/celery_worker.py index 7d3ee686..4ea4fee1 100644 --- a/api/app/celery_worker.py +++ b/api/app/celery_worker.py @@ -13,4 +13,4 @@ logger.info("Celery worker logging initialized") # 导入任务模块以注册任务 import app.tasks -__all__ = ['celery_app'] \ No newline at end of file +__all__ = ['celery_app'] diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 85550f94..585de2ed 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -16,6 +16,7 @@ from . import ( file_controller, file_storage_controller, home_page_controller, + i18n_controller, implicit_memory_controller, knowledge_controller, knowledgeshare_controller, @@ -94,5 +95,6 @@ manager_router.include_router(memory_working_controller.router) manager_router.include_router(file_storage_controller.router) manager_router.include_router(ontology_controller.router) manager_router.include_router(skill_controller.router) +manager_router.include_router(i18n_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 63f484ca..059bec6b 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -844,6 +844,15 @@ async def draft_run_compare( raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) service._validate_app_accessible(app, workspace_id) + if payload.user_id is None: + end_user_repo = EndUserRepository(db) + new_end_user = end_user_repo.get_or_create_end_user( + app_id=app_id, + other_id=str(current_user.id), + original_user_id=str(current_user.id) # Save original user_id to other_id + ) + payload.user_id = str(new_end_user.id) + # 2. 获取 Agent 配置 from sqlalchemy import select from app.models import AgentConfig @@ -889,6 +898,8 @@ async def draft_run_compare( "conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id }) + + # 流式返回 if payload.stream: async def event_generator(): @@ -900,7 +911,7 @@ async def draft_run_compare( message=payload.message, workspace_id=workspace_id, conversation_id=payload.conversation_id, - user_id=payload.user_id or str(current_user.id), + user_id=payload.user_id, variables=payload.variables, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id, @@ -931,7 +942,7 @@ async def draft_run_compare( message=payload.message, workspace_id=workspace_id, conversation_id=payload.conversation_id, - user_id=payload.user_id or str(current_user.id), + user_id=payload.user_id, variables=payload.variables, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id, diff --git a/api/app/controllers/auth_controller.py b/api/app/controllers/auth_controller.py index 708cbaa2..2cc72a3b 100644 --- a/api/app/controllers/auth_controller.py +++ b/api/app/controllers/auth_controller.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta, timezone +from typing import Callable from fastapi import APIRouter, Depends from sqlalchemy.orm import Session @@ -16,6 +17,7 @@ from app.core.exceptions import BusinessException from app.core.error_codes import BizCode from app.dependencies import get_current_user, oauth2_scheme from app.models.user_model import User +from app.i18n.dependencies import get_translator # 获取专用日志器 auth_logger = get_auth_logger() @@ -26,7 +28,8 @@ router = APIRouter(tags=["Authentication"]) @router.post("/token", response_model=ApiResponse) async def login_for_access_token( form_data: TokenRequest, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + t: Callable = Depends(get_translator) ): """用户登录获取token""" auth_logger.info(f"用户登录请求: {form_data.email}") @@ -40,10 +43,10 @@ async def login_for_access_token( invite_info = workspace_service.validate_invite_token(db, form_data.invite) if not invite_info.is_valid: - raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST) + raise BusinessException(t("auth.invite.invalid"), code=BizCode.BAD_REQUEST) if invite_info.email != form_data.email: - raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST) + raise BusinessException(t("auth.invite.email_mismatch"), code=BizCode.BAD_REQUEST) auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}") try: # 尝试认证用户 @@ -69,7 +72,7 @@ async def login_for_access_token( elif e.code == BizCode.PASSWORD_ERROR: # 用户存在但密码错误 auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}") - raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED) + raise BusinessException(t("auth.invite.password_verification_failed"), BizCode.LOGIN_FAILED) else: # 其他认证失败情况,直接抛出 raise @@ -82,7 +85,7 @@ async def login_for_access_token( except BusinessException as e: # 其他认证失败情况,直接抛出 - raise BusinessException(e.message,BizCode.LOGIN_FAILED) + raise BusinessException(e.message, BizCode.LOGIN_FAILED) # 创建 tokens access_token, access_token_id = security.create_access_token(subject=user.id) @@ -110,14 +113,15 @@ async def login_for_access_token( expires_at=access_expires_at, refresh_expires_at=refresh_expires_at ), - msg="登录成功" + msg=t("auth.login.success") ) @router.post("/refresh", response_model=ApiResponse) async def refresh_token( refresh_request: RefreshTokenRequest, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + t: Callable = Depends(get_translator) ): """刷新token""" auth_logger.info("收到token刷新请求") @@ -125,18 +129,18 @@ async def refresh_token( # 验证 refresh token userId = security.verify_token(refresh_request.refresh_token, "refresh") if not userId: - raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID) + raise BusinessException(t("auth.token.invalid_refresh_token"), code=BizCode.TOKEN_INVALID) # 检查用户是否存在 user = auth_service.get_user_by_id(db, userId) if not user: - raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) + raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND) # 检查 refresh token 黑名单 if settings.ENABLE_SINGLE_SESSION: refresh_token_id = security.get_token_id(refresh_request.refresh_token) if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id): - raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED) + raise BusinessException(t("auth.token.refresh_token_blacklisted"), code=BizCode.TOKEN_BLACKLISTED) # 生成新 tokens new_access_token, new_access_token_id = security.create_access_token(subject=user.id) @@ -167,7 +171,7 @@ async def refresh_token( expires_at=access_expires_at, refresh_expires_at=refresh_expires_at ), - msg="token刷新成功" + msg=t("auth.token.refresh_success") ) @@ -175,14 +179,15 @@ async def refresh_token( async def logout( token: str = Depends(oauth2_scheme), current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), + t: Callable = Depends(get_translator) ): """登出当前用户:加入token黑名单并清理会话""" auth_logger.info(f"用户 {current_user.username} 请求登出") token_id = security.get_token_id(token) if not token_id: - raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID) + raise BusinessException(t("auth.token.invalid"), code=BizCode.TOKEN_INVALID) # 加入黑名单 await SessionService.blacklist_token(token_id) @@ -192,5 +197,5 @@ async def logout( await SessionService.clear_user_session(current_user.username) auth_logger.info(f"用户 {current_user.username} 登出成功") - return success(msg="登出成功") + return success(msg=t("auth.logout.success")) diff --git a/api/app/controllers/i18n_controller.py b/api/app/controllers/i18n_controller.py new file mode 100644 index 00000000..5dd07797 --- /dev/null +++ b/api/app/controllers/i18n_controller.py @@ -0,0 +1,833 @@ +""" +I18n Management API Controller + +This module provides management APIs for: +- Language management (list, get, add, update languages) +- Translation management (get, update, reload translations) +""" + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session +from typing import Callable, Optional + +from app.core.logging_config import get_api_logger +from app.core.response_utils import success +from app.db import get_db +from app.dependencies import get_current_user, get_current_superuser +from app.i18n.dependencies import get_translator +from app.i18n.service import get_translation_service +from app.models.user_model import User +from app.schemas.i18n_schema import ( + LanguageInfo, + LanguageListResponse, + LanguageCreateRequest, + LanguageUpdateRequest, + TranslationResponse, + TranslationUpdateRequest, + MissingTranslationsResponse, + ReloadResponse +) +from app.schemas.response_schema import ApiResponse + +api_logger = get_api_logger() + +router = APIRouter( + prefix="/i18n", + tags=["I18n Management"], +) + + +# ============================================================================ +# Language Management APIs +# ============================================================================ + +@router.get("/languages", response_model=ApiResponse) +def get_languages( + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_user) +): + """ + Get list of all supported languages. + + Returns: + List of language information including code, name, and status + """ + api_logger.info(f"Get languages request from user: {current_user.username}") + + from app.core.config import settings + translation_service = get_translation_service() + + # Get available locales from translation service + available_locales = translation_service.get_available_locales() + + # Build language info list + languages = [] + for locale in available_locales: + is_default = locale == settings.I18N_DEFAULT_LANGUAGE + is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES + + # Get native names + native_names = { + "zh": "中文(简体)", + "en": "English", + "ja": "日本語", + "ko": "한국어", + "fr": "Français", + "de": "Deutsch", + "es": "Español" + } + + language_info = LanguageInfo( + code=locale, + name=f"{locale.upper()}", + native_name=native_names.get(locale, locale), + is_enabled=is_enabled, + is_default=is_default + ) + languages.append(language_info) + + response = LanguageListResponse(languages=languages) + + api_logger.info(f"Returning {len(languages)} languages") + return success(data=response.dict(), msg=t("common.success.retrieved")) + + +@router.get("/languages/{locale}", response_model=ApiResponse) +def get_language( + locale: str, + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_user) +): + """ + Get information about a specific language. + + Args: + locale: Language code (e.g., 'zh', 'en') + + Returns: + Language information + """ + api_logger.info(f"Get language info request: locale={locale}, user={current_user.username}") + + from app.core.config import settings + translation_service = get_translation_service() + + # Check if locale exists + available_locales = translation_service.get_available_locales() + if locale not in available_locales: + api_logger.warning(f"Language not found: {locale}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=t("i18n.language.not_found", locale=locale) + ) + + # Build language info + is_default = locale == settings.I18N_DEFAULT_LANGUAGE + is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES + + native_names = { + "zh": "中文(简体)", + "en": "English", + "ja": "日本語", + "ko": "한국어", + "fr": "Français", + "de": "Deutsch", + "es": "Español" + } + + language_info = LanguageInfo( + code=locale, + name=f"{locale.upper()}", + native_name=native_names.get(locale, locale), + is_enabled=is_enabled, + is_default=is_default + ) + + api_logger.info(f"Returning language info for: {locale}") + return success(data=language_info.dict(), msg=t("common.success.retrieved")) + + +@router.post("/languages", response_model=ApiResponse) +def add_language( + request: LanguageCreateRequest, + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_superuser) +): + """ + Add a new language (admin only). + + Note: This endpoint validates the request but actual language addition + requires creating translation files in the locales directory. + + Args: + request: Language creation request + + Returns: + Success message + """ + api_logger.info( + f"Add language request: code={request.code}, admin={current_user.username}" + ) + + from app.core.config import settings + translation_service = get_translation_service() + + # Check if language already exists + available_locales = translation_service.get_available_locales() + if request.code in available_locales: + api_logger.warning(f"Language already exists: {request.code}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=t("i18n.language.already_exists", locale=request.code) + ) + + # Note: Actual language addition requires creating translation files + # This endpoint serves as a validation and documentation point + + api_logger.info( + f"Language addition validated: {request.code}. " + "Translation files need to be created manually." + ) + + return success( + msg=t( + "i18n.language.add_instructions", + locale=request.code, + dir=settings.I18N_CORE_LOCALES_DIR + ) + ) + + +@router.put("/languages/{locale}", response_model=ApiResponse) +def update_language( + locale: str, + request: LanguageUpdateRequest, + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_superuser) +): + """ + Update language configuration (admin only). + + Note: This endpoint validates the request but actual configuration + changes require updating environment variables or config files. + + Args: + locale: Language code + request: Language update request + + Returns: + Success message + """ + api_logger.info( + f"Update language request: locale={locale}, admin={current_user.username}" + ) + + translation_service = get_translation_service() + + # Check if language exists + available_locales = translation_service.get_available_locales() + if locale not in available_locales: + api_logger.warning(f"Language not found: {locale}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=t("i18n.language.not_found", locale=locale) + ) + + # Note: Actual configuration changes require updating settings + # This endpoint serves as a validation and documentation point + + api_logger.info( + f"Language update validated: {locale}. " + "Configuration changes require environment variable updates." + ) + + return success(msg=t("i18n.language.update_instructions", locale=locale)) + + +# ============================================================================ +# Translation Management APIs +# ============================================================================ + +@router.get("/translations", response_model=ApiResponse) +def get_all_translations( + locale: Optional[str] = None, + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_user) +): + """ + Get all translations for all or specific locale. + + Args: + locale: Optional locale filter + + Returns: + All translations organized by locale and namespace + """ + api_logger.info( + f"Get all translations request: locale={locale}, user={current_user.username}" + ) + + translation_service = get_translation_service() + + if locale: + # Get translations for specific locale + available_locales = translation_service.get_available_locales() + if locale not in available_locales: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=t("i18n.language.not_found", locale=locale) + ) + + translations = { + locale: translation_service._cache.get(locale, {}) + } + else: + # Get all translations + translations = translation_service._cache + + response = TranslationResponse(translations=translations) + + api_logger.info(f"Returning translations for: {locale or 'all locales'}") + return success(data=response.dict(), msg=t("common.success.retrieved")) + + +@router.get("/translations/{locale}", response_model=ApiResponse) +def get_locale_translations( + locale: str, + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_user) +): + """ + Get all translations for a specific locale. + + Args: + locale: Language code + + Returns: + All translations for the locale organized by namespace + """ + api_logger.info( + f"Get locale translations request: locale={locale}, user={current_user.username}" + ) + + translation_service = get_translation_service() + + # Check if locale exists + available_locales = translation_service.get_available_locales() + if locale not in available_locales: + api_logger.warning(f"Language not found: {locale}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=t("i18n.language.not_found", locale=locale) + ) + + translations = translation_service._cache.get(locale, {}) + + api_logger.info(f"Returning {len(translations)} namespaces for locale: {locale}") + return success(data={"locale": locale, "translations": translations}, msg=t("common.success.retrieved")) + + +@router.get("/translations/{locale}/{namespace}", response_model=ApiResponse) +def get_namespace_translations( + locale: str, + namespace: str, + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_user) +): + """ + Get translations for a specific namespace in a locale. + + Args: + locale: Language code + namespace: Translation namespace (e.g., 'common', 'auth') + + Returns: + Translations for the specified namespace + """ + api_logger.info( + f"Get namespace translations request: locale={locale}, " + f"namespace={namespace}, user={current_user.username}" + ) + + translation_service = get_translation_service() + + # Check if locale exists + available_locales = translation_service.get_available_locales() + if locale not in available_locales: + api_logger.warning(f"Language not found: {locale}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=t("i18n.language.not_found", locale=locale) + ) + + # Get namespace translations + locale_translations = translation_service._cache.get(locale, {}) + namespace_translations = locale_translations.get(namespace, {}) + + if not namespace_translations: + api_logger.warning(f"Namespace not found: {namespace} in locale: {locale}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=t("i18n.namespace.not_found", namespace=namespace, locale=locale) + ) + + api_logger.info( + f"Returning translations for namespace: {namespace} in locale: {locale}" + ) + return success( + data={ + "locale": locale, + "namespace": namespace, + "translations": namespace_translations + }, + msg=t("common.success.retrieved") + ) + + +@router.put("/translations/{locale}/{key:path}", response_model=ApiResponse) +def update_translation( + locale: str, + key: str, + request: TranslationUpdateRequest, + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_superuser) +): + """ + Update a single translation (admin only). + + Note: This endpoint validates the request but actual translation updates + require modifying translation files in the locales directory. + + Args: + locale: Language code + key: Translation key (format: "namespace.key.subkey") + request: Translation update request + + Returns: + Success message + """ + api_logger.info( + f"Update translation request: locale={locale}, key={key}, " + f"admin={current_user.username}" + ) + + translation_service = get_translation_service() + + # Check if locale exists + available_locales = translation_service.get_available_locales() + if locale not in available_locales: + api_logger.warning(f"Language not found: {locale}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=t("i18n.language.not_found", locale=locale) + ) + + # Validate key format + if "." not in key: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=t("i18n.translation.invalid_key_format", key=key) + ) + + # Note: Actual translation updates require modifying JSON files + # This endpoint serves as a validation and documentation point + + api_logger.info( + f"Translation update validated: {locale}/{key}. " + "Translation files need to be updated manually." + ) + + return success( + msg=t("i18n.translation.update_instructions", locale=locale, key=key) + ) + + +@router.get("/translations/missing", response_model=ApiResponse) +def get_missing_translations( + locale: Optional[str] = None, + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_user) +): + """ + Get list of missing translations. + + Compares translations across locales to find missing keys. + + Args: + locale: Optional locale to check (defaults to checking all non-default locales) + + Returns: + List of missing translation keys + """ + api_logger.info( + f"Get missing translations request: locale={locale}, user={current_user.username}" + ) + + from app.core.config import settings + translation_service = get_translation_service() + + default_locale = settings.I18N_DEFAULT_LANGUAGE + available_locales = translation_service.get_available_locales() + + # Get default locale translations as reference + default_translations = translation_service._cache.get(default_locale, {}) + + # Collect all keys from default locale + def collect_keys(data, prefix=""): + keys = [] + for key, value in data.items(): + full_key = f"{prefix}.{key}" if prefix else key + if isinstance(value, dict): + keys.extend(collect_keys(value, full_key)) + else: + keys.append(full_key) + return keys + + default_keys = set() + for namespace, translations in default_translations.items(): + namespace_keys = collect_keys(translations, namespace) + default_keys.update(namespace_keys) + + # Find missing keys in target locale(s) + missing_by_locale = {} + + target_locales = [locale] if locale else [ + loc for loc in available_locales if loc != default_locale + ] + + for target_locale in target_locales: + if target_locale not in available_locales: + continue + + target_translations = translation_service._cache.get(target_locale, {}) + target_keys = set() + + for namespace, translations in target_translations.items(): + namespace_keys = collect_keys(translations, namespace) + target_keys.update(namespace_keys) + + missing_keys = default_keys - target_keys + if missing_keys: + missing_by_locale[target_locale] = sorted(list(missing_keys)) + + response = MissingTranslationsResponse(missing_translations=missing_by_locale) + + total_missing = sum(len(keys) for keys in missing_by_locale.values()) + api_logger.info(f"Found {total_missing} missing translations across {len(missing_by_locale)} locales") + + return success(data=response.dict(), msg=t("common.success.retrieved")) + + +@router.post("/reload", response_model=ApiResponse) +def reload_translations( + locale: Optional[str] = None, + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_superuser) +): + """ + Trigger hot reload of translation files (admin only). + + Args: + locale: Optional locale to reload (defaults to reloading all locales) + + Returns: + Reload status and statistics + """ + api_logger.info( + f"Reload translations request: locale={locale or 'all'}, " + f"admin={current_user.username}" + ) + + from app.core.config import settings + + if not settings.I18N_ENABLE_HOT_RELOAD: + api_logger.warning("Hot reload is disabled in configuration") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=t("i18n.reload.disabled") + ) + + translation_service = get_translation_service() + + try: + # Reload translations + translation_service.reload(locale) + + # Get statistics + available_locales = translation_service.get_available_locales() + reloaded_locales = [locale] if locale else available_locales + + response = ReloadResponse( + success=True, + reloaded_locales=reloaded_locales, + total_locales=len(available_locales) + ) + + api_logger.info( + f"Successfully reloaded translations for: {', '.join(reloaded_locales)}" + ) + + return success(data=response.dict(), msg=t("i18n.reload.success")) + + except Exception as e: + api_logger.error(f"Failed to reload translations: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=t("i18n.reload.failed", error=str(e)) + ) + + +# ============================================================================ +# Performance Monitoring APIs +# ============================================================================ + +@router.get("/metrics", response_model=ApiResponse) +def get_metrics( + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_superuser) +): + """ + Get i18n performance metrics (admin only). + + Returns: + Performance metrics including: + - Request counts + - Missing translations + - Timing statistics + - Locale usage + - Error counts + """ + api_logger.info(f"Get metrics request: admin={current_user.username}") + + translation_service = get_translation_service() + metrics = translation_service.get_metrics_summary() + + api_logger.info("Returning i18n metrics") + return success(data=metrics, msg=t("common.success.retrieved")) + + +@router.get("/metrics/cache", response_model=ApiResponse) +def get_cache_stats( + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_superuser) +): + """ + Get cache statistics (admin only). + + Returns: + Cache statistics including: + - Hit/miss rates + - LRU cache performance + - Loaded locales + - Memory usage + """ + api_logger.info(f"Get cache stats request: admin={current_user.username}") + + translation_service = get_translation_service() + cache_stats = translation_service.get_cache_stats() + memory_usage = translation_service.get_memory_usage() + + data = { + "cache": cache_stats, + "memory": memory_usage + } + + api_logger.info("Returning cache statistics") + return success(data=data, msg=t("common.success.retrieved")) + + +@router.get("/metrics/prometheus") +def get_prometheus_metrics( + current_user: User = Depends(get_current_superuser) +): + """ + Get metrics in Prometheus format (admin only). + + Returns: + Prometheus-formatted metrics as plain text + """ + api_logger.info(f"Get Prometheus metrics request: admin={current_user.username}") + + from app.i18n.metrics import get_metrics + metrics = get_metrics() + prometheus_output = metrics.export_prometheus() + + from fastapi.responses import PlainTextResponse + return PlainTextResponse(content=prometheus_output) + + +@router.post("/metrics/reset", response_model=ApiResponse) +def reset_metrics( + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_superuser) +): + """ + Reset all metrics (admin only). + + Returns: + Success message + """ + api_logger.info(f"Reset metrics request: admin={current_user.username}") + + from app.i18n.metrics import get_metrics + metrics = get_metrics() + metrics.reset() + + translation_service = get_translation_service() + translation_service.cache.reset_stats() + + api_logger.info("Metrics reset completed") + return success(msg=t("i18n.metrics.reset_success")) + + +# ============================================================================ +# Missing Translation Logging and Reporting APIs +# ============================================================================ + +@router.get("/logs/missing", response_model=ApiResponse) +def get_missing_translation_logs( + locale: Optional[str] = None, + limit: Optional[int] = 100, + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_superuser) +): + """ + Get missing translation logs (admin only). + + Returns logged missing translations with context information. + + Args: + locale: Optional locale filter + limit: Maximum number of entries to return (default: 100) + + Returns: + Missing translation logs with context + """ + api_logger.info( + f"Get missing translation logs request: locale={locale}, " + f"limit={limit}, admin={current_user.username}" + ) + + translation_service = get_translation_service() + translation_logger = translation_service.translation_logger + + # Get missing translations + missing_translations = translation_logger.get_missing_translations(locale) + + # Get missing with context + missing_with_context = translation_logger.get_missing_with_context(locale, limit) + + # Get statistics + statistics = translation_logger.get_statistics() + + data = { + "missing_translations": missing_translations, + "recent_context": missing_with_context, + "statistics": statistics + } + + api_logger.info( + f"Returning {statistics['total_missing']} missing translations" + ) + return success(data=data, msg=t("common.success.retrieved")) + + +@router.get("/logs/missing/report", response_model=ApiResponse) +def generate_missing_translation_report( + locale: Optional[str] = None, + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_superuser) +): + """ + Generate a comprehensive missing translation report (admin only). + + Args: + locale: Optional locale filter + + Returns: + Comprehensive report with missing translations and statistics + """ + api_logger.info( + f"Generate missing translation report request: locale={locale}, " + f"admin={current_user.username}" + ) + + translation_service = get_translation_service() + translation_logger = translation_service.translation_logger + + # Generate report + report = translation_logger.generate_report(locale) + + api_logger.info( + f"Generated report with {report['total_missing']} missing translations" + ) + return success(data=report, msg=t("common.success.retrieved")) + + +@router.post("/logs/missing/export", response_model=ApiResponse) +def export_missing_translations( + locale: Optional[str] = None, + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_superuser) +): + """ + Export missing translations to JSON file (admin only). + + Args: + locale: Optional locale filter + + Returns: + Export status and file path + """ + api_logger.info( + f"Export missing translations request: locale={locale}, " + f"admin={current_user.username}" + ) + + from datetime import datetime + translation_service = get_translation_service() + translation_logger = translation_service.translation_logger + + # Generate filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + locale_suffix = f"_{locale}" if locale else "_all" + output_file = f"logs/i18n/missing_translations{locale_suffix}_{timestamp}.json" + + # Export to file + translation_logger.export_to_json(output_file) + + api_logger.info(f"Missing translations exported to: {output_file}") + return success( + data={"file_path": output_file}, + msg=t("i18n.logs.export_success", file=output_file) + ) + + +@router.delete("/logs/missing", response_model=ApiResponse) +def clear_missing_translation_logs( + locale: Optional[str] = None, + t: Callable = Depends(get_translator), + current_user: User = Depends(get_current_superuser) +): + """ + Clear missing translation logs (admin only). + + Args: + locale: Optional locale to clear (clears all if not specified) + + Returns: + Success message + """ + api_logger.info( + f"Clear missing translation logs request: locale={locale or 'all'}, " + f"admin={current_user.username}" + ) + + translation_service = get_translation_service() + translation_logger = translation_service.translation_logger + + # Clear logs + translation_logger.clear(locale) + + api_logger.info(f"Cleared missing translation logs for: {locale or 'all locales'}") + return success(msg=t("i18n.logs.clear_success")) diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py index 5a32372a..f827eaaf 100644 --- a/api/app/controllers/memory_reflection_controller.py +++ b/api/app/controllers/memory_reflection_controller.py @@ -1,3 +1,19 @@ +""" +Memory Reflection Controller + +This module provides REST API endpoints for managing memory reflection configurations +and operations. It handles reflection engine setup, configuration management, and +execution of self-reflection processes across memory systems. + +Key Features: +- Reflection configuration management (save, retrieve, update) +- Workspace-wide reflection execution across multiple applications +- Individual configuration-based reflection runs +- Multi-language support for reflection outputs +- Integration with Neo4j memory storage and LLM models +- Comprehensive error handling and logging +""" + import asyncio import time import uuid @@ -28,9 +44,13 @@ from sqlalchemy.orm import Session from app.utils.config_utils import resolve_config_id +# Load environment variables for configuration load_dotenv() + +# Initialize API logger for request tracking and debugging api_logger = get_api_logger() +# Configure router with prefix and tags for API organization router = APIRouter( prefix="/memory", tags=["Memory"], @@ -43,7 +63,38 @@ async def save_reflection_config( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: - """Save reflection configuration to data_comfig table""" + """ + Save reflection configuration to memory config table + + Persists reflection engine configuration settings to the data_config table, + including reflection parameters, model settings, and evaluation criteria. + Validates configuration parameters and ensures data consistency. + + Args: + request: Memory reflection configuration data including: + - config_id: Configuration identifier to update + - reflection_enabled: Whether reflection is enabled + - reflection_period_in_hours: Reflection execution interval + - reflexion_range: Scope of reflection (partial/all) + - baseline: Reflection strategy (time/fact/hybrid) + - reflection_model_id: LLM model for reflection operations + - memory_verify: Enable memory verification checks + - quality_assessment: Enable quality assessment evaluation + current_user: Authenticated user saving the configuration + db: Database session for data operations + + Returns: + dict: Success response with saved reflection configuration data + + Raises: + HTTPException 400: If config_id is missing or parameters are invalid + HTTPException 500: If configuration save operation fails + + Database Operations: + - Updates memory_config table with reflection settings + - Commits transaction and refreshes entity + - Maintains configuration consistency + """ try: config_id = request.config_id config_id = resolve_config_id(config_id, db) @@ -54,6 +105,7 @@ async def save_reflection_config( ) api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}") + # Update reflection configuration in database memory_config = MemoryConfigRepository.update_reflection_config( db, config_id=config_id, @@ -66,6 +118,7 @@ async def save_reflection_config( quality_assessment=request.quality_assessment ) + # Commit transaction and refresh entity db.commit() db.refresh(memory_config) @@ -102,13 +155,55 @@ async def start_workspace_reflection( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: - """启动工作空间中所有匹配应用的反思功能""" + """ + Start reflection functionality for all matching applications in workspace + + Initiates reflection processes across all applications within the user's current + workspace that have valid memory configurations. Processes each application's + configurations and associated end users, executing reflection operations + with proper error isolation and transaction management. + + This endpoint serves as a workspace-wide reflection orchestrator, ensuring + that reflection failures for individual users don't affect other operations. + + Args: + current_user: Authenticated user initiating workspace reflection + db: Database session for configuration queries + + Returns: + dict: Success response with reflection results for all processed applications: + - app_id: Application identifier + - config_id: Memory configuration identifier + - end_user_id: End user identifier + - reflection_result: Individual reflection operation result + + Processing Logic: + 1. Retrieve all applications in the current workspace + 2. Filter applications with valid memory configurations + 3. For each configuration, find matching releases + 4. Execute reflection for each end user with isolated transactions + 5. Aggregate results with error handling per user + + Error Handling: + - Individual user reflection failures are isolated + - Failed operations are logged and included in results + - Database transactions are isolated per user to prevent cascading failures + - Comprehensive error reporting for debugging + + Raises: + HTTPException 500: If workspace reflection initialization fails + + Performance Notes: + - Uses independent database sessions for each user operation + - Prevents transaction failures from affecting other users + - Comprehensive logging for operation tracking + """ workspace_id = current_user.current_workspace_id try: api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}") - # 使用独立的数据库会话来获取工作空间应用详情,避免事务失败 + # Use independent database session to get workspace app details, avoiding transaction failures from app.db import get_db_context with get_db_context() as query_db: service = WorkspaceAppService(query_db) @@ -116,8 +211,9 @@ async def start_workspace_reflection( reflection_results = [] + # Process each application in the workspace for data in result['apps_detailed_info']: - # 跳过没有配置的应用 + # Skip applications without configurations if not data['memory_configs']: api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过") continue @@ -126,22 +222,22 @@ async def start_workspace_reflection( memory_configs = data['memory_configs'] end_users = data['end_users'] - # 为每个配置和用户组合执行反思 + # Execute reflection for each configuration and user combination for config in memory_configs: config_id_str = str(config['config_id']) - # 找到匹配此配置的所有release + # Find all releases matching this configuration matching_releases = [r for r in releases if str(r['config']) == config_id_str] if not matching_releases: api_logger.debug(f"配置 {config_id_str} 没有匹配的release") continue - # 为每个用户执行反思 - 使用独立的数据库会话 + # Execute reflection for each user - using independent database sessions for user in end_users: api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}") - # 为每个用户创建独立的数据库会话,避免事务失败影响其他用户 + # Create independent database session for each user to avoid transaction failure impact with get_db_context() as user_db: try: reflection_service = MemoryReflectionService(user_db) @@ -184,14 +280,51 @@ async def start_reflection_configs( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: - """通过config_id查询memory_config表中的反思配置信息""" + """ + Query reflection configuration information by config_id + + Retrieves detailed reflection configuration settings from the memory_config + table for a specific configuration ID. Provides comprehensive reflection + parameters including model settings, evaluation criteria, and operational flags. + + Args: + config_id: Configuration identifier (UUID or integer) to query + current_user: Authenticated user making the request + db: Database session for data operations + + Returns: + dict: Success response with detailed reflection configuration: + - config_id: Resolved configuration identifier + - reflection_enabled: Whether reflection is enabled for this config + - reflection_period_in_hours: Reflection execution interval + - reflexion_range: Scope of reflection operations (partial/all) + - baseline: Reflection strategy (time/fact/hybrid) + - reflection_model_id: LLM model identifier for reflection + - memory_verify: Memory verification flag + - quality_assessment: Quality assessment flag + + Database Operations: + - Queries memory_config table by resolved config_id + - Retrieves all reflection-related configuration fields + - Resolves configuration ID for consistent formatting + + Raises: + HTTPException 404: If configuration with specified ID is not found + HTTPException 500: If configuration query operation fails + + ID Resolution: + - Supports both UUID and integer config_id formats + - Automatically resolves to appropriate internal format + - Maintains consistency across different ID representations + """ config_id = resolve_config_id(config_id, db) try: config_id=resolve_config_id(config_id,db) api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id) memory_config_id = resolve_config_id(result.config_id, db) - # 构建返回数据 + + # Build response data with comprehensive configuration details reflection_config = { "config_id": memory_config_id, "reflection_enabled": result.enable_self_reflexion, @@ -204,10 +337,12 @@ async def start_reflection_configs( } api_logger.info(f"成功查询反思配置,config_id: {config_id}") return success(data=reflection_config, msg="反思配置查询成功") - + api_logger.info(f"Successfully queried reflection config, config_id: {config_id}") + return success(data=reflection_config, msg="Reflection configuration query successful") + except HTTPException: - # 重新抛出HTTP异常 + # Re-raise HTTP exceptions without modification raise except Exception as e: api_logger.error(f"查询反思配置失败: {str(e)}") @@ -223,13 +358,66 @@ async def reflection_run( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: - """Activate the reflection function for all matching applications in the workspace""" - # 使用集中化的语言校验 + """ + Execute reflection engine with specified configuration + + Runs the reflection engine using configuration parameters from the database. + Validates model availability, sets up the reflection engine with proper + configuration, and executes the reflection process with multi-language support. + + This endpoint provides a test run capability for reflection configurations, + allowing users to validate their reflection settings and see results before + deploying to production environments. + + Args: + config_id: Configuration identifier (UUID or integer) for reflection settings + language_type: Language preference header for output localization (optional) + current_user: Authenticated user executing the reflection + db: Database session for configuration queries + + Returns: + dict: Success response with reflection execution results including: + - baseline: Reflection strategy used + - source_data: Input data processed + - memory_verifies: Memory verification results (if enabled) + - quality_assessments: Quality assessment results (if enabled) + - reflexion_data: Generated reflection insights and solutions + + Configuration Validation: + - Verifies configuration exists in database + - Validates LLM model availability + - Falls back to default model if specified model is unavailable + - Ensures all required parameters are properly set + + Reflection Engine Setup: + - Creates ReflectionConfig with database parameters + - Initializes Neo4j connector for memory access + - Sets up ReflectionEngine with validated model + - Configures language preferences for output + + Error Handling: + - Model validation with fallback to default + - Configuration validation and error reporting + - Comprehensive logging for debugging + - Graceful handling of missing configurations + + Raises: + HTTPException 404: If configuration is not found + HTTPException 500: If reflection execution fails + + Performance Notes: + - Direct database query for configuration retrieval + - Model validation to prevent runtime failures + - Efficient reflection engine initialization + - Language-aware output processing + """ + # Use centralized language validation for consistent localization language = get_language_from_header(language_type) api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") config_id = resolve_config_id(config_id, db) - # 使用MemoryConfigRepository查询反思配置 + + # Query reflection configuration using MemoryConfigRepository result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id) if not result: raise HTTPException( @@ -239,7 +427,7 @@ async def reflection_run( api_logger.info(f"成功查询反思配置,config_id: {config_id}") - # 验证模型ID是否存在 + # Validate model ID existence model_id = result.reflection_model_id if model_id: try: @@ -250,6 +438,7 @@ async def reflection_run( # 可以设置为None,让反思引擎使用默认模型 model_id = None + # Create reflection configuration with database parameters config = ReflectionConfig( enabled=result.enable_self_reflexion, iteration_period=result.iteration_period, @@ -262,11 +451,13 @@ async def reflection_run( model_id=model_id, language_type=language_type ) + + # Initialize Neo4j connector and reflection engine connector = Neo4jConnector() engine = ReflectionEngine( config=config, neo4j_connector=connector, - llm_client=model_id # 传入验证后的 model_id + llm_client=model_id # Pass validated model_id ) result=await (engine.reflection_run()) diff --git a/api/app/controllers/memory_short_term_controller.py b/api/app/controllers/memory_short_term_controller.py index 0acac6ce..b69406a8 100644 --- a/api/app/controllers/memory_short_term_controller.py +++ b/api/app/controllers/memory_short_term_controller.py @@ -1,3 +1,18 @@ +""" +Memory Short Term Controller + +This module provides REST API endpoints for managing short-term and long-term memory +data retrieval and analysis. It handles memory system statistics, data aggregation, +and provides comprehensive memory insights for end users. + +Key Features: +- Short-term memory data retrieval and statistics +- Long-term memory data aggregation +- Entity count integration +- Multi-language response support +- Memory system analytics and reporting +""" + from typing import Optional from dotenv import load_dotenv @@ -13,9 +28,13 @@ from app.models.user_model import User from app.services.memory_short_service import LongService, ShortService from app.services.memory_storage_service import search_entity +# Load environment variables for configuration load_dotenv() + +# Initialize API logger for request tracking and debugging api_logger = get_api_logger() +# Configure router with prefix and tags for API organization router = APIRouter( prefix="/memory/short", tags=["Memory"], @@ -27,24 +46,73 @@ async def short_term_configs( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): - # 使用集中化的语言校验 + """ + Retrieve comprehensive short-term and long-term memory statistics + + Provides a comprehensive overview of memory system data for a specific end user, + including short-term memory entries, long-term memory aggregations, entity counts, + and retrieval statistics. Supports multi-language responses based on request headers. + + This endpoint serves as a central dashboard for memory system analytics, combining + data from multiple memory subsystems to provide a holistic view of user memory state. + + Args: + end_user_id: Unique identifier for the end user whose memory data to retrieve + language_type: Language preference header for response localization (optional) + current_user: Authenticated user making the request (injected by dependency) + db: Database session for data operations (injected by dependency) + + Returns: + dict: Success response containing comprehensive memory statistics: + - short_term: List of short-term memory entries with detailed data + - long_term: List of long-term memory aggregations and summaries + - entity: Count of entities associated with the end user + - retrieval_number: Total count of short-term memory retrievals + - long_term_number: Total count of long-term memory entries + + Response Structure: + { + "code": 200, + "msg": "Short-term memory system data retrieved successfully", + "data": { + "short_term": [...], # Short-term memory entries + "long_term": [...], # Long-term memory data + "entity": 42, # Entity count + "retrieval_number": 156, # Short-term retrieval count + "long_term_number": 23 # Long-term memory count + } + } + + Raises: + HTTPException: If end_user_id is invalid or data retrieval fails + + Performance Notes: + - Combines multiple service calls for comprehensive data + - Entity search is performed asynchronously for better performance + - Response time depends on memory data volume for the specified user + """ + # Use centralized language validation for consistent localization language = get_language_from_header(language_type) - # 获取短期记忆数据 - short_term=ShortService(end_user_id, db) - short_result=short_term.get_short_databasets() - short_count=short_term.get_short_count() + # Retrieve short-term memory data and statistics + short_term = ShortService(end_user_id, db) + short_result = short_term.get_short_databasets() # Get short-term memory entries + short_count = short_term.get_short_count() # Get short-term retrieval count - long_term=LongService(end_user_id, db) - long_result=long_term.get_long_databasets() + # Retrieve long-term memory data and aggregations + long_term = LongService(end_user_id, db) + long_result = long_term.get_long_databasets() # Get long-term memory entries + # Get entity count for the specified end user entity_result = await search_entity(end_user_id) + + # Compile comprehensive memory statistics response result = { - 'short_term': short_result, - 'long_term': long_result, - 'entity': entity_result.get('num', 0), - "retrieval_number":short_count, - "long_term_number":len(long_result) + 'short_term': short_result, # Short-term memory entries + 'long_term': long_result, # Long-term memory data + 'entity': entity_result.get('num', 0), # Entity count (default to 0 if not found) + "retrieval_number": short_count, # Short-term retrieval statistics + "long_term_number": len(long_result) # Long-term memory entry count } return success(data=result, msg="短期记忆系统数据获取成功") \ No newline at end of file diff --git a/api/app/controllers/user_controller.py b/api/app/controllers/user_controller.py index 2806da1a..16213690 100644 --- a/api/app/controllers/user_controller.py +++ b/api/app/controllers/user_controller.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, Depends from sqlalchemy.orm import Session import uuid +from typing import Callable from app.core.error_codes import BizCode from app.core.exceptions import BusinessException @@ -19,6 +20,7 @@ from app.services import user_service from app.core.logging_config import get_api_logger from app.core.response_utils import success from app.core.security import verify_password +from app.i18n.dependencies import get_translator # 获取API专用日志器 api_logger = get_api_logger() @@ -33,7 +35,8 @@ router = APIRouter( def create_superuser( user: user_schema.UserCreate, db: Session = Depends(get_db), - current_superuser: User = Depends(get_current_superuser) + current_superuser: User = Depends(get_current_superuser), + t: Callable = Depends(get_translator) ): """创建超级管理员(仅超级管理员可访问)""" api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}") @@ -42,7 +45,7 @@ def create_superuser( api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})") result_schema = user_schema.User.model_validate(result) - return success(data=result_schema, msg="超级管理员创建成功") + return success(data=result_schema, msg=t("users.create.superuser_success")) @router.delete("/{user_id}", response_model=ApiResponse) @@ -50,6 +53,7 @@ def delete_user( user_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + t: Callable = Depends(get_translator) ): """停用用户(软删除)""" api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}") @@ -57,13 +61,14 @@ def delete_user( db=db, user_id_to_deactivate=user_id, current_user=current_user ) api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})") - return success(msg="用户停用成功") + return success(msg=t("users.delete.deactivate_success")) @router.post("/{user_id}/activate", response_model=ApiResponse) def activate_user( user_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + t: Callable = Depends(get_translator) ): """激活用户""" api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}") @@ -74,13 +79,14 @@ def activate_user( api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})") result_schema = user_schema.User.model_validate(result) - return success(data=result_schema, msg="用户激活成功") + return success(data=result_schema, msg=t("users.activate.success")) @router.get("", response_model=ApiResponse) def get_current_user_info( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + t: Callable = Depends(get_translator) ): """获取当前用户信息""" api_logger.info(f"当前用户信息请求: {current_user.username}") @@ -105,7 +111,7 @@ def get_current_user_info( break api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}") - return success(data=result_schema, msg="用户信息获取成功") + return success(data=result_schema, msg=t("users.info.get_success")) @router.get("/superusers", response_model=ApiResponse) @@ -113,6 +119,7 @@ def get_tenant_superusers( include_inactive: bool = False, db: Session = Depends(get_db), current_user: User = Depends(get_current_superuser), + t: Callable = Depends(get_translator) ): """获取当前租户下的超管账号列表(仅超级管理员可访问)""" api_logger.info(f"获取租户超管列表请求: {current_user.username}") @@ -125,7 +132,7 @@ def get_tenant_superusers( api_logger.info(f"租户超管列表获取成功: count={len(superusers)}") superusers_schema = [user_schema.User.model_validate(u) for u in superusers] - return success(data=superusers_schema, msg="租户超管列表获取成功") + return success(data=superusers_schema, msg=t("users.list.superusers_success")) @@ -134,6 +141,7 @@ def get_user_info_by_id( user_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + t: Callable = Depends(get_translator) ): """根据用户ID获取用户信息""" api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}") @@ -144,7 +152,7 @@ def get_user_info_by_id( api_logger.info(f"用户信息获取成功: {result.username}") result_schema = user_schema.User.model_validate(result) - return success(data=result_schema, msg="用户信息获取成功") + return success(data=result_schema, msg=t("users.info.get_success")) @router.put("/change-password", response_model=ApiResponse) @@ -152,6 +160,7 @@ async def change_password( request: ChangePasswordRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + t: Callable = Depends(get_translator) ): """修改当前用户密码""" api_logger.info(f"用户密码修改请求: {current_user.username}") @@ -164,7 +173,7 @@ async def change_password( current_user=current_user ) api_logger.info(f"用户密码修改成功: {current_user.username}") - return success(msg="密码修改成功") + return success(msg=t("auth.password.change_success")) @router.put("/admin/change-password", response_model=ApiResponse) @@ -172,6 +181,7 @@ async def admin_change_password( request: AdminChangePasswordRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_superuser), + t: Callable = Depends(get_translator) ): """超级管理员修改指定用户的密码""" api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}") @@ -186,16 +196,17 @@ async def admin_change_password( # 根据是否生成了随机密码来构造响应 if request.new_password: api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}") - return success(msg="密码修改成功") + return success(msg=t("auth.password.change_success")) else: api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成") - return success(data=generated_password, msg="密码重置成功") + return success(data=generated_password, msg=t("auth.password.reset_success")) @router.post("/verify_pwd", response_model=ApiResponse) def verify_pwd( request: VerifyPasswordRequest, current_user: User = Depends(get_current_user), + t: Callable = Depends(get_translator) ): """验证当前用户密码""" api_logger.info(f"用户验证密码请求: {current_user.username}") @@ -203,8 +214,8 @@ def verify_pwd( is_valid = verify_password(request.password, current_user.hashed_password) api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}") if not is_valid: - raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED) - return success(data={"valid": is_valid}, msg="验证完成") + raise BusinessException(t("users.errors.password_verification_failed"), code=BizCode.VALIDATION_FAILED) + return success(data={"valid": is_valid}, msg=t("common.success.retrieved")) @router.post("/send-email-code", response_model=ApiResponse) @@ -212,6 +223,7 @@ async def send_email_code( request: SendEmailCodeRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + t: Callable = Depends(get_translator) ): """发送邮箱验证码""" api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}") @@ -219,7 +231,7 @@ async def send_email_code( await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id) api_logger.info(f"邮箱验证码已发送: {current_user.username}") - return success(msg="验证码已发送到您的邮箱,请查收") + return success(msg=t("users.email.code_sent")) @router.put("/change-email", response_model=ApiResponse) @@ -227,6 +239,7 @@ async def change_email( request: VerifyEmailCodeRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + t: Callable = Depends(get_translator) ): """验证验证码并修改邮箱""" api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}") @@ -239,4 +252,51 @@ async def change_email( ) api_logger.info(f"用户邮箱修改成功: {current_user.username}") - return success(msg="邮箱修改成功") + return success(msg=t("users.email.change_success")) + + + +@router.get("/me/language", response_model=ApiResponse) +def get_current_user_language( + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), + t: Callable = Depends(get_translator) +): + """获取当前用户的语言偏好""" + api_logger.info(f"获取用户语言偏好: {current_user.username}") + + language = user_service.get_user_language_preference( + db=db, + user_id=current_user.id, + current_user=current_user + ) + + api_logger.info(f"用户语言偏好获取成功: {current_user.username}, language={language}") + return success( + data=user_schema.LanguagePreferenceResponse(language=language), + msg=t("users.language.get_success") + ) + + +@router.put("/me/language", response_model=ApiResponse) +def update_current_user_language( + request: user_schema.LanguagePreferenceRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), + t: Callable = Depends(get_translator) +): + """设置当前用户的语言偏好""" + api_logger.info(f"更新用户语言偏好: {current_user.username}, language={request.language}") + + updated_user = user_service.update_user_language_preference( + db=db, + user_id=current_user.id, + language=request.language, + current_user=current_user + ) + + api_logger.info(f"用户语言偏好更新成功: {current_user.username}, language={request.language}") + return success( + data=user_schema.LanguagePreferenceResponse(language=updated_user.preferred_language), + msg=t("users.language.update_success") + ) diff --git a/api/app/controllers/workspace_controller.py b/api/app/controllers/workspace_controller.py index 9bcd8571..6f4a4fa8 100644 --- a/api/app/controllers/workspace_controller.py +++ b/api/app/controllers/workspace_controller.py @@ -14,6 +14,12 @@ from app.dependencies import ( get_current_user, workspace_access_guard, ) +from app.i18n.dependencies import get_current_language, get_translator +from app.i18n.serializers import ( + WorkspaceSerializer, + WorkspaceMemberSerializer, + WorkspaceInviteSerializer +) from app.models.tenant_model import Tenants from app.models.user_model import User from app.models.workspace_model import InviteStatus @@ -65,7 +71,9 @@ def get_workspaces( include_current: bool = Query(True, description="是否包含当前工作空间"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), - current_tenant: Tenants = Depends(get_current_tenant) + current_tenant: Tenants = Depends(get_current_tenant), + language: str = Depends(get_current_language), + t: callable = Depends(get_translator) ): """获取当前租户下用户参与的所有工作空间 @@ -88,8 +96,13 @@ def get_workspaces( ) api_logger.info(f"成功获取 {len(workspaces)} 个工作空间") - workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces] - return success(data=workspaces_schema, msg="工作空间列表获取成功") + + # 使用序列化器添加国际化字段 + serializer = WorkspaceSerializer() + workspaces_data = [WorkspaceResponse.model_validate(w).model_dump() for w in workspaces] + workspaces_i18n = serializer.serialize_list(workspaces_data, language) + + return success(data=workspaces_i18n, msg=t("workspace.list_retrieved")) @router.post("", response_model=ApiResponse) @@ -98,6 +111,8 @@ def create_workspace( language_type: str = Header(default="zh", alias="X-Language-Type"), db: Session = Depends(get_db), current_user: User = Depends(get_current_superuser), + language: str = Depends(get_current_language), + t: callable = Depends(get_translator) ): """创建新的工作空间""" from app.core.language_utils import get_language_from_header @@ -118,8 +133,13 @@ def create_workspace( f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, " f"创建者: {current_user.username}, language={language}" ) - result_schema = WorkspaceResponse.model_validate(result) - return success(data=result_schema, msg="工作空间创建成功") + + # 使用序列化器添加国际化字段 + serializer = WorkspaceSerializer() + result_data = WorkspaceResponse.model_validate(result).model_dump() + result_i18n = serializer.serialize(result_data, language) + + return success(data=result_i18n, msg=t("workspace.created")) @router.put("", response_model=ApiResponse) @cur_workspace_access_guard() @@ -127,6 +147,8 @@ def update_workspace( workspace: WorkspaceUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + language: str = Depends(get_current_language), + t: callable = Depends(get_translator) ): """更新工作空间""" workspace_id = current_user.current_workspace_id @@ -139,14 +161,21 @@ def update_workspace( user=current_user, ) api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}") - result_schema = WorkspaceResponse.model_validate(result) - return success(data=result_schema, msg="工作空间更新成功") + + # 使用序列化器添加国际化字段 + serializer = WorkspaceSerializer() + result_data = WorkspaceResponse.model_validate(result).model_dump() + result_i18n = serializer.serialize(result_data, language) + + return success(data=result_i18n, msg=t("workspace.updated")) @router.get("/members", response_model=ApiResponse) @cur_workspace_access_guard() def get_cur_workspace_members( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + language: str = Depends(get_current_language), + t: callable = Depends(get_translator) ): """获取工作空间成员列表(关系序列化)""" api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表") @@ -157,8 +186,14 @@ def get_cur_workspace_members( user=current_user, ) api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}") + + # 转换为表格项并使用序列化器添加国际化字段 table_items = _convert_members_to_table_items(members) - return success(data=table_items, msg="工作空间成员列表获取成功") + serializer = WorkspaceMemberSerializer() + members_data = [item.model_dump() for item in table_items] + members_i18n = serializer.serialize_list(members_data, language) + + return success(data=members_i18n, msg=t("workspace.members.list_retrieved")) @router.put("/members", response_model=ApiResponse) @@ -168,6 +203,7 @@ def update_workspace_members( updates: List[WorkspaceMemberUpdate], db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + t: callable = Depends(get_translator) ): workspace_id = current_user.current_workspace_id api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色") @@ -178,7 +214,7 @@ def update_workspace_members( user=current_user, ) api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}") - return success(msg="成员角色更新成功") + return success(msg=t("workspace.members.role_updated")) @router.delete("/members/{member_id}", response_model=ApiResponse) @@ -187,6 +223,7 @@ def delete_workspace_member( member_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + t: callable = Depends(get_translator) ): workspace_id = current_user.current_workspace_id api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}") @@ -198,7 +235,7 @@ def delete_workspace_member( user=current_user, ) api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}") - return success(msg="成员删除成功") + return success(msg=t("workspace.members.deleted")) # 创建空间协作邀请 @@ -208,6 +245,8 @@ def create_workspace_invite( invite_data: WorkspaceInviteCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + language: str = Depends(get_current_language), + t: callable = Depends(get_translator) ): """创建工作空间邀请""" workspace_id = current_user.current_workspace_id @@ -220,7 +259,12 @@ def create_workspace_invite( user=current_user ) api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}") - return success(data=result, msg="邀请创建成功") + + # 使用序列化器添加国际化字段 + serializer = WorkspaceInviteSerializer() + result_i18n = serializer.serialize(result, language) + + return success(data=result_i18n, msg=t("workspace.invites.created")) @router.get("/invites", response_model=ApiResponse) @@ -232,6 +276,8 @@ def get_workspace_invites( offset: int = Query(0, ge=0), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + language: str = Depends(get_current_language), + t: callable = Depends(get_translator) ): """获取工作空间邀请列表""" workspace_id = current_user.current_workspace_id @@ -246,18 +292,30 @@ def get_workspace_invites( offset=offset ) api_logger.info(f"成功获取 {len(invites)} 个邀请记录") - return success(data=invites, msg="邀请列表获取成功") + + # 使用序列化器添加国际化字段 + serializer = WorkspaceInviteSerializer() + invites_i18n = serializer.serialize_list(invites, language) + + return success(data=invites_i18n, msg=t("workspace.invites.list_retrieved")) @public_router.get("/invites/validate/{token}", response_model=ApiResponse) def get_workspace_invite_info( token: str, db: Session = Depends(get_db), + language: str = Depends(get_current_language), + t: callable = Depends(get_translator) ): """获取工作空间邀请用户信息(无需认证)""" result = workspace_service.validate_invite_token(db=db, token=token) api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}") - return success(data=result, msg="邀请验证成功") + + # 使用序列化器添加国际化字段 + serializer = WorkspaceInviteSerializer() + result_i18n = serializer.serialize(result, language) + + return success(data=result_i18n, msg=t("workspace.invites.validated")) @router.delete("/invites/{invite_id}", response_model=ApiResponse) @@ -267,6 +325,8 @@ def revoke_workspace_invite( invite_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + language: str = Depends(get_current_language), + t: callable = Depends(get_translator) ): """撤销工作空间邀请""" workspace_id = current_user.current_workspace_id @@ -279,7 +339,12 @@ def revoke_workspace_invite( user=current_user ) api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}") - return success(data=result, msg="邀请撤销成功") + + # 使用序列化器添加国际化字段 + serializer = WorkspaceInviteSerializer() + result_i18n = serializer.serialize(result, language) + + return success(data=result_i18n, msg=t("workspace.invites.revoked")) # ==================== 公开邀请接口(无需认证) ==================== @@ -302,6 +367,7 @@ def switch_workspace( workspace_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + t: callable = Depends(get_translator) ): """切换工作空间""" api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}") @@ -312,7 +378,7 @@ def switch_workspace( user=current_user, ) api_logger.info(f"成功切换工作空间为 {workspace_id}") - return success(msg="工作空间切换成功") + return success(msg=t("workspace.switched")) @router.get("/storage", response_model=ApiResponse) @@ -320,6 +386,7 @@ def switch_workspace( def get_workspace_storage_type( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + t: callable = Depends(get_translator) ): """获取当前工作空间的存储类型""" workspace_id = current_user.current_workspace_id @@ -331,7 +398,7 @@ def get_workspace_storage_type( user=current_user ) api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}") - return success(data={"storage_type": storage_type}, msg="存储类型获取成功") + return success(data={"storage_type": storage_type}, msg=t("workspace.storage.type_retrieved")) @router.get("/workspace_models", response_model=ApiResponse) @@ -339,6 +406,8 @@ def get_workspace_storage_type( def workspace_models_configs( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + language: str = Depends(get_current_language), + t: callable = Depends(get_translator) ): """获取当前工作空间的模型配置(llm, embedding, rerank)""" workspace_id = current_user.current_workspace_id @@ -354,14 +423,14 @@ def workspace_models_configs( api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问") raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="工作空间不存在或无权访问" + detail=t("workspace.not_found") ) api_logger.info( f"成功获取工作空间 {workspace_id} 的模型配置: " f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}" ) - return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功") + return success(data=WorkspaceModelsConfig.model_validate(configs), msg=t("workspace.models.config_retrieved")) @router.put("/workspace_models", response_model=ApiResponse) @@ -370,6 +439,7 @@ def update_workspace_models_configs( models_update: WorkspaceModelsUpdate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), + t: callable = Depends(get_translator) ): """更新当前工作空间的模型配置(llm, embedding, rerank)""" workspace_id = current_user.current_workspace_id @@ -386,5 +456,5 @@ def update_workspace_models_configs( f"成功更新工作空间 {workspace_id} 的模型配置: " f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}" ) - return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功") + return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg=t("workspace.models.config_updated")) diff --git a/api/app/core/config.py b/api/app/core/config.py index 25713967..cdaa13cc 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -162,6 +162,44 @@ class Settings: # This controls the language used for memory summary titles and other generated content DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh") + # ======================================================================== + # Internationalization (i18n) Configuration + # ======================================================================== + # Default language for API responses + I18N_DEFAULT_LANGUAGE: str = os.getenv("I18N_DEFAULT_LANGUAGE", "zh") + + # Supported languages (comma-separated) + I18N_SUPPORTED_LANGUAGES: list[str] = [ + lang.strip() + for lang in os.getenv("I18N_SUPPORTED_LANGUAGES", "zh,en").split(",") + if lang.strip() + ] + + # Core locales directory (community edition) + # Use absolute path to work from any working directory + I18N_CORE_LOCALES_DIR: str = os.getenv( + "I18N_CORE_LOCALES_DIR", + os.path.join(os.path.dirname(os.path.dirname(__file__)), "locales") + ) + + # Premium locales directory (enterprise edition, optional) + I18N_PREMIUM_LOCALES_DIR: Optional[str] = os.getenv("I18N_PREMIUM_LOCALES_DIR", None) + + # Enable translation cache + I18N_ENABLE_TRANSLATION_CACHE: bool = os.getenv("I18N_ENABLE_TRANSLATION_CACHE", "true").lower() == "true" + + # LRU cache size for hot translations + I18N_LRU_CACHE_SIZE: int = int(os.getenv("I18N_LRU_CACHE_SIZE", "1000")) + + # Enable hot reload of translation files + I18N_ENABLE_HOT_RELOAD: bool = os.getenv("I18N_ENABLE_HOT_RELOAD", "false").lower() == "true" + + # Fallback language when translation is missing + I18N_FALLBACK_LANGUAGE: str = os.getenv("I18N_FALLBACK_LANGUAGE", "zh") + + # Log missing translations + I18N_LOG_MISSING_TRANSLATIONS: bool = os.getenv("I18N_LOG_MISSING_TRANSLATIONS", "true").lower() == "true" + # Logging settings LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO") LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s") diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py index 6595a2ce..ca08db76 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/data_nodes.py @@ -2,15 +2,37 @@ from app.core.memory.agent.utils.llm_tools import ReadState, WriteState def content_input_node(state: ReadState) -> ReadState: - """开始节点 - 提取内容并保持状态信息""" + """ + Start node - Extract content and maintain state information + + Extracts the content from the first message in the state and returns it + as the data field while preserving all other state information. + + Args: + state: ReadState containing messages and other state data + + Returns: + ReadState: Updated state with extracted content in data field + """ content = state['messages'][0].content if state.get('messages') else '' - # 返回内容并保持所有状态信息 + # Return content and maintain all state information return {"data": content} + def content_input_write(state: WriteState) -> WriteState: - """开始节点 - 提取内容并保持状态信息""" + """ + Start node - Extract content and maintain state information for write operations + + Extracts the content from the first message in the state for write operations. + + Args: + state: WriteState containing messages and other state data + + Returns: + WriteState: Updated state with extracted content in data field + """ content = state['messages'][0].content if state.get('messages') else '' - # 返回内容并保持所有状态信息 - return {"data": content} \ No newline at end of file + # Return content and maintain all state information + return {"data": content} diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index 784e5802..3030669c 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -19,19 +19,39 @@ logger = get_agent_logger(__name__) class ProblemNodeService(LLMServiceMixin): - """问题处理节点服务类""" + """ + Problem processing node service class + + Handles problem decomposition and extension operations using LLM services. + Inherits from LLMServiceMixin to provide structured LLM calling capabilities. + + Attributes: + template_service: Service for rendering Jinja2 templates + """ def __init__(self): super().__init__() self.template_service = TemplateService(template_root) -# 创建全局服务实例 +# Create global service instance problem_service = ProblemNodeService() async def Split_The_Problem(state: ReadState) -> ReadState: - """问题分解节点""" + """ + Problem decomposition node + + Breaks down complex user queries into smaller, more manageable sub-problems. + Uses LLM to analyze the input and generate structured problem decomposition + with question types and reasoning. + + Args: + state: ReadState containing user input and configuration + + Returns: + ReadState: Updated state with problem decomposition results + """ # 从状态中获取数据 content = state.get('data', '') end_user_id = state.get('end_user_id', '') @@ -64,7 +84,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState: # 添加更详细的日志记录 logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") - # 验证结构化响应 + # Validate structured response if not structured or not hasattr(structured, 'root'): logger.warning("Split_The_Problem: 结构化响应为空或格式不正确") split_result = json.dumps([], ensure_ascii=False) @@ -106,7 +126,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState: exc_info=True ) - # 提供更详细的错误信息 + # Provide more detailed error information error_details = { "error_type": type(e).__name__, "error_message": str(e), @@ -116,7 +136,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState: logger.error(f"Split_The_Problem error details: {error_details}") - # 创建默认的空结果 + # Create default empty result result = { "context": json.dumps([], ensure_ascii=False), "original": content, @@ -130,13 +150,25 @@ async def Split_The_Problem(state: ReadState) -> ReadState: } } - # 返回更新后的状态,包含spit_context字段 + # Return updated state including spit_context field return {"spit_data": result} async def Problem_Extension(state: ReadState) -> ReadState: - """问题扩展节点""" - # 获取原始数据和分解结果 + """ + Problem extension node + + Extends the decomposed problems from Split_The_Problem node by generating + additional related questions and organizing them by original question. + Uses LLM to create comprehensive question extensions for better memory retrieval. + + Args: + state: ReadState containing decomposed problems and configuration + + Returns: + ReadState: Updated state with extended problem results + """ + # Get original data and decomposition results start = time.time() content = state.get('data', '') data = state.get('spit_data', '')['context'] @@ -182,7 +214,7 @@ async def Problem_Extension(state: ReadState) -> ReadState: logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}") - # 验证结构化响应 + # Validate structured response if not response_content or not hasattr(response_content, 'root'): logger.warning("Problem_Extension: 结构化响应为空或格式不正确") aggregated_dict = {} @@ -216,7 +248,7 @@ async def Problem_Extension(state: ReadState) -> ReadState: exc_info=True ) - # 提供更详细的错误信息 + # Provide more detailed error information error_details = { "error_type": type(e).__name__, "error_message": str(e), diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py index 06539ad1..f2cd0d3d 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py @@ -29,6 +29,18 @@ logger = get_agent_logger(__name__) async def rag_config(state): + """ + Configure RAG (Retrieval-Augmented Generation) settings + + Creates configuration for knowledge base retrieval including similarity thresholds, + weights, and reranker settings. + + Args: + state: Current state containing user_rag_memory_id + + Returns: + dict: RAG configuration dictionary + """ user_rag_memory_id = state.get('user_rag_memory_id', '') kb_config = { "knowledge_bases": [ @@ -48,6 +60,19 @@ async def rag_config(state): async def rag_knowledge(state, question): + """ + Retrieve knowledge using RAG approach + + Performs knowledge retrieval from configured knowledge bases using the + provided question and returns formatted results. + + Args: + state: Current state containing configuration + question: Question to search for + + Returns: + tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results) + """ kb_config = await rag_config(state) end_user_id = state.get('end_user_id', '') user_rag_memory_id = state.get("user_rag_memory_id", '') @@ -68,12 +93,24 @@ async def rag_knowledge(state, question): async def llm_infomation(state: ReadState) -> ReadState: + """ + Get LLM configuration information from state + + Retrieves model configuration details including model ID and tenant ID + from the memory configuration in the current state. + + Args: + state: ReadState containing memory configuration + + Returns: + ReadState: Model configuration as Pydantic model + """ memory_config = state.get('memory_config', None) model_id = memory_config.llm_model_id tenant_id = memory_config.tenant_id - # 使用现有的 memory_config 而不是重新查询数据库 - # 或者使用线程安全的数据库访问 + # Use existing memory_config instead of re-querying database + # or use thread-safe database access with get_db_context() as db: result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id) result_pydantic = model_schema.ModelConfig.model_validate(result_orm) @@ -82,16 +119,20 @@ async def llm_infomation(state: ReadState) -> ReadState: async def clean_databases(data) -> str: """ - 简化的数据库搜索结果清理函数 + Simplified database search result cleaning function + + Processes and cleans search results from various sources including + reranked results and time-based search results. Extracts text content + from structured data and returns as formatted string. Args: - data: 搜索结果数据 + data: Search result data (can be string, dict, or other types) Returns: - 清理后的内容字符串 + str: Cleaned content string """ try: - # 解析JSON字符串 + # Parse JSON string if isinstance(data, str): try: data = json.loads(data) @@ -101,24 +142,24 @@ async def clean_databases(data) -> str: if not isinstance(data, dict): return str(data) - # 获取结果数据 + # Get result data # with open("搜索结果.json","w",encoding='utf-8') as f: # f.write(json.dumps(data, indent=4, ensure_ascii=False)) results = data.get('results', data) if not isinstance(results, dict): return str(results) - # 收集所有内容 + # Collect all content content_list = [] - # 处理重排序结果 + # Process reranked results reranked = results.get('reranked_results', {}) if reranked: for category in ['summaries', 'statements', 'chunks', 'entities']: items = reranked.get(category, []) if isinstance(items, list): content_list.extend(items) - # 处理时间搜索结果 + # Process time search results time_search = results.get('time_search', {}) if time_search: if isinstance(time_search, dict): @@ -128,7 +169,7 @@ async def clean_databases(data) -> str: elif isinstance(time_search, list): content_list.extend(time_search) - # 提取文本内容 + # Extract text content text_parts = [] for item in content_list: if isinstance(item, dict): @@ -146,10 +187,19 @@ async def clean_databases(data) -> str: async def retrieve_nodes(state: ReadState) -> ReadState: - ''' - - 模型信息 - ''' + """ + Retrieve information using simplified search approach + + Processes extended problems from previous nodes and performs retrieval + using either RAG or hybrid search based on storage type. Handles concurrent + processing of multiple questions and deduplicates results. + + Args: + state: ReadState containing problem extensions and configuration + + Returns: + ReadState: Updated state with retrieval results and intermediate outputs + """ problem_extension = state.get('problem_extension', '')['context'] storage_type = state.get('storage_type', '') @@ -163,7 +213,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState: problem_list.append(data) logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - # 创建异步任务处理单个问题 + # Create async task to process individual questions async def process_question_nodes(idx, question): try: # Prepare search parameters based on storage type @@ -209,7 +259,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState: } } - # 并发处理所有问题 + # Process all questions concurrently tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)] databases_anser = await asyncio.gather(*tasks) databases_data = { @@ -257,7 +307,20 @@ async def retrieve_nodes(state: ReadState) -> ReadState: async def retrieve(state: ReadState) -> ReadState: - # 从state中获取end_user_id + """ + Advanced retrieve function using LangChain agents and tools + + Uses LangChain agents with specialized retrieval tools (time-based and hybrid) + to perform sophisticated information retrieval. Supports both RAG and traditional + memory storage approaches with concurrent processing and result deduplication. + + Args: + state: ReadState containing problem extensions and configuration + + Returns: + ReadState: Updated state with retrieval results and intermediate outputs + """ + # Get end_user_id from state import time start = time.time() problem_extension = state.get('problem_extension', '')['context'] @@ -299,21 +362,21 @@ async def retrieve(state: ReadState) -> ReadState: system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}" ) - # 创建异步任务处理单个问题 + # Create async task to process individual questions import asyncio - # 在模块级别定义信号量,限制最大并发数 - SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作 + # Define semaphore at module level to limit maximum concurrency + SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations async def process_question(idx, question): - async with SEMAPHORE: # 限制并发 + async with SEMAPHORE: # Limit concurrency try: if storage_type == "rag" and user_rag_memory_id: retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question) else: cleaned_query = question - # 使用 asyncio 在线程池中运行同步的 agent.invoke + # Use asyncio to run synchronous agent.invoke in thread pool import asyncio response = await asyncio.get_event_loop().run_in_executor( None, @@ -362,7 +425,7 @@ async def retrieve(state: ReadState) -> ReadState: } } - # 并发处理所有问题 + # Process all questions concurrently import asyncio tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)] databases_anser = await asyncio.gather(*tasks) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index 87606bf8..030acc9a 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -23,18 +23,39 @@ logger = get_agent_logger(__name__) class SummaryNodeService(LLMServiceMixin): - """总结节点服务类""" + """ + Summary node service class + + Handles summary generation operations using LLM services. Inherits from + LLMServiceMixin to provide structured LLM calling capabilities for + generating summaries from retrieved information. + + Attributes: + template_service: Service for rendering Jinja2 templates + """ def __init__(self): super().__init__() self.template_service = TemplateService(template_root) -# 创建全局服务实例 +# Create global service instance summary_service = SummaryNodeService() async def rag_config(state): + """ + Configure RAG (Retrieval-Augmented Generation) settings for summary operations + + Creates configuration for knowledge base retrieval including similarity thresholds, + weights, and reranker settings specifically for summary generation. + + Args: + state: Current state containing user_rag_memory_id + + Returns: + dict: RAG configuration dictionary with knowledge base settings + """ user_rag_memory_id = state.get('user_rag_memory_id', '') kb_config = { "knowledge_bases": [ @@ -54,6 +75,23 @@ async def rag_config(state): async def rag_knowledge(state, question): + """ + Retrieve knowledge using RAG approach for summary generation + + Performs knowledge retrieval from configured knowledge bases using the + provided question and returns formatted results for summary processing. + + Args: + state: Current state containing configuration + question: Question to search for in knowledge base + + Returns: + tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results) + - retrieval_knowledge: List of retrieved knowledge chunks + - clean_content: Formatted content string + - cleaned_query: Processed query string + - raw_results: Raw retrieval results + """ kb_config = await rag_config(state) end_user_id = state.get('end_user_id', '') user_rag_memory_id = state.get("user_rag_memory_id", '') @@ -74,6 +112,18 @@ async def rag_knowledge(state, question): async def summary_history(state: ReadState) -> ReadState: + """ + Retrieve conversation history for summary context + + Gets the conversation history for the current user to provide context + for summary generation operations. + + Args: + state: ReadState containing end_user_id + + Returns: + ReadState: Conversation history data + """ end_user_id = state.get("end_user_id", '') history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) return history @@ -82,11 +132,26 @@ async def summary_history(state: ReadState) -> ReadState: async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model, search_mode) -> str: """ - 增强的summary_llm函数,包含更好的错误处理和数据验证 + Enhanced summary_llm function with better error handling and data validation + + Generates summaries using LLM with structured output. Includes fallback mechanisms + for handling LLM failures and provides robust error recovery. + + Args: + state: ReadState containing current context + history: Conversation history for context + retrieve_info: Retrieved information to summarize + template_name: Jinja2 template name for prompt generation + operation_name: Type of operation (summary, input_summary, retrieve_summary) + response_model: Pydantic model for structured output + search_mode: Search mode flag ("0" for simple, "1" for complex) + + Returns: + str: Generated summary text or fallback message """ data = state.get("data", '') - # 构建系统提示词 + # Build system prompt if str(search_mode) == "0": system_prompt = await summary_service.template_service.render_template( template_name=template_name, @@ -103,7 +168,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o retrieve_info=retrieve_info ) try: - # 使用优化的LLM服务进行结构化输出 + # Use optimized LLM service for structured output with get_db_context() as db_session: structured = await summary_service.call_llm_structured( state=state, @@ -112,23 +177,23 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o response_model=response_model, fallback_value=None ) - # 验证结构化响应 + # Validate structured response if structured is None: logger.warning("LLM返回None,使用默认回答") return "信息不足,无法回答" - # 根据操作类型提取答案 + # Extract answer based on operation type if operation_name == "summary": aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答" else: - # 处理RetrieveSummaryResponse + # Handle RetrieveSummaryResponse if hasattr(structured, 'data') and structured.data: aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答" else: logger.warning("结构化响应缺少data字段") aimessages = "信息不足,无法回答" - # 验证答案不为空 + # Validate answer is not empty if not aimessages or aimessages.strip() == "": aimessages = "信息不足,无法回答" @@ -137,7 +202,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o except Exception as e: logger.error(f"结构化输出失败: {e}", exc_info=True) - # 尝试非结构化输出作为fallback + # Try unstructured output as fallback try: logger.info("尝试非结构化输出作为fallback") response = await summary_service.call_llm_simple( @@ -148,9 +213,9 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o ) if response and response.strip(): - # 简单清理响应 + # Simple response cleaning cleaned_response = response.strip() - # 移除可能的JSON标记 + # Remove possible JSON markers if cleaned_response.startswith('```'): lines = cleaned_response.split('\n') cleaned_response = '\n'.join(lines[1:-1]) @@ -165,6 +230,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o async def summary_redis_save(state: ReadState, aimessages) -> ReadState: + """ + Save summary results to Redis session storage + + Stores the generated summary and user query in Redis for session management + and conversation history tracking. + + Args: + state: ReadState containing user and query information + aimessages: Generated summary message to save + + Returns: + ReadState: Updated state after saving to Redis + """ data = state.get("data", '') end_user_id = state.get("end_user_id", '') await SessionService(store).save_session( @@ -179,6 +257,20 @@ async def summary_redis_save(state: ReadState, aimessages) -> ReadState: async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState: + """ + Format summary results for different output types + + Creates structured output formats for both input summary and retrieval summary + operations, including metadata and intermediate results for frontend display. + + Args: + state: ReadState containing storage and user information + aimessages: Generated summary message + raw_results: Raw search/retrieval results + + Returns: + tuple: (input_summary, retrieve_summary) formatted result dictionaries + """ storage_type = state.get("storage_type", '') user_rag_memory_id = state.get("user_rag_memory_id", '') data = state.get("data", '') @@ -217,6 +309,19 @@ async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState async def Input_Summary(state: ReadState) -> ReadState: + """ + Generate quick input summary from retrieved information + + Performs fast retrieval and generates a quick summary response for user queries. + This function prioritizes speed by only searching summary nodes and provides + immediate feedback to users. + + Args: + state: ReadState containing user query, storage configuration, and context + + Returns: + ReadState: Dictionary containing summary results with status and metadata + """ start = time.time() storage_type = state.get("storage_type", '') memory_config = state.get('memory_config', None) @@ -266,6 +371,19 @@ async def Input_Summary(state: ReadState) -> ReadState: async def Retrieve_Summary(state: ReadState) -> ReadState: + """ + Generate comprehensive summary from retrieved expansion issues + + Processes retrieved expansion issues and generates a detailed summary using LLM. + This function handles complex retrieval results and provides comprehensive answers + based on expanded query results. + + Args: + state: ReadState containing retrieve data with expansion issues + + Returns: + ReadState: Dictionary containing comprehensive summary results + """ retrieve = state.get("retrieve", '') history = await summary_history(state) import json @@ -299,13 +417,26 @@ async def Retrieve_Summary(state: ReadState) -> ReadState: duration = 0.0 log_time('Retrieval summary', duration) - # 修复协程调用 - 先await,然后访问返回值 + # Fixed coroutine call - await first, then access return value summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary = summary_result[1] return {"summary": summary} async def Summary(state: ReadState) -> ReadState: + """ + Generate final comprehensive summary from verified data + + Creates the final summary using verified expansion issues and conversation history. + This function processes verified data to generate the most comprehensive and + accurate response to user queries. + + Args: + state: ReadState containing verified data and query information + + Returns: + ReadState: Dictionary containing final summary results + """ start = time.time() query = state.get("data", '') verify = state.get("verify", '') @@ -336,13 +467,26 @@ async def Summary(state: ReadState) -> ReadState: duration = 0.0 log_time('Retrieval summary', duration) - # 修复协程调用 - 先await,然后访问返回值 + # Fixed coroutine call - await first, then access return value summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary = summary_result[1] return {"summary": summary} async def Summary_fails(state: ReadState) -> ReadState: + """ + Generate fallback summary when normal summary process fails + + Provides a fallback summary generation mechanism when the standard summary + process encounters errors or fails to produce satisfactory results. Uses + a specialized failure template to handle edge cases. + + Args: + state: ReadState containing verified data and failure context + + Returns: + ReadState: Dictionary containing fallback summary results + """ storage_type = state.get("storage_type", '') user_rag_memory_id = state.get("user_rag_memory_id", '') history = await summary_history(state) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py index 3f7b491e..3a04b411 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py @@ -18,24 +18,46 @@ logger = get_agent_logger(__name__) class VerificationNodeService(LLMServiceMixin): - """验证节点服务类""" + """ + Verification node service class + + Handles data verification operations using LLM services. Inherits from + LLMServiceMixin to provide structured LLM calling capabilities for + verifying and validating retrieved information. + + Attributes: + template_service: Service for rendering Jinja2 templates + """ def __init__(self): super().__init__() self.template_service = TemplateService(template_root) -# 创建全局服务实例 +# Create global service instance verification_service = VerificationNodeService() async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): - """处理验证结果并生成输出格式""" + """ + Process verification results and generate output format + + Transforms VerificationResult objects into structured output format suitable + for frontend consumption. Handles conversion of VerificationItem objects to + dictionary format and adds metadata for tracking. + + Args: + state: ReadState containing storage and user configuration + messages_deal: VerificationResult containing verification outcomes + + Returns: + dict: Formatted verification result with status and metadata + """ storage_type = state.get('storage_type', '') user_rag_memory_id = state.get('user_rag_memory_id', '') data = state.get('data', '') - # 将 VerificationItem 对象转换为字典列表 + # Convert VerificationItem objects to dictionary list verified_data = [] if messages_deal.expansion_issue: for item in messages_deal.expansion_issue: @@ -89,7 +111,7 @@ async def Verify(state: ReadState): logger.info("Verify: 开始渲染模板") - # 生成 JSON schema 以指导 LLM 输出正确格式 + # Generate JSON schema to guide LLM output format json_schema = VerificationResult.model_json_schema() system_prompt = await verification_service.template_service.render_template( @@ -104,8 +126,8 @@ async def Verify(state: ReadState): # 使用优化的LLM服务,添加超时保护 logger.info("Verify: 开始调用 LLM") try: - # 添加 asyncio.wait_for 超时包裹,防止无限等待 - # 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长) + # Add asyncio.wait_for timeout wrapper to prevent infinite waiting + # Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds) with get_db_context() as db_session: structured = await asyncio.wait_for( @@ -122,7 +144,7 @@ async def Verify(state: ReadState): "reason": "验证失败或超时" } ), - timeout=150.0 # 150秒超时 + timeout=150.0 # 150 second timeout ) logger.info(f"Verify: LLM 调用完成,result={structured}") except asyncio.TimeoutError: diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index cba1b230..bddae618 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -33,7 +33,19 @@ from app.core.memory.agent.langgraph_graph.routing.routers import ( @asynccontextmanager async def make_read_graph(): - """创建并返回 LangGraph 工作流""" + """ + Create and return a LangGraph workflow for memory reading operations + + Builds a state graph workflow that handles memory retrieval, problem analysis, + verification, and summarization. The workflow includes nodes for content input, + problem splitting, retrieval, verification, and various summary operations. + + Yields: + StateGraph: Compiled LangGraph workflow for memory reading + + Raises: + Exception: If workflow creation fails + """ try: # Build workflow graph workflow = StateGraph(ReadState) @@ -48,7 +60,7 @@ async def make_read_graph(): workflow.add_node("Summary", Summary) workflow.add_node("Summary_fails", Summary_fails) - # 添加边 + # Add edges to define workflow flow workflow.add_edge(START, "content_input") workflow.add_conditional_edges("content_input", Split_continue) workflow.add_edge("Input_Summary", END) @@ -63,7 +75,7 @@ async def make_read_graph(): '''-----''' # workflow.add_edge("Retrieve", END) - # 编译工作流 + # Compile workflow graph = workflow.compile() yield graph @@ -72,108 +84,3 @@ async def make_read_graph(): raise finally: print("工作流创建完成") - - -async def main(): - """主函数 - 运行工作流""" - message = "昨天有什么好看的电影" - end_user_id = '88a459f5_text09' # 组ID - storage_type = 'neo4j' # 存储类型 - search_switch = '1' # 搜索开关 - user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID - - # 获取数据库会话 - db_session = next(get_db()) - config_service = MemoryConfigService(db_session) - memory_config = config_service.load_memory_config( - config_id=17, # 改为整数 - service_name="MemoryAgentService" - ) - import time - start = time.time() - try: - async with make_read_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, - "end_user_id": end_user_id - , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, - "memory_config": memory_config} - # 获取节点更新信息 - _intermediate_outputs = [] - summary = '' - - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - print(f"处理节点: {node_name}") - - # 处理不同Summary节点的返回结构 - if 'Summary' in node_name: - if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']: - summary = node_data['InputSummary']['summary_result'] - elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']: - summary = node_data['RetrieveSummary']['summary_result'] - elif 'summary' in node_data and 'summary_result' in node_data['summary']: - summary = node_data['summary']['summary_result'] - elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']: - summary = node_data['SummaryFails']['summary_result'] - - spit_data = node_data.get('spit_data', {}).get('_intermediate', None) - if spit_data and spit_data != [] and spit_data != {}: - _intermediate_outputs.append(spit_data) - - # Problem_Extension 节点 - problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None) - if problem_extension and problem_extension != [] and problem_extension != {}: - _intermediate_outputs.append(problem_extension) - - # Retrieve 节点 - retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None) - if retrieve_node and retrieve_node != [] and retrieve_node != {}: - _intermediate_outputs.extend(retrieve_node) - - # Verify 节点 - verify_n = node_data.get('verify', {}).get('_intermediate', None) - if verify_n and verify_n != [] and verify_n != {}: - _intermediate_outputs.append(verify_n) - - # Summary 节点 - summary_n = node_data.get('summary', {}).get('_intermediate', None) - if summary_n and summary_n != [] and summary_n != {}: - _intermediate_outputs.append(summary_n) - - # # 过滤掉空值 - # _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}] - # - # # 优化搜索结果 - # print("=== 开始优化搜索结果 ===") - # optimized_outputs = merge_multiple_search_results(_intermediate_outputs) - # result=reorder_output_results(optimized_outputs) - # # 保存优化后的结果到文件 - # with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f: - # import json - # f.write(json.dumps(result, indent=4, ensure_ascii=False)) - # - print(f"=== 最终摘要 ===") - print(summary) - - except Exception as e: - import traceback - traceback.print_exc() - finally: - db_session.close() - - end = time.time() - print(100 * 'y') - print(f"总耗时: {end - start}s") - print(100 * 'y') - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main()) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/routers.py b/api/app/core/memory/agent/langgraph_graph/routing/routers.py index 004e03b3..d6ca3333 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/routers.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/routers.py @@ -1,13 +1,13 @@ - from typing import Literal from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState - logger = get_agent_logger(__name__) counter = COUNTState(limit=3) -def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]: + + +def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]: """ Determine routing based on search_switch value. @@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa return 'Input_Summary' return 'Split_The_Problem' # 默认情况 + def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: """ Determine routing based on search_switch value. @@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]: elif search_switch == '1': return 'Retrieve_Summary' return 'Retrieve_Summary' # Default based on business logic + + def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]: - status=state.get('verify', '')['status'] + status = state.get('verify', '')['status'] # loop_count = counter.get_total() if "success" in status: # counter.reset() @@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co # if loop_count < 2: # Maximum loop count is 3 # return "content_input" # else: - # counter.reset() + # counter.reset() return "Summary_fails" else: # Add default return value to avoid returning None diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 895f61ac..6176caf5 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -2,77 +2,104 @@ import json import os from app.core.logging_config import get_agent_logger -from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse -from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage - +from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ -from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import count_store +from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context, get_db +from app.db import get_db_context from app.repositories.memory_short_repository import LongTermMemoryRepository from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.services.memory_konwledges_server import write_rag from app.services.task_service import get_task_memory_write_result from app.tasks import write_message_task from app.utils.config_utils import resolve_config_id + logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') + async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): - # RAG 模式:组合消息为字符串格式(保持原有逻辑) + """ + Write messages to RAG storage system + + Combines user and AI messages into a single string format and stores them + in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval. + + Args: + end_user_id: User identifier for the conversation + user_message: User's input message content + ai_message: AI's response message content + user_rag_memory_id: RAG memory identifier for storage location + """ + # RAG mode: combine messages into string format (maintain original logic) combined_message = f"user: {user_message}\nassistant: {ai_message}" await write_rag(end_user_id, combined_message, user_rag_memory_id) logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') -async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, - actual_config_id, long_term_messages=[]): + + +async def write( + storage_type, + end_user_id, + user_message, + ai_message, + user_rag_memory_id, + actual_end_user_id, + actual_config_id, + long_term_messages=None +): """ - 写入记忆(支持结构化消息) + Write memory with structured message support + + Handles memory writing operations for different storage types (Neo4j/RAG). + Supports both individual message pairs and batch long-term message processing. Args: - storage_type: 存储类型 (neo4j/rag) - end_user_id: 终端用户ID - user_message: 用户消息内容 - ai_message: AI 回复内容 - user_rag_memory_id: RAG 记忆ID - actual_end_user_id: 实际用户ID - actual_config_id: 配置ID + storage_type: Storage type identifier ("neo4j" or "rag") + end_user_id: Terminal user identifier + user_message: User message content + ai_message: AI response content + user_rag_memory_id: RAG memory identifier + actual_end_user_id: Actual user identifier for storage + actual_config_id: Configuration identifier + long_term_messages: Optional list of structured messages for batch processing - 逻辑说明: - - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 - - Neo4j 模式:使用结构化消息列表 - 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] - 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) - 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 + Logic explanation: + - RAG mode: Combines user_message and ai_message into string format, maintains original logic + - Neo4j mode: Uses structured message lists + 1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant] + 2. If only user_message exists: Creates single user message [user] (for historical memory scenarios) + 3. Each message is converted to independent Chunk, preserving speaker field """ - db = next(get_db()) - try: + if long_term_messages is None: + long_term_messages = [] + with get_db_context() as db: actual_config_id = resolve_config_id(actual_config_id, db) - # Neo4j 模式:使用结构化消息列表 + # Neo4j mode: Use structured message lists structured_messages = [] - # 始终添加用户消息(如果不为空) + # Always add user message (if not empty) if isinstance(user_message, str) and user_message.strip() != "": structured_messages.append({"role": "user", "content": user_message}) - # 只有当 AI 回复不为空时才添加 assistant 消息 + # Only add assistant message when AI reply is not empty if isinstance(ai_message, str) and ai_message.strip() != "": structured_messages.append({"role": "assistant", "content": ai_message}) - # 如果提供了 long_term_messages,使用它替代 structured_messages + # If long_term_messages provided, use it to replace structured_messages if long_term_messages and isinstance(long_term_messages, list): structured_messages = long_term_messages elif long_term_messages and isinstance(long_term_messages, str): - # 如果是 JSON 字符串,先解析 + # If it's a JSON string, parse it first try: structured_messages = json.loads(long_term_messages) except json.JSONDecodeError: logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}") - # 如果没有消息,直接返回 + # If no messages, return directly if not structured_messages: logger.warning(f"No messages to write for user {actual_end_user_id}") return @@ -80,29 +107,41 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me logger.info( f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") write_id = write_message_task.delay( - actual_end_user_id, # end_user_id: 用户ID - structured_messages, # message: JSON 字符串格式的消息列表 - str(actual_config_id), # config_id: 配置ID字符串 + actual_end_user_id, # end_user_id: User ID + structured_messages, # message: JSON string format message list + str(actual_config_id), # config_id: Configuration ID string storage_type, # storage_type: "neo4j" - user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) + user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) ) logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") write_status = get_task_memory_write_result(str(write_id)) logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') - finally: - db.close() -async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope): + +async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope): + """ + Save long-term memory data to database + + Handles the storage of long-term memory data based on different strategies + (chunk-based or aggregate-based) and manages the transition from short-term + to long-term memory storage. + + Args: + long_term_messages: Long-term message data to be saved + actual_config_id: Configuration identifier for memory settings + end_user_id: User identifier for memory association + type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE) + scope: Scope/window size for memory processing + """ with get_db_context() as db_session: repo = LongTermMemoryRepository(db_session) - from app.core.memory.agent.utils.redis_tool import write_store result = write_store.get_session_by_userid(end_user_id) - if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: data = await format_parsing(result, "dict") chunk_data = data[:scope] - if len(chunk_data)==scope: + if len(chunk_data) == scope: repo.upsert(end_user_id, chunk_data) logger.info(f'---------写入短长期-----------') else: @@ -112,18 +151,23 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type, logger.info(f'写入短长期:') +"""Window-based dialogue processing""" -'''根据窗口''' -async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): - ''' - 根据窗口获取redis数据,写入neo4j: - Args: - end_user_id: 终端用户ID - memory_config: 内存配置对象 - langchain_messages:原始数据LIST - scope:窗口大小 - ''' - scope=scope + +async def window_dialogue(end_user_id, langchain_messages, memory_config, scope): + """ + Process dialogue based on window size and write to Neo4j + + Manages conversation data based on a sliding window approach. When the window + reaches the specified scope size, it triggers long-term memory storage to Neo4j. + + Args: + end_user_id: Terminal user identifier + memory_config: Memory configuration object containing settings + langchain_messages: Original message data list + scope: Window size determining when to trigger long-term storage + """ + scope = scope is_end_user_id = count_store.get_sessions_count(end_user_id) if is_end_user_id is not False: is_end_user_id = count_store.get_sessions_count(end_user_id)[0] @@ -135,50 +179,72 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): elif int(is_end_user_id) == int(scope): logger.info('写入长期记忆NEO4J') formatted_messages = (redis_messages) - # 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用) + # Get config_id (if memory_config is an object, extract config_id; otherwise use directly) if hasattr(memory_config, 'config_id'): config_id = memory_config.config_id else: config_id = memory_config - - await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, - config_id, formatted_messages) + + await write( + AgentMemory_Long_Term.STORAGE_NEO4J, + end_user_id, + "", + "", + None, + end_user_id, + config_id, + formatted_messages + ) count_store.update_sessions_count(end_user_id, 1, langchain_messages) else: count_store.save_sessions_count(end_user_id, 1, langchain_messages) -"""根据时间""" -async def memory_long_term_storage(end_user_id,memory_config,time): - ''' - 根据时间获取redis数据,写入neo4j: - Args: - end_user_id: 终端用户ID - memory_config: 内存配置对象 - ''' +"""Time-based memory processing""" + + +async def memory_long_term_storage(end_user_id, memory_config, time): + """ + Process memory storage based on time intervals and write to Neo4j + + Retrieves Redis data based on time intervals and writes it to Neo4j for + long-term storage. This function handles time-based memory consolidation. + + Args: + end_user_id: Terminal user identifier + memory_config: Memory configuration object containing settings + time: Time interval for data retrieval + """ long_time_data = write_store.find_user_recent_sessions(end_user_id, time) - format_messages = (long_time_data) - messages=[] - memory_config=memory_config.config_id + format_messages = long_time_data + messages = [] + memory_config = memory_config.config_id for i in format_messages: - message=json.loads(i['Query']) - messages+= message - if format_messages!=[]: + message = json.loads(i['Query']) + messages += message + if format_messages: await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, memory_config, messages) -'''聚合判断''' + + async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: """ - 聚合判断函数:判断输入句子和历史消息是否描述同一事件 + Aggregation judgment function: determine if input sentence and historical messages describe the same event + + Uses LLM-based analysis to determine whether new messages should be aggregated with existing + historical data or stored as separate events. This helps optimize memory storage and retrieval. Args: - end_user_id: 终端用户ID - ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] - memory_config: 内存配置对象 - """ + end_user_id: Terminal user identifier + ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] + memory_config: Memory configuration object containing LLM settings + Returns: + dict: Aggregation judgment result containing is_same_event flag and processed output + """ + history = None try: - # 1. 获取历史会话数据(使用新方法) + # 1. Get historical session data (using new method) result = write_store.get_all_sessions_by_end_user_id(end_user_id) history = await format_parsing(result) if not result: @@ -210,7 +276,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config output_value = structured.output if isinstance(output_value, list): output_value = [ - {"role": msg.role, "content": msg.content} + {"role": msg.role, "content": msg.content} for msg in output_value ] @@ -223,16 +289,16 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config await write("neo4j", end_user_id, "", "", None, end_user_id, memory_config.config_id, output_value) return result_dict - + except Exception as e: print(f"[aggregate_judgment] 发生错误: {e}") import traceback traceback.print_exc() - + return { "is_same_event": False, "output": ori_messages, "messages": ori_messages, "history": history if 'history' in locals() else [], "error": str(e) - } \ No newline at end of file + } diff --git a/api/app/core/memory/agent/langgraph_graph/tools/tool.py b/api/app/core/memory/agent/langgraph_graph/tools/tool.py index fcbb18e3..9bd2b2cf 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/tool.py @@ -2,41 +2,53 @@ import asyncio import json from datetime import datetime, timedelta - from langchain.tools import tool from pydantic import BaseModel, Field - from app.core.memory.src.search import ( search_by_temporal, search_by_keyword_temporal, ) + def extract_tool_message_content(response): - """从agent响应中提取ToolMessage内容和工具名称""" + """ + Extract ToolMessage content and tool names from agent response + + Parses agent response messages to extract tool execution results and metadata. + Handles JSON parsing and provides structured access to tool output data. + + Args: + response: Agent response dictionary containing messages + + Returns: + dict: Dictionary containing tool_name and parsed content, or None if no tool message found + - tool_name: Name of the executed tool + - content: Parsed tool execution result (JSON or raw text) + """ messages = response.get('messages', []) for message in messages: if hasattr(message, 'tool_call_id') and hasattr(message, 'content'): - # 这是一个ToolMessage + # This is a ToolMessage tool_content = message.content tool_name = None - # 尝试获取工具名称 + # Try to get tool name if hasattr(message, 'name'): tool_name = message.name elif hasattr(message, 'tool_name'): tool_name = message.tool_name try: - # 解析JSON内容 + # Parse JSON content parsed_content = json.loads(tool_content) return { 'tool_name': tool_name, 'content': parsed_content } except json.JSONDecodeError: - # 如果不是JSON格式,直接返回内容 + # If not JSON format, return content directly return { 'tool_name': tool_name, 'content': tool_content @@ -46,38 +58,61 @@ def extract_tool_message_content(response): class TimeRetrievalInput(BaseModel): - """时间检索工具的输入模式""" + """ + Input schema for time retrieval tool + + Defines the expected input parameters for time-based retrieval operations. + Used for validation and documentation of tool parameters. + + Attributes: + context: User input query content for search + end_user_id: Group ID for filtering search results, defaults to test user + """ context: str = Field(description="用户输入的查询内容") end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果") + def create_time_retrieval_tool(end_user_id: str): """ - 创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements) + Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range + + Creates a specialized time-based retrieval tool that searches for statements within + specified time ranges. Includes field cleaning functionality to remove unnecessary + metadata from search results. + + Args: + end_user_id: User identifier for scoping search results + + Returns: + function: Configured TimeRetrievalWithGroupId tool function """ - + def clean_temporal_result_fields(data): """ - 清理时间搜索结果中不需要的字段,并修改结构 + Clean unnecessary fields from temporal search results and modify structure + + Removes metadata fields that are not needed for end-user consumption and + restructures the response format for better usability. Args: - data: 要清理的数据 + data: Data to be cleaned (dict, list, or other types) Returns: - 清理后的数据 + Cleaned data with unnecessary fields removed """ - # 需要过滤的字段列表 + # List of fields to filter out fields_to_remove = { - 'id', 'apply_id', 'user_id', 'chunk_id', 'created_at', + 'id', 'apply_id', 'user_id', 'chunk_id', 'created_at', 'valid_at', 'invalid_at', 'statement_ids' } - + if isinstance(data, dict): cleaned = {} for key, value in data.items(): if key == 'statements' and isinstance(value, dict) and 'statements' in value: - # 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]} + # Change statements: {"statements": [...]} to time_search: {"statements": [...]} cleaned_value = clean_temporal_result_fields(value) - # 进一步将内部的 statements 改为 time_search + # Further change internal statements to time_search if 'statements' in cleaned_value: cleaned['results'] = { 'time_search': cleaned_value['statements'] @@ -91,26 +126,35 @@ def create_time_retrieval_tool(end_user_id: str): return [clean_temporal_result_fields(item) for item in data] else: return data - + @tool - def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str: + def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, + end_user_id_param: str = None, clean_output: bool = True) -> str: """ - 优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段 - 显式接收参数: - - context: 查询上下文内容 - - start_date: 开始时间(可选,格式:YYYY-MM-DD) - - end_date: 结束时间(可选,格式:YYYY-MM-DD) - - end_user_id_param: 组ID(可选,用于覆盖默认组ID) - - clean_output: 是否清理输出中的元数据字段 - -end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d") + Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields + + Performs time-based search operations with automatic metadata filtering. Supports + flexible date range specification and provides clean, user-friendly output. + + Explicit parameters: + - context: Query context content + - start_date: Start time (optional, format: YYYY-MM-DD) + - end_date: End time (optional, format: YYYY-MM-DD) + - end_user_id_param: Group ID (optional, overrides default group ID) + - clean_output: Whether to clean metadata fields from output + - end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d") + + Returns: + str: JSON formatted search results with temporal data """ + async def _async_search(): - # 使用传入的参数或默认值 + # Use passed parameters or default values actual_end_user_id = end_user_id_param or end_user_id actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") - # 基本时间搜索 + # Basic time search results = await search_by_temporal( end_user_id=actual_end_user_id, start_date=actual_start_date, @@ -118,33 +162,43 @@ def create_time_retrieval_tool(end_user_id: str): limit=10 ) - # 清理结果中不需要的字段 + # Clean unnecessary fields from results if clean_output: cleaned_results = clean_temporal_result_fields(results) else: cleaned_results = results return json.dumps(cleaned_results, ensure_ascii=False, indent=2) - + return asyncio.run(_async_search()) @tool - def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str: + def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, + clean_output: bool = True) -> str: """ - 优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段 - 显式接收参数: - - context: 查询内容 - - days_back: 向前搜索的天数,默认7天 - - start_date: 开始时间(可选,格式:YYYY-MM-DD) - - end_date: 结束时间(可选,格式:YYYY-MM-DD) - - clean_output: 是否清理输出中的元数据字段 - - end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d") + Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields + + Performs combined keyword and temporal search operations with automatic metadata + filtering. Provides more targeted search results by combining content relevance + with time-based filtering. + + Explicit parameters: + - context: Query content for keyword matching + - days_back: Number of days to search backwards, default 7 days + - start_date: Start time (optional, format: YYYY-MM-DD) + - end_date: End time (optional, format: YYYY-MM-DD) + - clean_output: Whether to clean metadata fields from output + - end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d") + + Returns: + str: JSON formatted search results combining keyword and temporal data """ + async def _async_search(): actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d") - # 关键词时间搜索 + # Keyword time search results = await search_by_keyword_temporal( query_text=context, end_user_id=end_user_id, @@ -153,7 +207,7 @@ def create_time_retrieval_tool(end_user_id: str): limit=15 ) - # 清理结果中不需要的字段 + # Clean unnecessary fields from results if clean_output: cleaned_results = clean_temporal_result_fields(results) else: @@ -162,51 +216,60 @@ def create_time_retrieval_tool(end_user_id: str): return json.dumps(cleaned_results, ensure_ascii=False, indent=2) return asyncio.run(_async_search()) - + return TimeRetrievalWithGroupId def create_hybrid_retrieval_tool_async(memory_config, **search_params): """ - 创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段 + Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields + + Creates an advanced hybrid search tool that combines multiple search strategies + (keyword, vector, hybrid) with automatic result cleaning and formatting. Args: - memory_config: 内存配置对象 - **search_params: 搜索参数,包含end_user_id, limit, include等 + memory_config: Memory configuration object containing LLM and search settings + **search_params: Search parameters including end_user_id, limit, include, etc. + + Returns: + function: Configured HybridSearch tool function with async capabilities """ - + def clean_result_fields(data): """ - 递归清理结果中不需要的字段 + Recursively clean unnecessary fields from results + + Removes metadata fields that are not needed for end-user consumption, + improving readability and reducing response size. Args: - data: 要清理的数据(可能是字典、列表或其他类型) + data: Data to be cleaned (can be dict, list, or other types) Returns: - 清理后的数据 + Cleaned data with unnecessary fields removed """ - # 需要过滤的字段列表 - # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 + # List of fields to filter out + # TODO: fact_summary functionality temporarily disabled, will be enabled after future development fields_to_remove = { - 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', - 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', - 'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary" + 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', + 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', + 'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary" } - + if isinstance(data, dict): - # 对字典进行清理 + # Clean dictionary cleaned = {} for key, value in data.items(): if key not in fields_to_remove: - cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据 + cleaned[key] = clean_result_fields(value) # Recursively clean nested data return cleaned elif isinstance(data, list): - # 对列表中的每个元素进行清理 + # Clean each element in list return [clean_result_fields(item) for item in data] else: - # 其他类型直接返回 + # Return other types directly return data - + @tool async def HybridSearch( context: str, @@ -216,57 +279,63 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): rerank_alpha: float = 0.6, use_forgetting_rerank: bool = False, use_llm_rerank: bool = False, - clean_output: bool = True # 新增:是否清理输出字段 + clean_output: bool = True # New: whether to clean output fields ) -> str: """ - 优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段 + Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields + + Provides comprehensive search capabilities combining multiple search strategies + with intelligent result ranking and automatic metadata filtering for clean output. Args: - context: 查询内容 - search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') - limit: 结果数量限制 - end_user_id: 组ID,用于过滤搜索结果 - rerank_alpha: 重排序权重参数 - use_forgetting_rerank: 是否使用遗忘重排序 - use_llm_rerank: 是否使用LLM重排序 - clean_output: 是否清理输出中的元数据字段 + context: Query content for search + search_type: Search type ('keyword', 'embedding', 'hybrid') + limit: Result quantity limit + end_user_id: Group ID for filtering search results + rerank_alpha: Reranking weight parameter for result scoring + use_forgetting_rerank: Whether to use forgetting-based reranking + use_llm_rerank: Whether to use LLM-based reranking + clean_output: Whether to clean metadata fields from output + + Returns: + str: JSON formatted comprehensive search results """ try: - # 导入run_hybrid_search函数 + # Import run_hybrid_search function from app.core.memory.src.search import run_hybrid_search - # 合并参数,优先使用传入的参数 + # Merge parameters, prioritize passed parameters final_params = { "query_text": context, "search_type": search_type, "end_user_id": end_user_id or search_params.get("end_user_id"), "limit": limit or search_params.get("limit", 10), "include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]), - "output_path": None, # 不保存到文件 + "output_path": None, # Don't save to file "memory_config": memory_config, "rerank_alpha": rerank_alpha, "use_forgetting_rerank": use_forgetting_rerank, "use_llm_rerank": use_llm_rerank } - # 执行混合检索 + # Execute hybrid retrieval raw_results = await run_hybrid_search(**final_params) - # 清理结果中不需要的字段 + # Clean unnecessary fields from results if clean_output: cleaned_results = clean_result_fields(raw_results) else: cleaned_results = raw_results - # 格式化返回结果 + # Format return results formatted_results = { "search_query": context, "search_type": search_type, "results": cleaned_results } - + return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str) - + except Exception as e: error_result = { "error": f"混合检索失败: {str(e)}", @@ -275,38 +344,52 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): "timestamp": datetime.now().isoformat() } return json.dumps(error_result, ensure_ascii=False, indent=2) - + return HybridSearch def create_hybrid_retrieval_tool_sync(memory_config, **search_params): """ - 创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段 + Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields + + Creates a synchronous wrapper around the async hybrid search functionality, + making it compatible with synchronous tool execution environments. Args: - memory_config: 内存配置对象 - **search_params: 搜索参数 + memory_config: Memory configuration object containing search settings + **search_params: Search parameters for configuration + + Returns: + function: Configured HybridSearchSync tool function """ + @tool def HybridSearchSync( - context: str, - search_type: str = "hybrid", - limit: int = 10, - end_user_id: str = None, - clean_output: bool = True + context: str, + search_type: str = "hybrid", + limit: int = 10, + end_user_id: str = None, + clean_output: bool = True ) -> str: """ - 优化的混合检索工具(同步版本),自动过滤不需要的元数据字段 + Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields + + Provides the same hybrid search capabilities as the async version but in a + synchronous execution context. Automatically handles async-to-sync conversion. Args: - context: 查询内容 - search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') - limit: 结果数量限制 - end_user_id: 组ID,用于过滤搜索结果 - clean_output: 是否清理输出中的元数据字段 + context: Query content for search + search_type: Search type ('keyword', 'embedding', 'hybrid') + limit: Result quantity limit + end_user_id: Group ID for filtering search results + clean_output: Whether to clean metadata fields from output + + Returns: + str: JSON formatted search results """ + async def _async_search(): - # 创建异步工具并执行 + # Create async tool and execute async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params) return await async_tool.ainvoke({ "context": context, @@ -315,7 +398,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params): "end_user_id": end_user_id, "clean_output": clean_output }) - + return asyncio.run(_async_search()) - - return HybridSearchSync \ No newline at end of file + + return HybridSearchSync diff --git a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py index 9ce581ee..e11a2085 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py @@ -1,20 +1,28 @@ import json from langchain_core.messages import HumanMessage, AIMessage -async def format_parsing(messages: list,type:str='string'): + + +async def format_parsing(messages: list, type: str = 'string'): """ - 格式化解析消息列表 + Format and parse message lists into different output types + + Processes message lists from storage and converts them into either string format + or dictionary format based on the specified type parameter. Handles JSON parsing + and role-based message organization. Args: - messages: 消息列表 - type: 返回类型 ('string' 或 'dict') + messages: List of message objects from storage containing message data + type: Return type specification ('string' for text format, 'dict' for key-value pairs) Returns: - 格式化后的消息列表 + list: Formatted message list in the specified format + - 'string': List of formatted text messages with role prefixes + - 'dict': List of dictionaries mapping user messages to AI responses """ result = [] - user=[] - ai=[] + user = [] + ai = [] for message in messages: hstory_messages = message['messages'] @@ -24,25 +32,38 @@ async def format_parsing(messages: list,type:str='string'): role = content['role'] content = content['content'] if type == "string": - if role == 'human' or role=="user": + if role == 'human' or role == "user": content = '用户:' + content else: content = 'AI:' + content result.append(content) - if type == "dict" : - if role == 'human' or role=="user": - user.append( content) + if type == "dict": + if role == 'human' or role == "user": + user.append(content) else: ai.append(content) if type == "dict": - for key,values in zip(user,ai): - result.append({key:values}) + for key, values in zip(user, ai): + result.append({key: values}) return result + async def messages_parse(messages: list | dict): - user=[] - ai=[] - database=[] + """ + Parse messages from storage format into user-AI conversation pairs + + Extracts and organizes conversation data from stored message format, + separating user and AI messages and pairing them for database storage. + + Args: + messages: List or dictionary containing stored message data with Query fields + + Returns: + list: List of dictionaries containing user-AI message pairs for database storage + """ + user = [] + ai = [] + database = [] for message in messages: Query = message['Query'] Query = json.loads(Query) @@ -54,10 +75,23 @@ async def messages_parse(messages: list | dict): ai.append(data['content']) for key, values in zip(user, ai): database.append({key, values}) - return database + return database -async def agent_chat_messages(user_content,ai_content): +async def agent_chat_messages(user_content, ai_content): + """ + Create structured chat message format for agent conversations + + Formats user and AI content into a standardized message structure suitable + for agent processing and storage. Creates role-based message objects. + + Args: + user_content: User's message content string + ai_content: AI's response content string + + Returns: + list: List of structured message dictionaries with role and content fields + """ messages = [ { "role": "user", diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 1134acc7..bf3c6597 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -13,7 +13,6 @@ from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.services.memory_config_service import MemoryConfigService - warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) @@ -42,10 +41,26 @@ async def make_write_graph(): yield graph -async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): - from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment + +async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '', + end_user_id: str = '', scope: int = 6): + """ + Handle long-term memory storage with different strategies + + Supports multiple storage strategies including chunk-based, time-based, + and aggregate judgment approaches for long-term memory persistence. + + Args: + long_term_type: Storage strategy type ('chunk', 'time', 'aggregate') + langchain_messages: List of messages to store + memory_config: Memory configuration identifier + end_user_id: User group identifier + scope: Scope parameter for chunk-based storage (default: 6) + """ + from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \ + aggregate_judgment from app.core.memory.agent.utils.redis_tool import write_store - write_store.save_session_write(end_user_id, (langchain_messages)) + write_store.save_session_write(end_user_id, langchain_messages) # 获取数据库会话 with get_db_context() as db_session: config_service = MemoryConfigService(db_session) @@ -53,26 +68,39 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ config_id=memory_config, # 改为整数 service_name="MemoryAgentService" ) - if long_term_type=='chunk': - '''方案一:对话窗口6轮对话''' - await window_dialogue(end_user_id,langchain_messages,memory_config,scope) - if long_term_type=='time': - """时间""" - await memory_long_term_storage(end_user_id, memory_config,5) - if long_term_type=='aggregate': - """方案三:聚合判断""" + if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK: + '''Strategy 1: Dialogue window with 6 rounds of conversation''' + await window_dialogue(end_user_id, langchain_messages, memory_config, scope) + if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME: + """Time-based strategy""" + await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE) + if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE: + """Strategy 3: Aggregate judgment""" await aggregate_judgment(end_user_id, langchain_messages, memory_config) +async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id): + """ + Write long-term memory with different storage types -async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id): + Handles both RAG-based storage and traditional memory storage approaches. + For traditional storage, uses chunk-based strategy with paired user-AI messages. + + Args: + storage_type: Type of storage (RAG or traditional) + end_user_id: User group identifier + message_chat: User message content + aimessages: AI response messages + user_rag_memory_id: RAG memory identifier + actual_config_id: Actual configuration ID + """ from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save - from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages + from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages if storage_type == AgentMemory_Long_Term.STORAGE_RAG: await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) else: - # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + # AI reply writing (user messages and AI replies paired, written as complete dialogue at once) CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE long_term_messages = await agent_chat_messages(message_chat, aimessages) @@ -101,4 +129,4 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_ # # if __name__ == "__main__": # import asyncio -# asyncio.run(main()) \ No newline at end of file +# asyncio.run(main()) diff --git a/api/app/core/memory/agent/utils/llm_tools.py b/api/app/core/memory/agent/utils/llm_tools.py index 1c183422..ea8add48 100644 --- a/api/app/core/memory/agent/utils/llm_tools.py +++ b/api/app/core/memory/agent/utils/llm_tools.py @@ -8,10 +8,11 @@ from langgraph.graph import add_messages PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3]) + class WriteState(TypedDict): - ''' + """ Langgrapg Writing TypedDict - ''' + """ messages: Annotated[list[AnyMessage], add_messages] end_user_id: str errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] @@ -20,6 +21,7 @@ class WriteState(TypedDict): data: str language: str # 语言类型 ("zh" 中文, "en" 英文) + class ReadState(TypedDict): """ LangGraph 工作流状态定义 @@ -43,18 +45,20 @@ class ReadState(TypedDict): config_id: str data: str # 新增字段用于传递内容 spit_data: dict # 新增字段用于传递问题分解结果 - problem_extension:dict + problem_extension: dict storage_type: str user_rag_memory_id: str llm_id: str embedding_id: str memory_config: object # 新增字段用于传递内存配置对象 - retrieve:dict + retrieve: dict RetrieveSummary: dict InputSummary: dict verify: dict SummaryFails: dict summary: dict + + class COUNTState: """ 工作流对话检索内容计数器 @@ -99,6 +103,7 @@ class COUNTState: self.total = 0 print("[COUNTState] 已重置为 0") + def deduplicate_entries(entries): seen = set() deduped = [] @@ -109,6 +114,7 @@ def deduplicate_entries(entries): deduped.append(entry) return deduped + def merge_to_key_value_pairs(data, query_key, result_key): grouped = defaultdict(list) for item in data: @@ -142,4 +148,4 @@ def convert_extended_question_to_question(data): return [convert_extended_question_to_question(item) for item in data] else: # 其他类型直接返回 - return data \ No newline at end of file + return data diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py index 024e320a..147ed777 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py @@ -5,7 +5,7 @@ from typing import List, Dict, Optional from app.core.logging_config import get_memory_logger from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt -from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤 +from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤 from app.core.memory.models.triplet_models import TripletExtractionResponse from app.core.memory.models.message_models import DialogData, Statement from app.core.memory.models.ontology_extraction_models import OntologyTypeList @@ -14,15 +14,15 @@ from app.core.memory.utils.log.logging_utils import prompt_logger logger = get_memory_logger(__name__) - class TripletExtractor: """Extracts knowledge triplets and entities from statements using LLM""" def __init__( - self, - llm_client: OpenAIClient, - ontology_types: Optional[OntologyTypeList] = None, - language: str = "zh"): + self, + llm_client: OpenAIClient, + ontology_types: Optional[OntologyTypeList] = None, + language: str = "zh" + ): """Initialize the TripletExtractor with an LLM client Args: @@ -65,7 +65,8 @@ class TripletExtractor: # Create messages for LLM messages = [ - {"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."}, + {"role": "system", + "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."}, {"role": "user", "content": prompt_content} ] @@ -116,7 +117,8 @@ class TripletExtractor: logger.error(f"Error processing statement: {e}", exc_info=True) return TripletExtractionResponse(triplets=[], entities=[]) - async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]: + async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[ + str, TripletExtractionResponse]: """Extract triplets and entities from statements Args: diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index 09c7ef3d..b2a594c6 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -1,11 +1,11 @@ """ -自我反思引擎实现 +Self-Reflection Engine Implementation -该模块实现了记忆系统的自我反思功能,包括: -1. 基于时间的反思 - 根据时间周期触发反思 -2. 基于事实的反思 - 检测记忆冲突并解决 -3. 综合反思 - 整合多种反思策略 -4. 反思结果应用 - 更新记忆库 +This module implements the self-reflection functionality of the memory system, including: +1. Time-based reflection - Triggers reflection based on time cycles +2. Fact-based reflection - Detects and resolves memory conflicts +3. Comprehensive reflection - Integrates multiple reflection strategies +4. Reflection result application - Updates memory database """ import asyncio @@ -38,7 +38,7 @@ from app.schemas.memory_storage_schema import ( ) from pydantic import BaseModel -# 配置日志 +# Configure logging _root_logger = logging.getLogger() if not _root_logger.handlers: logging.basicConfig( @@ -49,35 +49,62 @@ else: _root_logger.setLevel(logging.INFO) class TranslationResponse(BaseModel): - """翻译响应模型""" + """Translation response model for language conversion""" data: str + class ReflectionRange(str, Enum): - """反思范围枚举""" - PARTIAL = "partial" # 从检索结果中反思 - ALL = "all" # 从整个数据库中反思 + """ + Reflection range enumeration + + Defines the scope of data to be included in reflection operations. + """ + PARTIAL = "partial" # Reflect from retrieval results + ALL = "all" # Reflect from entire database class ReflectionBaseline(str, Enum): - """反思基线枚举""" - TIME = "TIME" # 基于时间的反思 - FACT = "FACT" # 基于事实的反思 - HYBRID = "HYBRID" # 混合反思 + """ + Reflection baseline enumeration + + Defines the strategy or approach used for reflection operations. + """ + TIME = "TIME" # Time-based reflection + FACT = "FACT" # Fact-based reflection + HYBRID = "HYBRID" # Hybrid reflection combining multiple strategies class ReflectionConfig(BaseModel): - """反思引擎配置""" + """ + Reflection engine configuration + + Defines all configuration parameters for the reflection engine including + operation modes, model settings, and evaluation criteria. + + Attributes: + enabled: Whether reflection engine is enabled + iteration_period: Reflection cycle period (e.g., "3" hours) + reflexion_range: Scope of reflection (PARTIAL or ALL) + baseline: Reflection strategy (TIME, FACT, or HYBRID) + model_id: LLM model identifier for reflection operations + end_user_id: User identifier for scoped operations + output_example: Example output format for guidance + memory_verify: Enable memory verification checks + quality_assessment: Enable quality assessment evaluation + violation_handling_strategy: Strategy for handling violations + language_type: Language type for output ("zh" or "en") + """ enabled: bool = False - iteration_period: str = "3" # 反思周期 + iteration_period: str = "3" # Reflection cycle period reflexion_range: ReflectionRange = ReflectionRange.PARTIAL baseline: ReflectionBaseline = ReflectionBaseline.TIME - model_id: Optional[str] = None # 模型ID + model_id: Optional[str] = None # Model ID end_user_id: Optional[str] = None - output_example: Optional[str] = None # 输出示例 + output_example: Optional[str] = None # Output example - # 评估相关字段 - memory_verify: bool = True # 记忆验证 - quality_assessment: bool = True # 质量评估 - violation_handling_strategy: str = "warn" # 违规处理策略 + # Evaluation related fields + memory_verify: bool = True # Memory verification + quality_assessment: bool = True # Quality assessment + violation_handling_strategy: str = "warn" # Violation handling strategy language_type: str = "zh" class Config: @@ -85,7 +112,21 @@ class ReflectionConfig(BaseModel): class ReflectionResult(BaseModel): - """反思结果""" + """ + Reflection operation result + + Contains comprehensive information about the outcome of a reflection operation + including success status, metrics, and execution details. + + Attributes: + success: Whether the reflection operation succeeded + message: Descriptive message about the operation result + conflicts_found: Number of conflicts detected during reflection + conflicts_resolved: Number of conflicts successfully resolved + memories_updated: Number of memory entries updated in database + execution_time: Total time taken for the reflection operation + details: Additional details about the operation (optional) + """ success: bool message: str conflicts_found: int = 0 @@ -97,9 +138,22 @@ class ReflectionResult(BaseModel): class ReflectionEngine: """ - 自我反思引擎 - - 负责执行记忆系统的自我反思,包括冲突检测、冲突解决和记忆更新。 + Self-Reflection Engine + + Responsible for executing memory system self-reflection operations including + conflict detection, conflict resolution, and memory updates. Supports multiple + reflection strategies and provides comprehensive result tracking. + + The engine can operate in different modes: + - Time-based: Reflects on memories within specific time periods + - Fact-based: Detects and resolves factual conflicts in memories + - Hybrid: Combines multiple reflection strategies + + Attributes: + config: Reflection engine configuration + neo4j_connector: Neo4j database connector + llm_client: Language model client for analysis + Various function handlers for data processing and prompt rendering """ def __init__( @@ -115,18 +169,21 @@ class ReflectionEngine: update_query: Optional[str] = None ): """ - 初始化反思引擎 + Initialize reflection engine + + Sets up the reflection engine with configuration and optional dependencies. + Uses lazy initialization to avoid circular imports and optimize startup time. Args: - config: 反思引擎配置 - neo4j_connector: Neo4j 连接器(可选) - llm_client: LLM 客户端(可选) - get_data_func: 获取数据的函数(可选) - render_evaluate_prompt_func: 渲染评估提示词的函数(可选) - render_reflexion_prompt_func: 渲染反思提示词的函数(可选) - conflict_schema: 冲突结果 Schema(可选) - reflexion_schema: 反思结果 Schema(可选) - update_query: 更新查询语句(可选) + config: Reflection engine configuration object + neo4j_connector: Neo4j connector instance (optional, will be created if not provided) + llm_client: LLM client instance (optional, will be created if not provided) + get_data_func: Function for retrieving data (optional, uses default if not provided) + render_evaluate_prompt_func: Function for rendering evaluation prompts (optional) + render_reflexion_prompt_func: Function for rendering reflection prompts (optional) + conflict_schema: Schema for conflict result validation (optional) + reflexion_schema: Schema for reflection result validation (optional) + update_query: Query string for database updates (optional) """ self.config = config self.neo4j_connector = neo4j_connector @@ -137,14 +194,20 @@ class ReflectionEngine: self.conflict_schema = conflict_schema self.reflexion_schema = reflexion_schema self.update_query = update_query - self._semaphore = asyncio.Semaphore(5) # 默认并发数为5 + self._semaphore = asyncio.Semaphore(5) # Default concurrency limit of 5 - # 延迟导入以避免循环依赖 + # Lazy import to avoid circular dependencies self._lazy_init_done = False def _lazy_init(self): - """延迟初始化,避免循环导入""" + """ + Lazy initialization to avoid circular imports + + Initializes dependencies only when needed, preventing circular import issues + and optimizing startup performance. Sets up default implementations for + any components not provided during construction. + """ if self._lazy_init_done: return @@ -158,7 +221,7 @@ class ReflectionEngine: factory = MemoryClientFactory(db) self.llm_client = factory.get_llm_client(self.config.model_id) elif isinstance(self.llm_client, str): - # 如果 llm_client 是字符串(model_id),则用它初始化客户端 + # If llm_client is a string (model_id), use it to initialize the client from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.services.memory_config_service import MemoryConfigService @@ -172,10 +235,10 @@ class ReflectionEngine: model_config = config_service.get_model_config(model_id) extra_params={ - "temperature": 0.2, # 降低温度提高响应速度和一致性 - "max_tokens": 600, # 限制最大token数 - "top_p": 0.8, # 优化采样参数 - "stream": False, # 确保非流式输出以获得最快响应 + "temperature": 0.2, # Lower temperature for faster response and consistency + "max_tokens": 600, # Limit maximum token count + "top_p": 0.8, # Optimize sampling parameters + "stream": False, # Ensure non-streaming output for fastest response } self.llm_client = OpenAIClient(RedBearModelConfig( @@ -191,7 +254,7 @@ class ReflectionEngine: if self.get_data_func is None: self.get_data_func = get_data - # 导入get_data_statement函数 + # Import get_data_statement function if not hasattr(self, 'get_data_statement'): self.get_data_statement = get_data_statement @@ -223,13 +286,20 @@ class ReflectionEngine: async def execute_reflection(self, host_id) -> ReflectionResult: """ - 执行完整的反思流程 + Execute complete reflection workflow + + Performs the full reflection process including data retrieval, conflict detection, + conflict resolution, and memory updates. This is the main entry point for + reflection operations. + Args: - host_id: 主机ID + host_id: Host identifier for scoping reflection operations + Returns: - ReflectionResult: 反思结果 + ReflectionResult: Comprehensive result of the reflection operation including + success status, conflict metrics, and execution time """ - # 延迟初始化 + # Lazy initialization self._lazy_init() if not self.config.enabled: @@ -243,7 +313,7 @@ class ReflectionEngine: print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment) try: - # 1. 获取反思数据 + # 1. Get reflection data reflexion_data, statement_databasets = await self._get_reflexion_data(host_id) if not reflexion_data: return ReflectionResult( @@ -252,7 +322,7 @@ class ReflectionEngine: execution_time=asyncio.get_event_loop().time() - start_time ) - # 2. 检测冲突(基于事实的反思) + # 2. Detect conflicts (fact-based reflection) conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets) conflict_list=[] for i in conflict_data: @@ -261,7 +331,7 @@ class ReflectionEngine: conflicts_found=0 - # 3. 解决冲突 + # 3. Resolve conflicts solved_data = await self._resolve_conflicts(conflict_list, statement_databasets) if not solved_data: @@ -276,7 +346,7 @@ class ReflectionEngine: logging.info(f"解决了 {conflicts_resolved} 个冲突") - # 4. 应用反思结果(更新记忆库) + # 4. Apply reflection results (update memory database) memories_updated=await self._apply_reflection_results(solved_data) execution_time = asyncio.get_event_loop().time() - start_time @@ -302,7 +372,19 @@ class ReflectionEngine: ) async def Translate(self, text): - # 翻译中文为英文 + """ + Translate Chinese text to English + + Uses the configured LLM to translate Chinese text to English with structured output. + Provides consistent translation format for reflection results. + + Args: + text: Chinese text to be translated + + Returns: + str: Translated English text + """ + # Translate Chinese to English translation_messages = [ { "role": "user", @@ -316,6 +398,19 @@ class ReflectionEngine: ) return response.data async def extract_translation(self,data): + """ + Extract and translate reflection data to English + + Processes reflection data structure and translates all Chinese content to English. + Handles nested data structures including memory verifications, quality assessments, + and reflection data while preserving the original structure. + + Args: + data: Dictionary containing reflection data with Chinese content + + Returns: + dict: Translated data structure with English content + """ end_datas={} end_datas['source_data']=await self.Translate(data['source_data']) quality_assessments = [] @@ -350,6 +445,18 @@ class ReflectionEngine: return end_datas async def reflection_run(self): + """ + Execute reflection workflow with comprehensive data processing + + Performs a complete reflection operation including conflict detection, resolution, + and result formatting. Supports both Chinese and English output based on + configuration settings. + + Returns: + dict: Comprehensive reflection results including source data, memory verifications, + quality assessments, and reflection data. Results are translated to English + if language_type is set to 'en'. + """ self._lazy_init() start_time = time.time() memory_verifies_flag = self.config.memory_verify @@ -367,7 +474,7 @@ class ReflectionEngine: result_data['source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合" # 2. 检测冲突(基于事实的反思) conflict_data = await self._detect_conflicts(databasets, source_data) - # 遍历数据提取字段 + # Traverse data to extract fields quality_assessments = [] memory_verifies = [] for item in conflict_data: @@ -375,9 +482,9 @@ class ReflectionEngine: memory_verifies.append(item['memory_verify']) result_data['memory_verifies'] = memory_verifies result_data['quality_assessments'] = quality_assessments - conflicts_found = 0 # 初始化为整数0而不是空字符串 + conflicts_found = 0 # Initialize as integer 0 instead of empty string REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"} - # Clearn conflict_data,And memory_verify和quality_assessment + # Clean conflict_data, and memory_verify and quality_assessment cleaned_conflict_data = [] for item in conflict_data: cleaned_item = { @@ -389,7 +496,7 @@ class ReflectionEngine: for item in conflict_data: cleaned_data = [] for row in item.get("data", []): - # 删除 created_at / expired_at + # Remove created_at / expired_at cleaned_row = { k: v for k, v in row.items() @@ -402,7 +509,7 @@ class ReflectionEngine: } cleaned_conflict_data_.append(cleaned_item) print(cleaned_conflict_data_) - # 3. 解决冲突 + # 3. Resolve conflicts solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data) if not solved_data: return ReflectionResult( @@ -413,7 +520,7 @@ class ReflectionEngine: ) reflexion_data = [] - # 遍历数据提取reflexion字段 + # Traverse data to extract reflexion fields for item in solved_data: if 'results' in item: for result in item['results']: @@ -431,15 +538,24 @@ class ReflectionEngine: async def extract_fields_from_json(self): - """从example.json中提取source_data和databasets字段""" + """ + Extract source_data and databasets fields from example.json + + Reads reflection example data from the example.json file and extracts + the source data and database statements for testing and demonstration purposes. + + Returns: + tuple: (source_data, databasets) extracted from the example file + Returns empty lists if file reading fails + """ prompt_dir = os.path.join(os.path.dirname(__file__), "example") try: - # 读取JSON文件 + # Read JSON file with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f: data = json.loads(f.read()) - # 提取memory_verify下的字段 + # Extract fields under memory_verify memory_verify = data.get("memory_verify", {}) source_data = memory_verify.get("source_data", []) databasets = memory_verify.get("databasets", []) @@ -451,15 +567,17 @@ class ReflectionEngine: async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]: """ - 获取反思数据 - - 根据配置的反思范围获取需要反思的记忆数据。 + Get reflection data from database + + Retrieves memory data for reflection based on the configured reflection range. + Supports both partial (from retrieval results) and full (entire database) modes. Args: - host_id: 主机ID + host_id: Host UUID identifier for scoping data retrieval Returns: - List[Any]: 反思数据列表 + tuple: (reflexion_data, statement_data) containing memory data for reflection + Returns empty lists if query fails """ print("=== 获取反思数据 ===") @@ -484,26 +602,29 @@ class ReflectionEngine: async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]: """ - 检测冲突(基于事实的反思) - - 使用 LLM 分析记忆数据,检测其中的冲突。 + Detect conflicts (fact-based reflection) + + Uses LLM to analyze memory data and detect conflicts within the memories. + Performs comprehensive conflict detection including memory verification and + quality assessment based on configuration settings. Args: - data: 待检测的记忆数据 + data: Memory data to be analyzed for conflicts + statement_databasets: Statement database records for context Returns: - List[Any]: 冲突记忆列表 + List[Any]: List of detected conflicts with detailed analysis """ if not data: return [] - # 数据预处理:如果数据量太少,直接返回无冲突 + # Data preprocessing: if data is too small, return no conflicts directly if len(data) < 2: logging.info("数据量不足,无需检测冲突") return [] - # 使用转换后的数据 - # print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长 + # Use converted data + # print("Converted data:", data[:2] if len(data) > 2 else data) # Only print first 2 to avoid long logs memory_verify = self.config.memory_verify logging.info("====== 冲突检测开始 ======") @@ -512,7 +633,7 @@ class ReflectionEngine: language_type=self.config.language_type try: - # 渲染冲突检测提示词 + # Render conflict detection prompt rendered_prompt = await self.render_evaluate_prompt_func( data, self.conflict_schema, @@ -526,7 +647,7 @@ class ReflectionEngine: messages = [{"role": "user", "content": rendered_prompt}] logging.info(f"提示词长度: {len(rendered_prompt)}") - # 调用 LLM 进行冲突检测 + # Call LLM for conflict detection response = await self.llm_client.response_structured( messages, self.conflict_schema @@ -539,7 +660,7 @@ class ReflectionEngine: logging.error("LLM 冲突检测输出解析失败") return [] - # 标准化返回格式 + # Standardize return format if isinstance(response, BaseModel): return [response.model_dump()] elif hasattr(response, 'dict'): @@ -553,15 +674,17 @@ class ReflectionEngine: async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]: """ - 解决冲突 - - 使用 LLM 对检测到的冲突进行反思和解决。 + Resolve detected conflicts + + Uses LLM to perform reflection and resolution on detected conflicts. + Processes conflicts in parallel for efficiency while respecting concurrency limits. Args: - conflicts: 冲突列表 + conflicts: List of conflicts to be resolved + statement_databasets: Statement database records for context Returns: - List[Any]: 解决方案列表 + List[Any]: List of resolution solutions with reflection results """ if not conflicts: return [] @@ -570,12 +693,12 @@ class ReflectionEngine: baseline = self.config.baseline memory_verify = self.config.memory_verify - # 并行处理每个冲突 + # Process each conflict in parallel async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]: - """解决单个冲突""" + """Resolve a single conflict""" async with self._semaphore: try: - # 渲染反思提示词 + # Render reflection prompt rendered_prompt = await self.render_reflexion_prompt_func( [conflict], self.reflexion_schema, @@ -587,7 +710,7 @@ class ReflectionEngine: messages = [{"role": "user", "content": rendered_prompt}] - # 调用 LLM 进行反思 + # Call LLM for reflection response = await self.llm_client.response_structured( messages, self.reflexion_schema @@ -596,7 +719,7 @@ class ReflectionEngine: if not response: return None - # 标准化返回格式 + # Standardize return format if isinstance(response, BaseModel): return response.model_dump() elif hasattr(response, 'dict'): @@ -610,11 +733,11 @@ class ReflectionEngine: logging.warning(f"解决单个冲突失败: {e}") return None - # 并发执行所有冲突解决任务 + # Execute all conflict resolution tasks concurrently tasks = [_resolve_one(conflict) for conflict in conflicts] results = await asyncio.gather(*tasks, return_exceptions=False) - # 过滤掉失败的结果 + # Filter out failed results solved = [r for r in results if r is not None] logging.info(f"成功解决 {len(solved)}/{len(conflicts)} 个冲突") @@ -626,15 +749,16 @@ class ReflectionEngine: solved_data: List[Dict[str, Any]] ) -> int: """ - 应用反思结果(更新记忆库) - - 将解决冲突后的记忆更新到 Neo4j 数据库中。 + Apply reflection results (update memory database) + + Updates the Neo4j database with resolved conflicts and reflection results. + Processes the solved data and applies changes to the memory storage system. Args: - solved_data: 解决方案列表 + solved_data: List of resolved conflict solutions with reflection data Returns: - int: 成功更新的记忆数量 + int: Number of successfully updated memory entries """ changes = extract_and_process_changes(solved_data) success_count = await neo4j_data(changes) @@ -642,80 +766,86 @@ class ReflectionEngine: - # 基于时间的反思方法 + # Time-based reflection methods async def time_based_reflection( self, host_id: uuid.UUID, time_period: Optional[str] = None ) -> ReflectionResult: """ - 基于时间的反思 - - 根据时间周期触发反思,检查在指定时间段内的记忆。 + Time-based reflection + + Triggers reflection based on time cycles, checking memories within + specified time periods. Uses the configured iteration period if + no specific time period is provided. Args: - host_id: 主机ID - time_period: 时间周期(如"三小时"),如果不提供则使用配置中的值 + host_id: Host UUID identifier for scoping reflection + time_period: Time period (e.g., "three hours"), uses config value if not provided Returns: - ReflectionResult: 反思结果 + ReflectionResult: Comprehensive reflection operation result """ period = time_period or self.config.iteration_period logging.info(f"执行基于时间的反思,周期: {period}") - # 使用标准反思流程 + # Use standard reflection workflow return await self.execute_reflection(host_id) - # 基于事实的反思方法 + # Fact-based reflection methods async def fact_based_reflection( self, host_id: uuid.UUID ) -> ReflectionResult: """ - 基于事实的反思 - - 检测记忆中的事实冲突并解决。 + Fact-based reflection + + Detects and resolves factual conflicts within memories. Analyzes + memory data for inconsistencies and contradictions that need resolution. Args: - host_id: 主机ID + host_id: Host UUID identifier for scoping reflection Returns: - ReflectionResult: 反思结果 + ReflectionResult: Comprehensive reflection operation result """ logging.info("执行基于事实的反思") - # 使用标准反思流程 + # Use standard reflection workflow return await self.execute_reflection(host_id) - # 综合反思方法 + # Comprehensive reflection methods async def comprehensive_reflection( self, host_id: uuid.UUID ) -> ReflectionResult: """ - 综合反思 - - 整合基于时间和基于事实的反思策略。 + Comprehensive reflection + + Integrates time-based and fact-based reflection strategies based on + the configured baseline. Supports hybrid approaches that combine + multiple reflection methodologies. Args: - host_id: 主机ID + host_id: Host UUID identifier for scoping reflection Returns: - ReflectionResult: 反思结果 + ReflectionResult: Comprehensive reflection operation result combining + multiple strategies if using hybrid baseline """ logging.info("执行综合反思") - # 根据配置的基线选择反思策略 + # Choose reflection strategy based on configured baseline if self.config.baseline == ReflectionBaseline.TIME: return await self.time_based_reflection(host_id) elif self.config.baseline == ReflectionBaseline.FACT: return await self.fact_based_reflection(host_id) elif self.config.baseline == ReflectionBaseline.HYBRID: - # 混合策略:先执行基于时间的反思,再执行基于事实的反思 + # Hybrid strategy: execute time-based reflection first, then fact-based reflection time_result = await self.time_based_reflection(host_id) fact_result = await self.fact_based_reflection(host_id) - # 合并结果 + # Merge results return ReflectionResult( success=time_result.success and fact_result.success, message=f"时间反思: {time_result.message}; 事实反思: {fact_result.message}", diff --git a/api/app/core/memory/utils/data/text_utils.py b/api/app/core/memory/utils/data/text_utils.py index 133990f7..d0b10f97 100644 --- a/api/app/core/memory/utils/data/text_utils.py +++ b/api/app/core/memory/utils/data/text_utils.py @@ -2,9 +2,17 @@ import json def escape_lucene_query(query: str) -> str: - """Escape Lucene special characters in a free-text query. - - This prevents ParseException when using Neo4j full-text procedures. + """ + Escape special characters in Lucene queries + + Prevents ParseException when using Neo4j full-text search procedures. + Escapes all Lucene reserved special characters and operators. + + Args: + query: Original query string + + Returns: + str: Escaped query string safe for Lucene search """ if query is None: return "" @@ -22,11 +30,21 @@ def escape_lucene_query(query: str) -> str: return s def extract_plain_query(query_input: str) -> str: - """Extract clean, plain-text query from various input forms. - + """ + Extract clean plain-text query from various input forms + + Handles the following cases: - Strips surrounding quotes and whitespace - If input looks like JSON, prefers the 'original' field - - Fallbacks to the raw string when parsing fails + - Falls back to raw string when parsing fails + - Handles dictionary-type input + - Best-effort unescape common escape characters + + Args: + query_input: Query input in various forms (string, dict, etc.) + + Returns: + str: Extracted plain-text query string """ if query_input is None: return "" diff --git a/api/app/core/memory/utils/data/time_utils.py b/api/app/core/memory/utils/data/time_utils.py index c6791dfc..763c642c 100644 --- a/api/app/core/memory/utils/data/time_utils.py +++ b/api/app/core/memory/utils/data/time_utils.py @@ -4,7 +4,13 @@ from datetime import datetime def validate_date_format(date_str: str) -> bool: """ - Validate if the date string is in the format YYYY-MM-DD. + Validate if date string conforms to YYYY-MM-DD format + + Args: + date_str: Date string to validate + + Returns: + bool: True if format is correct, False otherwise """ pattern = r"^\d{4}-\d{1,2}-\d{1,2}$" return bool(re.match(pattern, date_str)) @@ -41,7 +47,20 @@ def normalize_date(date_str: str) -> str: def preprocess_date_string(date_str: str) -> str: - """预处理日期字符串,处理特殊格式""" + """ + 预处理日期字符串,处理特殊格式 + + 处理以下特殊格式: + - 年份后直接跟月份没有分隔符的格式(如 "20259/28") + - 无分隔符的纯数字格式(如 "20251028", "251028") + - 混合分隔符,统一为 "-" + + Args: + date_str: 原始日期字符串 + + Returns: + str: 预处理后的日期字符串,格式为 "YYYY-MM-DD" 或 "YYYY-MM" + """ # 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔) match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str) @@ -78,7 +97,23 @@ def preprocess_date_string(date_str: str) -> str: def fallback_parse(date_str: str) -> str: - """备选解析方案""" + """ + 备选日期解析方案 + + 当智能解析失败时,尝试使用预定义的日期格式进行解析。 + 支持多种常见的日期格式,包括: + - YYYY-MM-DD, YYYY/MM/DD, YYYY.MM.DD + - YYYYMMDD, YYMMDD + - MM-DD-YYYY, MM/DD/YYYY, MM.DD.YYYY + - DD-MM-YYYY, DD/MM/YYYY, DD.MM.YYYY + - YYYY-MM, YYYY/MM, YYYY.MM + + Args: + date_str: 待解析的日期字符串 + + Returns: + str: 标准化后的日期字符串(YYYY-MM-DD格式),解析失败时返回原字符串 + """ # 尝试常见的日期格式[citation:4][citation:5] formats_to_try = [ diff --git a/api/app/core/memory/utils/prompt/template_render.py b/api/app/core/memory/utils/prompt/template_render.py index 68e0ffe4..4df8d55b 100644 --- a/api/app/core/memory/utils/prompt/template_render.py +++ b/api/app/core/memory/utils/prompt/template_render.py @@ -2,15 +2,15 @@ import os from jinja2 import Environment, FileSystemLoader from typing import List, Dict, Any - # Setup Jinja2 environment prompt_dir = os.path.join(os.path.dirname(__file__), "prompts") prompt_env = Environment(loader=FileSystemLoader(prompt_dir)) + async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any, baseline: str = "TIME", - memory_verify: bool = False,quality_assessment:bool = False, - statement_databasets: List[str] = [],language_type:str = "zh") -> str: + memory_verify: bool = False, quality_assessment: bool = False, + statement_databasets=None, language_type: str = "zh") -> str: """ Renders the evaluate prompt using the evaluate_optimized.jinja2 template. @@ -23,6 +23,8 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any, Returns: Rendered prompt content as string """ + if statement_databasets is None: + statement_databasets = [] template = prompt_env.get_template("evaluate.jinja2") # Convert Pydantic model to JSON schema if needed @@ -46,7 +48,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any, async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False, - statement_databasets: List[str] = [],language_type:str = "zh") -> str: + statement_databasets=None, language_type: str = "zh") -> str: """ Renders the reflexion prompt using the reflexion_optimized.jinja2 template. @@ -58,6 +60,8 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s Returns: Rendered prompt content as a string. """ + if statement_databasets is None: + statement_databasets = [] template = prompt_env.get_template("reflexion.jinja2") # Convert Pydantic model to JSON schema if needed @@ -69,7 +73,7 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s json_schema = schema rendered_prompt = template.render(data=data, json_schema=json_schema, - baseline=baseline,memory_verify=memory_verify, - statement_databasets=statement_databasets,language_type=language_type) + baseline=baseline, memory_verify=memory_verify, + statement_databasets=statement_databasets, language_type=language_type) return rendered_prompt diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index dba6717d..4a453c6b 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -1,23 +1,19 @@ from __future__ import annotations -import asyncio import os -import time -from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, TypeVar +from typing import Any, Dict, Optional, TypeVar + +from langchain_aws import ChatBedrock +from langchain_community.chat_models import ChatTongyi +from langchain_core.embeddings import Embeddings +from langchain_core.language_models import BaseLLM +from langchain_ollama import OllamaLLM +from langchain_openai import ChatOpenAI, OpenAI +from pydantic import BaseModel, Field -import httpx from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.models.models_model import ModelProvider, ModelType -from langchain_community.document_compressors import JinaRerank -from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.embeddings import Embeddings -from langchain_core.language_models import BaseLanguageModel, BaseLLM -from langchain_core.outputs import Generation, LLMResult -from langchain_core.retrievers import BaseRetriever -from langchain_core.runnables import RunnableSerializable -from pydantic import BaseModel, Field T = TypeVar("T") @@ -163,25 +159,17 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy # dashscope 的 omni 模型使用 OpenAI 兼容模式 if provider == ModelProvider.DASHSCOPE and config.is_omni: - from langchain_openai import ChatOpenAI return ChatOpenAI - - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: if type == ModelType.LLM: - from langchain_openai import OpenAI return OpenAI elif type == ModelType.CHAT: - from langchain_openai import ChatOpenAI return ChatOpenAI elif provider == ModelProvider.DASHSCOPE: - from langchain_community.chat_models import ChatTongyi return ChatTongyi elif provider == ModelProvider.OLLAMA: - from langchain_ollama import OllamaLLM return OllamaLLM elif provider == ModelProvider.BEDROCK: - from langchain_aws import ChatBedrock, ChatBedrockConverse - return ChatBedrock else: raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 39c7887b..0e3fecee 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -16,6 +16,7 @@ from app.core.workflow.variable.base_variable import VariableType, FileObject from app.db import get_db_read from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy from app.schemas import FileInput +from app.schemas.model_schema import ModelInfo from app.services.multimodal_service import MultimodalService logger = logging.getLogger(__name__) @@ -620,11 +621,12 @@ class BaseNode(ABC): @staticmethod async def process_message( - provider: str, - is_omni: bool, + api_config: ModelInfo, content: str | dict | FileObject, + end_user_id: str, enable_file=False ) -> list | str | None: + provider = api_config.provider if isinstance(content, dict): content = FileObject( type=content.get("type"), @@ -643,7 +645,7 @@ class BaseNode(ABC): if content.content_cache.get(provider): return content.content_cache[provider] with get_db_read() as db: - multimodel_service = MultimodalService(db, provider, is_omni=is_omni) + multimodel_service = MultimodalService(db, api_config=api_config) file_obj = FileInput( type=content.type, url=content.url, @@ -653,7 +655,8 @@ class BaseNode(ABC): ) file_obj.set_content(content.get_content()) message = await multimodel_service.process_files( - [file_obj] + end_user_id, + [file_obj], ) content.set_content(file_obj.get_content()) if message: diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 29f7085b..7e98efab 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -5,7 +5,7 @@ from typing import Any from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode -from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator +from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType from app.core.workflow.nodes.if_else import IfElseNodeConfig from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance from app.core.workflow.variable.base_variable import VariableType @@ -23,6 +23,26 @@ class IfElseNode(BaseNode): "output": VariableType.STRING } + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: + result = [] + for case in self.typed_config.cases: + expressions = [] + for expression in case.expressions: + expressions.append({ + "left": self.get_variable(expression.left, variable_pool, strict=False), + "right": expression.right + if expression.input_type == ValueInputType.CONSTANT + else self.get_variable(expression.right, variable_pool, strict=False), + "operator": expression.operator, + }) + result.append({ + "expressions": expressions, + "logical_operator": case.logical_operator, + }) + return { + "cases": result + } + @staticmethod def _evaluate(operator, instance: CompareOperatorInstance) -> Any: match operator: diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 696298eb..14f789a9 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -30,6 +30,12 @@ class KnowledgeRetrievalNode(BaseNode): "output": VariableType.ARRAY_STRING } + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: + return { + "query": self._render_template(self.typed_config.query, variable_pool), + "knowledge_bases": [kb_config.model_dump(mode="json") for kb_config in self.typed_config.knowledge_bases], + } + @staticmethod def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType): """ diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 186c204f..b293d1f4 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -20,6 +20,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig from app.core.workflow.variable.base_variable import VariableType from app.db import get_db_context from app.models import ModelType +from app.schemas.model_schema import ModelInfo from app.services.model_service import ModelConfigService logger = logging.getLogger(__name__) @@ -113,12 +114,15 @@ class LLMNode(BaseNode): # 在 Session 关闭前提取所有需要的数据 api_config = self.model_balance(config) - model_name = api_config.model_name - provider = api_config.provider - api_key = api_config.api_key - api_base = api_config.api_base - is_omni = api_config.is_omni - model_type = config.type + model_info = ModelInfo( + model_name=api_config.model_name, + model_type=ModelType(config.type), + api_key=api_config.api_key, + api_base=api_config.api_base, + provider=api_config.provider, + is_omni=api_config.is_omni, + capability=api_config.capability + ) # 4. 创建 LLM 实例(使用已提取的数据) # 注意:对于流式输出,需要在模型初始化时设置 streaming=True @@ -126,17 +130,18 @@ class LLMNode(BaseNode): llm = RedBearLLM( RedBearModelConfig( - model_name=model_name, - provider=provider, - api_key=api_key, - base_url=api_base, + model_name=model_info.model_name, + provider=model_info.provider, + api_key=model_info.api_key, + base_url=model_info.api_base, extra_params=extra_params, - is_omni=is_omni + is_omni=model_info.is_omni ), - type=ModelType(model_type) + type=model_info.model_type ) - logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") + logger.debug( + f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}") messages_config = self.typed_config.messages @@ -148,35 +153,40 @@ class LLMNode(BaseNode): content_template = msg_config.content content_template = self._render_context(content_template, variable_pool) content = self._render_template(content_template, variable_pool) - + user_id = self.get_variable("sys.user_id", variable_pool) # 根据角色创建对应的消息对象 if role == "system": messages.append({ "role": "system", - "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) + "content": await self.process_message( + model_info, + content, + user_id, + self.typed_config.vision, + ) }) elif role in ["user", "human"]: messages.append({ "role": "user", - "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) + "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) }) elif role in ["ai", "assistant"]: messages.append({ "role": "assistant", - "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) + "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) }) else: logger.warning(f"未知的消息角色: {role},默认使用 user") messages.append({ "role": "user", - "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) + "content": await self.process_message(model_info, content, user_id, self.typed_config.vision) }) if self.typed_config.vision_input and self.typed_config.vision: file_content = [] files = variable_pool.get_instance(self.typed_config.vision_input) for file in files.value: - content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision) + content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision) if content: file_content.extend(content) if messages and messages[-1]["role"] == 'user': @@ -190,14 +200,19 @@ class LLMNode(BaseNode): if isinstance(message["content"], list): file_content = [] for file in message["content"]: - content = await self.process_message(provider, is_omni, file, self.typed_config.vision) + content = await self.process_message(model_info, file, user_id, self.typed_config.vision) if content: file_content.extend(content) history_message.append( {"role": message["role"], "content": file_content} ) else: - message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision) + message["content"] = await self.process_message( + model_info, + message["content"], + user_id, + self.typed_config.vision + ) history_message.append(message) messages = messages[:-1] + history_message + messages[-1:] self.messages = messages @@ -293,7 +308,7 @@ class LLMNode(BaseNode): # 调用 LLM(流式,支持字符串或消息列表) last_meta_data = {} - async for chunk in llm.astream(self.messages, stream_usage=True): + async for chunk in llm.astream(self.messages): # 提取内容 if hasattr(chunk, 'content'): content = self.process_model_output(chunk.content) diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index 700ed85f..acac09e4 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -37,6 +37,14 @@ class ParameterExtractorNode(BaseNode): } return None + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: + return { + "text": self._render_template(self.typed_config.text, variable_pool), + "prompt": self._render_template(self.typed_config.prompt, variable_pool), + "params": [param.model_dump(mode="json") for param in self.typed_config.params], + "model_id": str(self.typed_config.model_id), + } + def _output_types(self) -> dict[str, VariableType]: outputs = {} for param in self.typed_config.params: diff --git a/api/app/i18n/README.md b/api/app/i18n/README.md new file mode 100644 index 00000000..7374e966 --- /dev/null +++ b/api/app/i18n/README.md @@ -0,0 +1,61 @@ +# Internationalization (i18n) Module + +This module provides internationalization support for the MemoryBear API. + +## Components + +- `service.py` - Translation service and core translation logic +- `middleware.py` - Language detection middleware +- `dependencies.py` - FastAPI dependency injection functions +- `exceptions.py` - Internationalized exception classes + +## Usage + +### Basic Translation + +```python +from app.i18n import t + +# Simple translation +message = t("common.success.created") + +# Parameterized translation +message = t("common.validation.required", field="Name") +``` + +### Enum Translation + +```python +from app.i18n import t_enum + +# Translate enum value +role_display = t_enum("workspace_role", "manager") +``` + +### In FastAPI Endpoints + +```python +from fastapi import Depends +from app.i18n.dependencies import get_translator + +@router.post("/workspaces") +async def create_workspace( + data: WorkspaceCreate, + t: Callable = Depends(get_translator) +): + workspace = await workspace_service.create(data) + return { + "success": True, + "message": t("workspace.created_successfully"), + "data": workspace + } +``` + +## Configuration + +See `app/core/config.py` for i18n configuration options: + +- `I18N_DEFAULT_LANGUAGE` - Default language (default: "zh") +- `I18N_SUPPORTED_LANGUAGES` - Supported languages (default: "zh,en") +- `I18N_ENABLE_TRANSLATION_CACHE` - Enable caching (default: true) +- `I18N_LOG_MISSING_TRANSLATIONS` - Log missing translations (default: true) diff --git a/api/app/i18n/__init__.py b/api/app/i18n/__init__.py new file mode 100644 index 00000000..38d2b5bd --- /dev/null +++ b/api/app/i18n/__init__.py @@ -0,0 +1,124 @@ +""" +Internationalization (i18n) module for MemoryBear Enterprise. + +This module provides complete i18n support for the backend API including: +- Translation loading from multiple directories (community + enterprise) +- Translation service with caching and fallback +- Language detection middleware +- Dependency injection for FastAPI +- Convenience functions for easy usage + +Usage: + from app.i18n import t, t_enum + + # Simple translation + message = t("common.success.created") + + # Parameterized translation + error = t("common.validation.required", field="名称") + + # Enum translation + role_display = t_enum("workspace_role", "manager") +""" + +from app.i18n.dependencies import ( + get_current_language, + get_enum_translator, + get_translator, +) +from app.i18n.exceptions import ( + BadRequestError, + ConflictError, + FileNotFoundError, + FileTooLargeError, + ForbiddenError, + I18nException, + InternalServerError, + InvalidCredentialsError, + InvalidFileTypeError, + NotFoundError, + QuotaExceededError, + RateLimitExceededError, + ServiceUnavailableError, + TenantNotFoundError, + TenantSuspendedError, + TokenExpiredError, + TokenInvalidError, + UnauthorizedError, + UserAlreadyExistsError, + UserNotFoundError, + ValidationError, + WorkspaceNotFoundError, + WorkspacePermissionDeniedError, + get_current_locale, + set_current_locale, +) +from app.i18n.loader import TranslationLoader +from app.i18n.logger import ( + TranslationLogger, + get_translation_logger, + log_missing_translation, + log_translation_error, +) +from app.i18n.middleware import LanguageMiddleware +from app.i18n.serializers import ( + I18nResponseMixin, + WorkspaceSerializer, + WorkspaceMemberSerializer, + WorkspaceInviteSerializer, +) +from app.i18n.service import ( + TranslationService, + get_translation_service, + t, + t_enum, +) + +__all__ = [ + "TranslationLoader", + "LanguageMiddleware", + "TranslationService", + "get_translation_service", + "t", + "t_enum", + "get_current_language", + "get_translator", + "get_enum_translator", + # Context management + "get_current_locale", + "set_current_locale", + # Logging + "TranslationLogger", + "get_translation_logger", + "log_missing_translation", + "log_translation_error", + # Serializers + "I18nResponseMixin", + "WorkspaceSerializer", + "WorkspaceMemberSerializer", + "WorkspaceInviteSerializer", + # Exception classes + "I18nException", + "BadRequestError", + "UnauthorizedError", + "ForbiddenError", + "NotFoundError", + "ConflictError", + "ValidationError", + "InternalServerError", + "ServiceUnavailableError", + "WorkspaceNotFoundError", + "WorkspacePermissionDeniedError", + "UserNotFoundError", + "UserAlreadyExistsError", + "TenantNotFoundError", + "TenantSuspendedError", + "InvalidCredentialsError", + "TokenExpiredError", + "TokenInvalidError", + "FileNotFoundError", + "FileTooLargeError", + "InvalidFileTypeError", + "RateLimitExceededError", + "QuotaExceededError", +] diff --git a/api/app/i18n/cache.py b/api/app/i18n/cache.py new file mode 100644 index 00000000..5b0837d9 --- /dev/null +++ b/api/app/i18n/cache.py @@ -0,0 +1,291 @@ +""" +Advanced caching system for i18n translations. + +This module provides: +- LRU cache for hot translations +- Lazy loading mechanism +- Memory optimization +- Cache statistics +""" + +import logging +from functools import lru_cache +from typing import Any, Dict, Optional +from collections import OrderedDict +import time + +logger = logging.getLogger(__name__) + + +class TranslationCache: + """ + Advanced translation cache with LRU eviction and lazy loading. + + Features: + - LRU cache for frequently accessed translations + - Lazy loading to reduce startup time + - Memory-efficient storage + - Cache hit/miss statistics + """ + + def __init__(self, max_lru_size: int = 1000, enable_lazy_load: bool = True): + """ + Initialize the translation cache. + + Args: + max_lru_size: Maximum size of LRU cache for hot translations + enable_lazy_load: Enable lazy loading of locales + """ + self.max_lru_size = max_lru_size + self.enable_lazy_load = enable_lazy_load + + # Main cache: {locale: {namespace: {key: value}}} + self._main_cache: Dict[str, Dict[str, Any]] = {} + + # LRU cache for hot translations + self._lru_cache: OrderedDict = OrderedDict() + + # Loaded locales tracker + self._loaded_locales: set = set() + + # Statistics + self._stats = { + "hits": 0, + "misses": 0, + "lru_hits": 0, + "lru_misses": 0, + "lazy_loads": 0 + } + + logger.info( + f"TranslationCache initialized with LRU size: {max_lru_size}, " + f"lazy loading: {enable_lazy_load}" + ) + + def set_locale_data(self, locale: str, data: Dict[str, Any]): + """ + Set translation data for a locale. + + Args: + locale: Locale code + data: Translation data dictionary + """ + self._main_cache[locale] = data + self._loaded_locales.add(locale) + logger.debug(f"Loaded locale '{locale}' into cache") + + def get_translation( + self, + locale: str, + namespace: str, + key_path: list + ) -> Optional[str]: + """ + Get translation from cache with LRU optimization. + + Args: + locale: Locale code + namespace: Translation namespace + key_path: List of nested keys + + Returns: + Translation string or None if not found + """ + # Build cache key for LRU + cache_key = f"{locale}:{namespace}:{'.'.join(key_path)}" + + # Check LRU cache first (hot translations) + if cache_key in self._lru_cache: + self._stats["lru_hits"] += 1 + self._stats["hits"] += 1 + # Move to end (most recently used) + self._lru_cache.move_to_end(cache_key) + return self._lru_cache[cache_key] + + self._stats["lru_misses"] += 1 + + # Check main cache + if locale not in self._main_cache: + self._stats["misses"] += 1 + return None + + if namespace not in self._main_cache[locale]: + self._stats["misses"] += 1 + return None + + # Navigate through nested keys + current = self._main_cache[locale][namespace] + for key in key_path: + if isinstance(current, dict) and key in current: + current = current[key] + else: + self._stats["misses"] += 1 + return None + + # Return only if it's a string value + if not isinstance(current, str): + self._stats["misses"] += 1 + return None + + self._stats["hits"] += 1 + + # Add to LRU cache + self._add_to_lru(cache_key, current) + + return current + + def _add_to_lru(self, key: str, value: str): + """ + Add translation to LRU cache. + + Args: + key: Cache key + value: Translation value + """ + # Remove oldest if cache is full + if len(self._lru_cache) >= self.max_lru_size: + self._lru_cache.popitem(last=False) + + self._lru_cache[key] = value + + def is_locale_loaded(self, locale: str) -> bool: + """ + Check if a locale is loaded. + + Args: + locale: Locale code + + Returns: + True if locale is loaded + """ + return locale in self._loaded_locales + + def get_loaded_locales(self) -> list: + """ + Get list of loaded locales. + + Returns: + List of locale codes + """ + return list(self._loaded_locales) + + def clear_lru(self): + """Clear the LRU cache.""" + self._lru_cache.clear() + logger.info("LRU cache cleared") + + def clear_locale(self, locale: str): + """ + Clear cache for a specific locale. + + Args: + locale: Locale code + """ + if locale in self._main_cache: + del self._main_cache[locale] + self._loaded_locales.discard(locale) + + # Clear related LRU entries + keys_to_remove = [k for k in self._lru_cache if k.startswith(f"{locale}:")] + for key in keys_to_remove: + del self._lru_cache[key] + + logger.info(f"Cleared cache for locale '{locale}'") + + def clear_all(self): + """Clear all caches.""" + self._main_cache.clear() + self._lru_cache.clear() + self._loaded_locales.clear() + logger.info("All caches cleared") + + def get_stats(self) -> Dict[str, Any]: + """ + Get cache statistics. + + Returns: + Dictionary with cache statistics + """ + total_requests = self._stats["hits"] + self._stats["misses"] + hit_rate = ( + self._stats["hits"] / total_requests * 100 + if total_requests > 0 + else 0 + ) + + lru_total = self._stats["lru_hits"] + self._stats["lru_misses"] + lru_hit_rate = ( + self._stats["lru_hits"] / lru_total * 100 + if lru_total > 0 + else 0 + ) + + return { + "total_requests": total_requests, + "hits": self._stats["hits"], + "misses": self._stats["misses"], + "hit_rate": round(hit_rate, 2), + "lru_hits": self._stats["lru_hits"], + "lru_misses": self._stats["lru_misses"], + "lru_hit_rate": round(lru_hit_rate, 2), + "lru_size": len(self._lru_cache), + "lru_max_size": self.max_lru_size, + "loaded_locales": len(self._loaded_locales), + "lazy_loads": self._stats["lazy_loads"] + } + + def reset_stats(self): + """Reset cache statistics.""" + self._stats = { + "hits": 0, + "misses": 0, + "lru_hits": 0, + "lru_misses": 0, + "lazy_loads": 0 + } + logger.info("Cache statistics reset") + + def get_memory_usage(self) -> Dict[str, Any]: + """ + Estimate memory usage of the cache. + + Returns: + Dictionary with memory usage information + """ + import sys + + main_cache_size = sys.getsizeof(self._main_cache) + lru_cache_size = sys.getsizeof(self._lru_cache) + + # Rough estimate of nested data + for locale_data in self._main_cache.values(): + main_cache_size += sys.getsizeof(locale_data) + for namespace_data in locale_data.values(): + main_cache_size += sys.getsizeof(namespace_data) + + return { + "main_cache_bytes": main_cache_size, + "lru_cache_bytes": lru_cache_size, + "total_bytes": main_cache_size + lru_cache_size, + "main_cache_mb": round(main_cache_size / 1024 / 1024, 2), + "lru_cache_mb": round(lru_cache_size / 1024 / 1024, 2), + "total_mb": round((main_cache_size + lru_cache_size) / 1024 / 1024, 2) + } + + +@lru_cache(maxsize=128) +def get_cached_translation_key(locale: str, namespace: str, key: str) -> str: + """ + LRU cached function for building translation cache keys. + + This reduces string concatenation overhead for frequently accessed keys. + + Args: + locale: Locale code + namespace: Translation namespace + key: Translation key + + Returns: + Cache key string + """ + return f"{locale}:{namespace}:{key}" diff --git a/api/app/i18n/dependencies.py b/api/app/i18n/dependencies.py new file mode 100644 index 00000000..4c8e9a11 --- /dev/null +++ b/api/app/i18n/dependencies.py @@ -0,0 +1,158 @@ +""" +FastAPI dependency injection functions for i18n. + +This module provides dependency injection functions that can be used +in FastAPI route handlers to access the current language and translator. +""" + +import logging +from typing import Callable + +from fastapi import Request + +from app.i18n.service import get_translation_service + +logger = logging.getLogger(__name__) + + +async def get_current_language(request: Request) -> str: + """ + Get the current language from the request context. + + This dependency extracts the language that was determined by the + LanguageMiddleware and stored in request.state. + + Args: + request: FastAPI request object + + Returns: + Language code (e.g., "zh", "en") + + Usage: + @router.get("/example") + async def example(language: str = Depends(get_current_language)): + return {"language": language} + """ + # Get language from request state (set by LanguageMiddleware) + language = getattr(request.state, "language", None) + + if language is None: + # Fallback to default language if not set + from app.core.config import settings + language = settings.I18N_DEFAULT_LANGUAGE + logger.warning( + "Language not found in request.state, using default: " + f"{language}" + ) + + return language + + +async def get_translator(request: Request) -> Callable: + """ + Get a translator function bound to the current request's language. + + This dependency returns a translation function that automatically + uses the current request's language, making it easy to translate + strings in route handlers. + + Args: + request: FastAPI request object + + Returns: + Translation function with signature: t(key: str, **params) -> str + + Usage: + @router.post("/workspaces") + async def create_workspace( + data: WorkspaceCreate, + t: Callable = Depends(get_translator) + ): + workspace = await workspace_service.create(data) + return { + "success": True, + "message": t("workspace.created_successfully"), + "data": workspace + } + + # With parameters + @router.get("/items") + async def get_items(t: Callable = Depends(get_translator)): + count = 5 + return { + "message": t("items.found", count=count) + } + """ + # Get current language + language = await get_current_language(request) + + # Get translation service + service = get_translation_service() + + # Return a bound translation function + def translate(key: str, **params) -> str: + """ + Translate a key using the current request's language. + + Args: + key: Translation key (e.g., "common.success.created") + **params: Parameters for parameterized messages + + Returns: + Translated string + """ + return service.translate(key, language, **params) + + return translate + + +async def get_enum_translator(request: Request) -> Callable: + """ + Get an enum translator function bound to the current request's language. + + This dependency returns a function for translating enum values + that automatically uses the current request's language. + + Args: + request: FastAPI request object + + Returns: + Enum translation function with signature: + t_enum(enum_type: str, value: str) -> str + + Usage: + @router.get("/workspace/{id}") + async def get_workspace( + id: str, + t_enum: Callable = Depends(get_enum_translator) + ): + workspace = await workspace_service.get(id) + return { + "id": workspace.id, + "role": workspace.role, + "role_display": t_enum("workspace_role", workspace.role), + "status": workspace.status, + "status_display": t_enum("workspace_status", workspace.status) + } + """ + # Get current language + language = await get_current_language(request) + + # Get translation service + service = get_translation_service() + + # Return a bound enum translation function + def translate_enum(enum_type: str, value: str) -> str: + """ + Translate an enum value using the current request's language. + + Args: + enum_type: Enum type name (e.g., "workspace_role") + value: Enum value (e.g., "manager") + + Returns: + Translated enum display name + """ + return service.translate_enum(enum_type, value, language) + + return translate_enum diff --git a/api/app/i18n/exceptions.py b/api/app/i18n/exceptions.py new file mode 100644 index 00000000..b81369ed --- /dev/null +++ b/api/app/i18n/exceptions.py @@ -0,0 +1,495 @@ +""" +Internationalized exception classes for i18n system. + +This module provides exception classes that automatically translate +error messages based on the current request's language. +""" + +import logging +from contextvars import ContextVar +from typing import Any, Dict, Optional + +from fastapi import HTTPException, Request + +from app.i18n.service import get_translation_service + +logger = logging.getLogger(__name__) + +# Context variable to store current locale +_current_locale: ContextVar[Optional[str]] = ContextVar("current_locale", default=None) + + +def set_current_locale(locale: str) -> None: + """ + Set the current locale in the context variable. + + This should be called by the LanguageMiddleware. + + Args: + locale: Locale code (e.g., "zh", "en") + """ + _current_locale.set(locale) + + +def get_current_locale() -> Optional[str]: + """ + Get the current locale from the context variable. + + Returns: + Locale code or None if not set + """ + return _current_locale.get() + + +class I18nException(HTTPException): + """ + Base exception class with automatic i18n support. + + This exception automatically translates error messages based on: + 1. The current request's language (from request.state.language) + 2. The fallback language if request language is not available + 3. The error key itself if no translation is found + + Features: + - Automatic error message translation + - Parameterized error messages support + - Consistent error response format + - Language-aware error handling + + Usage: + # Simple error + raise I18nException( + error_key="errors.workspace.not_found", + status_code=404 + ) + + # Error with parameters + raise I18nException( + error_key="errors.validation.missing_field", + status_code=400, + field="name" + ) + + # Custom error code + raise I18nException( + error_key="errors.workspace.not_found", + error_code="WORKSPACE_NOT_FOUND", + status_code=404, + workspace_id="123" + ) + """ + + def __init__( + self, + error_key: str, + status_code: int = 400, + error_code: Optional[str] = None, + locale: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + **params + ): + """ + Initialize the i18n exception. + + Args: + error_key: Translation key for the error message + (e.g., "errors.workspace.not_found") + status_code: HTTP status code (default: 400) + error_code: Custom error code for API clients + (default: derived from error_key) + locale: Target locale for translation (optional) + If not provided, uses current request's language + headers: Additional HTTP headers + **params: Parameters for parameterized error messages + """ + self.error_key = error_key + self.error_code = error_code or self._generate_error_code(error_key) + self.params = params + + # Get locale from request context if not provided + if locale is None: + locale = self._get_current_locale() + + # Translate error message + translation_service = get_translation_service() + message = translation_service.translate( + error_key, + locale, + **params + ) + + # Build error detail + detail = { + "error_code": self.error_code, + "message": message, + } + + # Add parameters to detail if provided + if params: + detail["params"] = params + + # Initialize HTTPException + super().__init__( + status_code=status_code, + detail=detail, + headers=headers + ) + + logger.debug( + f"I18nException raised: {self.error_code} " + f"(key: {error_key}, locale: {locale})" + ) + + def _get_current_locale(self) -> str: + """ + Get the current locale from request context. + + Returns: + Locale code (e.g., "zh", "en") + """ + try: + # Try to get locale from context variable + locale = _current_locale.get() + if locale: + return locale + except Exception as e: + logger.debug(f"Could not get locale from context: {e}") + + # Fallback to default locale + from app.core.config import settings + return settings.I18N_DEFAULT_LANGUAGE + + def _generate_error_code(self, error_key: str) -> str: + """ + Generate error code from error key. + + Converts "errors.workspace.not_found" to "WORKSPACE_NOT_FOUND" + + Args: + error_key: Translation key + + Returns: + Error code in UPPER_SNAKE_CASE + """ + # Remove "errors." prefix if present + if error_key.startswith("errors."): + error_key = error_key[7:] + + # Convert to UPPER_SNAKE_CASE + parts = error_key.split(".") + return "_".join(parts).upper() + + +# Specific exception classes for common errors + +class BadRequestError(I18nException): + """Bad request error (400).""" + + def __init__( + self, + error_key: str = "errors.common.bad_request", + error_code: Optional[str] = None, + **params + ): + super().__init__( + error_key=error_key, + status_code=400, + error_code=error_code, + **params + ) + + +class UnauthorizedError(I18nException): + """Unauthorized error (401).""" + + def __init__( + self, + error_key: str = "errors.auth.unauthorized", + error_code: Optional[str] = None, + **params + ): + super().__init__( + error_key=error_key, + status_code=401, + error_code=error_code, + **params + ) + + +class ForbiddenError(I18nException): + """Forbidden error (403).""" + + def __init__( + self, + error_key: str = "errors.auth.forbidden", + error_code: Optional[str] = None, + **params + ): + super().__init__( + error_key=error_key, + status_code=403, + error_code=error_code, + **params + ) + + +class NotFoundError(I18nException): + """Not found error (404).""" + + def __init__( + self, + error_key: str = "errors.common.not_found", + error_code: Optional[str] = None, + **params + ): + super().__init__( + error_key=error_key, + status_code=404, + error_code=error_code, + **params + ) + + +class ConflictError(I18nException): + """Conflict error (409).""" + + def __init__( + self, + error_key: str = "errors.common.conflict", + error_code: Optional[str] = None, + **params + ): + super().__init__( + error_key=error_key, + status_code=409, + error_code=error_code, + **params + ) + + +class ValidationError(I18nException): + """Validation error (422).""" + + def __init__( + self, + error_key: str = "errors.common.validation_failed", + error_code: Optional[str] = None, + **params + ): + super().__init__( + error_key=error_key, + status_code=422, + error_code=error_code, + **params + ) + + +class InternalServerError(I18nException): + """Internal server error (500).""" + + def __init__( + self, + error_key: str = "errors.common.internal_error", + error_code: Optional[str] = None, + **params + ): + super().__init__( + error_key=error_key, + status_code=500, + error_code=error_code, + **params + ) + + +class ServiceUnavailableError(I18nException): + """Service unavailable error (503).""" + + def __init__( + self, + error_key: str = "errors.common.service_unavailable", + error_code: Optional[str] = None, + **params + ): + super().__init__( + error_key=error_key, + status_code=503, + error_code=error_code, + **params + ) + + +# Domain-specific exception classes + +class WorkspaceNotFoundError(NotFoundError): + """Workspace not found error.""" + + def __init__(self, workspace_id: Optional[str] = None, **params): + if workspace_id: + params["workspace_id"] = workspace_id + super().__init__( + error_key="errors.workspace.not_found", + error_code="WORKSPACE_NOT_FOUND", + **params + ) + + +class WorkspacePermissionDeniedError(ForbiddenError): + """Workspace permission denied error.""" + + def __init__(self, workspace_id: Optional[str] = None, **params): + if workspace_id: + params["workspace_id"] = workspace_id + super().__init__( + error_key="errors.workspace.permission_denied", + error_code="WORKSPACE_PERMISSION_DENIED", + **params + ) + + +class UserNotFoundError(NotFoundError): + """User not found error.""" + + def __init__(self, user_id: Optional[str] = None, **params): + if user_id: + params["user_id"] = user_id + super().__init__( + error_key="errors.user.not_found", + error_code="USER_NOT_FOUND", + **params + ) + + +class UserAlreadyExistsError(ConflictError): + """User already exists error.""" + + def __init__(self, identifier: Optional[str] = None, **params): + if identifier: + params["identifier"] = identifier + super().__init__( + error_key="errors.user.already_exists", + error_code="USER_ALREADY_EXISTS", + **params + ) + + +class TenantNotFoundError(NotFoundError): + """Tenant not found error.""" + + def __init__(self, tenant_id: Optional[str] = None, **params): + if tenant_id: + params["tenant_id"] = tenant_id + super().__init__( + error_key="errors.tenant.not_found", + error_code="TENANT_NOT_FOUND", + **params + ) + + +class TenantSuspendedError(ForbiddenError): + """Tenant suspended error.""" + + def __init__(self, tenant_id: Optional[str] = None, **params): + if tenant_id: + params["tenant_id"] = tenant_id + super().__init__( + error_key="errors.tenant.suspended", + error_code="TENANT_SUSPENDED", + **params + ) + + +class InvalidCredentialsError(UnauthorizedError): + """Invalid credentials error.""" + + def __init__(self, **params): + super().__init__( + error_key="errors.auth.invalid_credentials", + error_code="INVALID_CREDENTIALS", + **params + ) + + +class TokenExpiredError(UnauthorizedError): + """Token expired error.""" + + def __init__(self, **params): + super().__init__( + error_key="errors.auth.token_expired", + error_code="TOKEN_EXPIRED", + **params + ) + + +class TokenInvalidError(UnauthorizedError): + """Token invalid error.""" + + def __init__(self, **params): + super().__init__( + error_key="errors.auth.token_invalid", + error_code="TOKEN_INVALID", + **params + ) + + +class FileNotFoundError(NotFoundError): + """File not found error.""" + + def __init__(self, file_id: Optional[str] = None, **params): + if file_id: + params["file_id"] = file_id + super().__init__( + error_key="errors.file.not_found", + error_code="FILE_NOT_FOUND", + **params + ) + + +class FileTooLargeError(BadRequestError): + """File too large error.""" + + def __init__(self, max_size: Optional[str] = None, **params): + if max_size: + params["max_size"] = max_size + super().__init__( + error_key="errors.file.too_large", + error_code="FILE_TOO_LARGE", + **params + ) + + +class InvalidFileTypeError(BadRequestError): + """Invalid file type error.""" + + def __init__(self, file_type: Optional[str] = None, **params): + if file_type: + params["file_type"] = file_type + super().__init__( + error_key="errors.file.invalid_type", + error_code="INVALID_FILE_TYPE", + **params + ) + + +class RateLimitExceededError(I18nException): + """Rate limit exceeded error (429).""" + + def __init__(self, **params): + super().__init__( + error_key="errors.api.rate_limit_exceeded", + status_code=429, + error_code="RATE_LIMIT_EXCEEDED", + **params + ) + + +class QuotaExceededError(ForbiddenError): + """Quota exceeded error.""" + + def __init__(self, resource: Optional[str] = None, **params): + if resource: + params["resource"] = resource + super().__init__( + error_key="errors.api.quota_exceeded", + error_code="QUOTA_EXCEEDED", + **params + ) diff --git a/api/app/i18n/loader.py b/api/app/i18n/loader.py new file mode 100644 index 00000000..3865378b --- /dev/null +++ b/api/app/i18n/loader.py @@ -0,0 +1,199 @@ +""" +Translation file loader for i18n system. + +This module handles loading translation files from multiple directories +(community edition + enterprise edition) and provides hot reload support. +""" + +import json +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class TranslationLoader: + """ + Translation file loader that supports: + - Loading from multiple directories (community + enterprise) + - Hot reload of translation files + - Automatic locale detection + """ + + def __init__(self, locales_dirs: Optional[List[str]] = None): + """ + Initialize the translation loader. + + Args: + locales_dirs: List of directories containing translation files. + If None, will auto-detect from settings. + """ + if locales_dirs is None: + locales_dirs = self._detect_locales_dirs() + + self.locales_dirs = [Path(d) for d in locales_dirs] + logger.info(f"TranslationLoader initialized with directories: {self.locales_dirs}") + + def _detect_locales_dirs(self) -> List[str]: + """ + Auto-detect translation directories from settings. + + Returns: + List of translation directory paths + """ + from app.core.config import settings + + dirs = [] + + # 1. Core locales directory (community edition, required) + core_dir = Path(settings.I18N_CORE_LOCALES_DIR) + if core_dir.exists(): + dirs.append(str(core_dir)) + logger.debug(f"Found core locales directory: {core_dir}") + else: + logger.warning(f"Core locales directory not found: {core_dir}") + + # 2. Premium locales directory (enterprise edition, optional) + if settings.I18N_PREMIUM_LOCALES_DIR: + premium_dir = Path(settings.I18N_PREMIUM_LOCALES_DIR) + if premium_dir.exists(): + dirs.append(str(premium_dir)) + logger.debug(f"Found premium locales directory: {premium_dir}") + else: + # Auto-detect premium directory + premium_dir = Path("premium/locales") + if premium_dir.exists(): + dirs.append(str(premium_dir)) + logger.debug(f"Auto-detected premium locales directory: {premium_dir}") + + if not dirs: + logger.error("No translation directories found!") + + return dirs + + def get_available_locales(self) -> List[str]: + """ + Get list of all available locales across all directories. + + Returns: + List of locale codes (e.g., ['zh', 'en']) + """ + locales = set() + + for locales_dir in self.locales_dirs: + if not locales_dir.exists(): + continue + + for locale_dir in locales_dir.iterdir(): + if locale_dir.is_dir() and not locale_dir.name.startswith('.'): + locales.add(locale_dir.name) + + return sorted(list(locales)) + + def load_locale(self, locale: str) -> Dict[str, Any]: + """ + Load all translation files for a specific locale from all directories. + + Translation files are merged with priority: + - Later directories override earlier directories + - Enterprise translations override community translations + + Args: + locale: Locale code (e.g., 'zh', 'en') + + Returns: + Dictionary of translations organized by namespace + Format: {namespace: {key: value, ...}, ...} + """ + translations = {} + + # Load from each directory in order (later directories override earlier) + for locales_dir in self.locales_dirs: + locale_dir = locales_dir / locale + if not locale_dir.exists(): + logger.debug(f"Locale directory not found: {locale_dir}") + continue + + # Load all JSON files in this locale directory + for json_file in locale_dir.glob("*.json"): + namespace = json_file.stem + + try: + with open(json_file, "r", encoding="utf-8") as f: + new_translations = json.load(f) + + # Merge translations (deep merge) + if namespace in translations: + translations[namespace] = self._deep_merge( + translations[namespace], + new_translations + ) + logger.debug( + f"Merged translations: {locale}/{namespace} from {json_file}" + ) + else: + translations[namespace] = new_translations + logger.debug( + f"Loaded translations: {locale}/{namespace} from {json_file}" + ) + + except json.JSONDecodeError as e: + logger.error( + f"Failed to parse JSON file {json_file}: {e}" + ) + except Exception as e: + logger.error( + f"Failed to load translation file {json_file}: {e}" + ) + + if not translations: + logger.warning(f"No translations found for locale: {locale}") + + return translations + + def reload(self, locale: Optional[str] = None) -> Dict[str, Dict[str, Any]]: + """ + Reload translation files. + + Args: + locale: Specific locale to reload. If None, reloads all locales. + + Returns: + Dictionary of reloaded translations + Format: {locale: {namespace: {key: value}}} + """ + if locale: + logger.info(f"Reloading translations for locale: {locale}") + return {locale: self.load_locale(locale)} + else: + logger.info("Reloading all translations") + all_translations = {} + for loc in self.get_available_locales(): + all_translations[loc] = self.load_locale(loc) + return all_translations + + def _deep_merge(self, base: Dict, override: Dict) -> Dict: + """ + Deep merge two dictionaries. + + Args: + base: Base dictionary + override: Dictionary with values to override + + Returns: + Merged dictionary + """ + result = base.copy() + + for key, value in override.items(): + if ( + key in result + and isinstance(result[key], dict) + and isinstance(value, dict) + ): + result[key] = self._deep_merge(result[key], value) + else: + result[key] = value + + return result diff --git a/api/app/i18n/logger.py b/api/app/i18n/logger.py new file mode 100644 index 00000000..9a81fc79 --- /dev/null +++ b/api/app/i18n/logger.py @@ -0,0 +1,382 @@ +""" +Translation logging for i18n system. + +This module provides: +- TranslationLogger for recording missing translations +- Missing translation report generation +- Integration with existing logging system +- Structured logging for translation events +""" + +import logging +from typing import Dict, List, Optional, Set +from datetime import datetime +from collections import defaultdict +from pathlib import Path +import json + +from app.core.logging_config import get_logger + +logger = get_logger(__name__) + + +class TranslationLogger: + """ + Logger for translation events and missing translations. + + Features: + - Records missing translations with context + - Generates missing translation reports + - Integrates with existing logging system + - Provides structured logging for analysis + """ + + def __init__(self, log_file: Optional[str] = None): + """ + Initialize translation logger. + + Args: + log_file: Optional custom log file path for missing translations + """ + self.log_file = log_file or "logs/i18n/missing_translations.log" + self._missing_translations: Dict[str, Set[str]] = defaultdict(set) + self._missing_with_context: List[Dict] = [] + self._max_context_entries = 10000 # Keep last 10k entries + + # Ensure log directory exists + log_path = Path(self.log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + + # Create dedicated file handler for missing translations + self._file_handler = logging.FileHandler( + self.log_file, + encoding='utf-8' + ) + self._file_handler.setLevel(logging.WARNING) + + # Create formatter + formatter = logging.Formatter( + fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + self._file_handler.setFormatter(formatter) + + # Create dedicated logger for missing translations + self._logger = logging.getLogger("i18n.missing_translations") + self._logger.setLevel(logging.WARNING) + self._logger.addHandler(self._file_handler) + self._logger.propagate = False # Don't propagate to root logger + + logger.info(f"TranslationLogger initialized with log file: {self.log_file}") + + def log_missing_translation( + self, + key: str, + locale: str, + context: Optional[Dict] = None + ): + """ + Log a missing translation. + + Args: + key: Translation key that was not found + locale: Locale code + context: Optional context information (e.g., request path, user info) + """ + # Add to missing set + self._missing_translations[locale].add(key) + + # Create context entry + entry = { + "timestamp": datetime.now().isoformat(), + "key": key, + "locale": locale, + "context": context or {} + } + + # Keep only recent entries to avoid memory bloat + if len(self._missing_with_context) >= self._max_context_entries: + self._missing_with_context.pop(0) + + self._missing_with_context.append(entry) + + # Log to file + context_str = f" (context: {context})" if context else "" + self._logger.warning( + f"Missing translation: key='{key}', locale='{locale}'{context_str}" + ) + + def log_translation_error( + self, + error_type: str, + message: str, + key: Optional[str] = None, + locale: Optional[str] = None, + context: Optional[Dict] = None + ): + """ + Log a translation error. + + Args: + error_type: Type of error (e.g., "format_error", "parameter_missing") + message: Error message + key: Translation key (optional) + locale: Locale code (optional) + context: Optional context information + """ + error_data = { + "error_type": error_type, + "message": message, + "key": key, + "locale": locale, + "context": context or {}, + "timestamp": datetime.now().isoformat() + } + + self._logger.error( + f"Translation error: {error_type} - {message} " + f"(key: {key}, locale: {locale})" + ) + + def log_translation_success( + self, + key: str, + locale: str, + duration_ms: Optional[float] = None + ): + """ + Log a successful translation (debug level). + + Args: + key: Translation key + locale: Locale code + duration_ms: Optional duration in milliseconds + """ + duration_str = f" ({duration_ms:.3f}ms)" if duration_ms else "" + logger.debug( + f"Translation success: key='{key}', locale='{locale}'{duration_str}" + ) + + def get_missing_translations( + self, + locale: Optional[str] = None + ) -> Dict[str, List[str]]: + """ + Get missing translations. + + Args: + locale: Specific locale (optional, returns all if None) + + Returns: + Dictionary of missing translations by locale + """ + if locale: + return {locale: sorted(list(self._missing_translations.get(locale, set())))} + + return { + loc: sorted(list(keys)) + for loc, keys in self._missing_translations.items() + } + + def get_missing_with_context( + self, + locale: Optional[str] = None, + limit: Optional[int] = None + ) -> List[Dict]: + """ + Get missing translations with context. + + Args: + locale: Filter by locale (optional) + limit: Maximum number of entries to return (optional) + + Returns: + List of missing translation entries with context + """ + entries = self._missing_with_context + + # Filter by locale if specified + if locale: + entries = [e for e in entries if e["locale"] == locale] + + # Apply limit if specified + if limit: + entries = entries[-limit:] + + return entries + + def generate_report( + self, + locale: Optional[str] = None, + output_file: Optional[str] = None + ) -> Dict: + """ + Generate a missing translation report. + + Args: + locale: Specific locale (optional, generates for all if None) + output_file: Optional file path to save report as JSON + + Returns: + Report dictionary + """ + missing = self.get_missing_translations(locale) + + report = { + "generated_at": datetime.now().isoformat(), + "total_missing": sum(len(keys) for keys in missing.values()), + "missing_by_locale": { + loc: { + "count": len(keys), + "keys": keys + } + for loc, keys in missing.items() + }, + "recent_context": self.get_missing_with_context(locale, limit=100) + } + + # Save to file if specified + if output_file: + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(report, f, indent=2, ensure_ascii=False) + + logger.info(f"Missing translation report saved to: {output_file}") + + return report + + def get_statistics(self) -> Dict: + """ + Get statistics about missing translations. + + Returns: + Dictionary with statistics + """ + total_missing = sum(len(keys) for keys in self._missing_translations.values()) + + # Count by namespace + namespace_counts = defaultdict(int) + for locale, keys in self._missing_translations.items(): + for key in keys: + namespace = key.split('.')[0] if '.' in key else 'unknown' + namespace_counts[namespace] += 1 + + return { + "total_missing": total_missing, + "locales_affected": len(self._missing_translations), + "missing_by_locale": { + loc: len(keys) + for loc, keys in self._missing_translations.items() + }, + "missing_by_namespace": dict(namespace_counts), + "total_context_entries": len(self._missing_with_context) + } + + def clear(self, locale: Optional[str] = None): + """ + Clear missing translation records. + + Args: + locale: Specific locale to clear (optional, clears all if None) + """ + if locale: + self._missing_translations.pop(locale, None) + self._missing_with_context = [ + e for e in self._missing_with_context + if e["locale"] != locale + ] + logger.info(f"Cleared missing translations for locale: {locale}") + else: + self._missing_translations.clear() + self._missing_with_context.clear() + logger.info("Cleared all missing translations") + + def export_to_json(self, output_file: str): + """ + Export all missing translations to JSON file. + + Args: + output_file: Output file path + """ + data = { + "exported_at": datetime.now().isoformat(), + "missing_translations": self.get_missing_translations(), + "statistics": self.get_statistics(), + "recent_context": self.get_missing_with_context(limit=1000) + } + + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + logger.info(f"Missing translations exported to: {output_file}") + + def __del__(self): + """Cleanup file handler on deletion.""" + try: + if hasattr(self, '_file_handler'): + self._file_handler.close() + self._logger.removeHandler(self._file_handler) + except Exception: + pass + + +# Global translation logger instance +_translation_logger: Optional[TranslationLogger] = None + + +def get_translation_logger() -> TranslationLogger: + """ + Get the global translation logger instance. + + Returns: + TranslationLogger singleton + """ + global _translation_logger + if _translation_logger is None: + _translation_logger = TranslationLogger() + return _translation_logger + + +def log_missing_translation( + key: str, + locale: str, + context: Optional[Dict] = None +): + """ + Log a missing translation (convenience function). + + Args: + key: Translation key + locale: Locale code + context: Optional context information + """ + translation_logger = get_translation_logger() + translation_logger.log_missing_translation(key, locale, context) + + +def log_translation_error( + error_type: str, + message: str, + key: Optional[str] = None, + locale: Optional[str] = None, + context: Optional[Dict] = None +): + """ + Log a translation error (convenience function). + + Args: + error_type: Type of error + message: Error message + key: Translation key (optional) + locale: Locale code (optional) + context: Optional context information + """ + translation_logger = get_translation_logger() + translation_logger.log_translation_error( + error_type, message, key, locale, context + ) diff --git a/api/app/i18n/metrics.py b/api/app/i18n/metrics.py new file mode 100644 index 00000000..781ba83e --- /dev/null +++ b/api/app/i18n/metrics.py @@ -0,0 +1,337 @@ +""" +Performance monitoring and metrics for i18n system. + +This module provides: +- Translation request counters +- Translation timing metrics +- Missing translation tracking +- Performance monitoring decorators +- Prometheus-compatible metrics +""" + +import logging +import time +from functools import wraps +from typing import Any, Callable, Dict, Optional +from collections import defaultdict +from datetime import datetime + +logger = logging.getLogger(__name__) + + +class TranslationMetrics: + """ + Metrics collector for translation operations. + + Tracks: + - Translation request counts + - Translation timing (latency) + - Missing translations + - Cache performance + - Locale usage + """ + + def __init__(self): + """Initialize metrics collector.""" + # Request counters by locale + self._request_counts: Dict[str, int] = defaultdict(int) + + # Missing translation tracker + self._missing_translations: Dict[str, set] = defaultdict(set) + + # Timing metrics (in milliseconds) + self._timing_data: list = [] + self._max_timing_samples = 10000 # Keep last 10k samples + + # Locale usage + self._locale_usage: Dict[str, int] = defaultdict(int) + + # Namespace usage + self._namespace_usage: Dict[str, int] = defaultdict(int) + + # Error counts + self._error_counts: Dict[str, int] = defaultdict(int) + + # Start time + self._start_time = datetime.now() + + logger.info("TranslationMetrics initialized") + + def record_request(self, locale: str, namespace: str = None): + """ + Record a translation request. + + Args: + locale: Locale code + namespace: Translation namespace (optional) + """ + self._request_counts[locale] += 1 + self._locale_usage[locale] += 1 + + if namespace: + self._namespace_usage[namespace] += 1 + + def record_missing(self, key: str, locale: str): + """ + Record a missing translation. + + Args: + key: Translation key + locale: Locale code + """ + self._missing_translations[locale].add(key) + logger.debug(f"Missing translation recorded: {key} (locale: {locale})") + + def record_timing(self, duration_ms: float, locale: str, operation: str = "translate"): + """ + Record translation operation timing. + + Args: + duration_ms: Duration in milliseconds + locale: Locale code + operation: Operation type + """ + # Keep only recent samples to avoid memory bloat + if len(self._timing_data) >= self._max_timing_samples: + self._timing_data.pop(0) + + self._timing_data.append({ + "duration_ms": duration_ms, + "locale": locale, + "operation": operation, + "timestamp": time.time() + }) + + def record_error(self, error_type: str): + """ + Record an error. + + Args: + error_type: Type of error + """ + self._error_counts[error_type] += 1 + + def get_summary(self) -> Dict[str, Any]: + """ + Get metrics summary. + + Returns: + Dictionary with metrics summary + """ + total_requests = sum(self._request_counts.values()) + total_missing = sum(len(keys) for keys in self._missing_translations.values()) + + # Calculate timing statistics + timing_stats = self._calculate_timing_stats() + + # Calculate uptime + uptime_seconds = (datetime.now() - self._start_time).total_seconds() + + return { + "uptime_seconds": round(uptime_seconds, 2), + "total_requests": total_requests, + "requests_per_locale": dict(self._request_counts), + "total_missing_translations": total_missing, + "missing_by_locale": { + locale: len(keys) + for locale, keys in self._missing_translations.items() + }, + "timing": timing_stats, + "locale_usage": dict(self._locale_usage), + "namespace_usage": dict(self._namespace_usage), + "error_counts": dict(self._error_counts) + } + + def _calculate_timing_stats(self) -> Dict[str, Any]: + """ + Calculate timing statistics. + + Returns: + Dictionary with timing statistics + """ + if not self._timing_data: + return { + "count": 0, + "avg_ms": 0, + "min_ms": 0, + "max_ms": 0, + "p50_ms": 0, + "p95_ms": 0, + "p99_ms": 0 + } + + durations = [d["duration_ms"] for d in self._timing_data] + durations.sort() + + count = len(durations) + avg = sum(durations) / count + + # Calculate percentiles + p50_idx = int(count * 0.50) + p95_idx = int(count * 0.95) + p99_idx = int(count * 0.99) + + return { + "count": count, + "avg_ms": round(avg, 3), + "min_ms": round(durations[0], 3), + "max_ms": round(durations[-1], 3), + "p50_ms": round(durations[p50_idx], 3), + "p95_ms": round(durations[p95_idx], 3), + "p99_ms": round(durations[p99_idx], 3) + } + + def get_missing_translations(self, locale: Optional[str] = None) -> Dict[str, list]: + """ + Get missing translations. + + Args: + locale: Specific locale (optional, returns all if None) + + Returns: + Dictionary of missing translations by locale + """ + if locale: + return {locale: list(self._missing_translations.get(locale, set()))} + + return { + locale: list(keys) + for locale, keys in self._missing_translations.items() + } + + def reset(self): + """Reset all metrics.""" + self._request_counts.clear() + self._missing_translations.clear() + self._timing_data.clear() + self._locale_usage.clear() + self._namespace_usage.clear() + self._error_counts.clear() + self._start_time = datetime.now() + logger.info("Metrics reset") + + def export_prometheus(self) -> str: + """ + Export metrics in Prometheus format. + + Returns: + Prometheus-formatted metrics string + """ + lines = [] + + # Translation requests counter + lines.append("# HELP i18n_translation_requests_total Total number of translation requests") + lines.append("# TYPE i18n_translation_requests_total counter") + for locale, count in self._request_counts.items(): + lines.append(f'i18n_translation_requests_total{{locale="{locale}"}} {count}') + + # Missing translations counter + lines.append("# HELP i18n_missing_translations_total Total number of missing translations") + lines.append("# TYPE i18n_missing_translations_total counter") + for locale, keys in self._missing_translations.items(): + lines.append(f'i18n_missing_translations_total{{locale="{locale}"}} {len(keys)}') + + # Timing metrics + timing_stats = self._calculate_timing_stats() + lines.append("# HELP i18n_translation_duration_ms Translation operation duration in milliseconds") + lines.append("# TYPE i18n_translation_duration_ms summary") + lines.append(f'i18n_translation_duration_ms{{quantile="0.5"}} {timing_stats["p50_ms"]}') + lines.append(f'i18n_translation_duration_ms{{quantile="0.95"}} {timing_stats["p95_ms"]}') + lines.append(f'i18n_translation_duration_ms{{quantile="0.99"}} {timing_stats["p99_ms"]}') + lines.append(f'i18n_translation_duration_ms_sum {sum(d["duration_ms"] for d in self._timing_data)}') + lines.append(f'i18n_translation_duration_ms_count {timing_stats["count"]}') + + # Error counter + lines.append("# HELP i18n_errors_total Total number of i18n errors") + lines.append("# TYPE i18n_errors_total counter") + for error_type, count in self._error_counts.items(): + lines.append(f'i18n_errors_total{{type="{error_type}"}} {count}') + + return "\n".join(lines) + + +# Global metrics instance +_metrics: Optional[TranslationMetrics] = None + + +def get_metrics() -> TranslationMetrics: + """ + Get the global metrics instance. + + Returns: + TranslationMetrics singleton + """ + global _metrics + if _metrics is None: + _metrics = TranslationMetrics() + return _metrics + + +def monitor_performance(operation: str = "translate"): + """ + Decorator to monitor translation operation performance. + + Args: + operation: Operation name for metrics + + Returns: + Decorated function + + Example: + @monitor_performance("translate") + def translate(key: str, locale: str) -> str: + ... + """ + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.perf_counter() + + try: + result = func(*args, **kwargs) + + # Record timing + duration_ms = (time.perf_counter() - start_time) * 1000 + + # Try to extract locale from args/kwargs + locale = kwargs.get("locale", "unknown") + if not locale and len(args) > 1: + locale = args[1] if isinstance(args[1], str) else "unknown" + + metrics = get_metrics() + metrics.record_timing(duration_ms, locale, operation) + + return result + + except Exception as e: + # Record error + metrics = get_metrics() + metrics.record_error(type(e).__name__) + raise + + return wrapper + return decorator + + +def track_missing_translation(key: str, locale: str): + """ + Track a missing translation. + + Args: + key: Translation key + locale: Locale code + """ + metrics = get_metrics() + metrics.record_missing(key, locale) + + +def track_translation_request(locale: str, namespace: str = None): + """ + Track a translation request. + + Args: + locale: Locale code + namespace: Translation namespace (optional) + """ + metrics = get_metrics() + metrics.record_request(locale, namespace) diff --git a/api/app/i18n/middleware.py b/api/app/i18n/middleware.py new file mode 100644 index 00000000..2e945dde --- /dev/null +++ b/api/app/i18n/middleware.py @@ -0,0 +1,202 @@ +""" +Language detection middleware for i18n system. + +This middleware determines the language to use for each request based on: +1. Query parameter (?lang=en) +2. Accept-Language HTTP header +3. User language preference (from database) +4. Tenant default language +5. System default language + +The detected language is injected into request.state.language and +added to the response Content-Language header. +""" + +import logging +import re +from typing import Optional + +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware + +logger = logging.getLogger(__name__) + + +class LanguageMiddleware(BaseHTTPMiddleware): + """ + Language detection middleware. + + Determines the language for each request based on multiple sources + with a clear priority order, validates the language is supported, + and injects it into the request context. + """ + + async def dispatch(self, request: Request, call_next): + """ + Process the request and determine the language. + + Args: + request: The incoming request + call_next: The next middleware/handler in the chain + + Returns: + Response with Content-Language header added + """ + # Determine the language for this request + language = await self._determine_language(request) + + # Validate language is supported + from app.core.config import settings + if language not in settings.I18N_SUPPORTED_LANGUAGES: + logger.warning( + f"Unsupported language '{language}' requested, " + f"falling back to default: {settings.I18N_DEFAULT_LANGUAGE}" + ) + language = settings.I18N_DEFAULT_LANGUAGE + + # Inject language into request state + request.state.language = language + + # Also set in context variable for exception handling + from app.i18n.exceptions import set_current_locale + set_current_locale(language) + + logger.debug(f"Request language set to: {language}") + + # Process the request + response = await call_next(request) + + # Add Content-Language header to response + response.headers["Content-Language"] = language + + return response + + async def _determine_language(self, request: Request) -> str: + """ + Determine the language to use based on priority order. + + Priority: + 1. Query parameter (?lang=en) + 2. Accept-Language HTTP header + 3. User language preference (from database) + 4. Tenant default language + 5. System default language + + Args: + request: The incoming request + + Returns: + Language code (e.g., "zh", "en") + """ + from app.core.config import settings + + # 1. Check query parameter (?lang=en) + if "lang" in request.query_params: + lang = request.query_params["lang"].strip().lower() + if lang: + logger.debug(f"Language from query parameter: {lang}") + return lang + + # 2. Check Accept-Language HTTP header + if "Accept-Language" in request.headers: + lang = self._parse_accept_language( + request.headers["Accept-Language"] + ) + if lang: + logger.debug(f"Language from Accept-Language header: {lang}") + return lang + + # 3. Check user language preference (requires authentication) + # Note: This assumes user is already loaded into request.state by auth middleware + if hasattr(request.state, "user") and request.state.user: + user = request.state.user + if hasattr(user, "preferred_language") and user.preferred_language: + logger.debug( + f"Language from user preference: {user.preferred_language}" + ) + return user.preferred_language + + # 4. Check tenant default language + # Note: This assumes tenant is already loaded into request.state + if hasattr(request.state, "tenant") and request.state.tenant: + tenant = request.state.tenant + if hasattr(tenant, "default_language") and tenant.default_language: + logger.debug( + f"Language from tenant default: {tenant.default_language}" + ) + return tenant.default_language + + # 5. Fall back to system default language + logger.debug( + f"Using system default language: {settings.I18N_DEFAULT_LANGUAGE}" + ) + return settings.I18N_DEFAULT_LANGUAGE + + def _parse_accept_language(self, header: str) -> Optional[str]: + """ + Parse the Accept-Language HTTP header. + + The Accept-Language header format: + Accept-Language: zh-CN,zh;q=0.9,en;q=0.8,en-US;q=0.7 + + This method: + 1. Parses all language codes and their quality values + 2. Extracts the base language code (zh-CN -> zh) + 3. Sorts by quality value (higher first) + 4. Returns the first supported language + + Args: + header: Accept-Language header value + + Returns: + Language code if found and supported, None otherwise + + Examples: + _parse_accept_language("zh-CN,zh;q=0.9,en;q=0.8") + # => "zh" (if zh is supported) + + _parse_accept_language("en-US,en;q=0.9") + # => "en" (if en is supported) + """ + from app.core.config import settings + + if not header: + return None + + # Parse language preferences with quality values + languages = [] + + for item in header.split(","): + item = item.strip() + if not item: + continue + + # Split language code and quality value + parts = item.split(";") + lang_code = parts[0].strip() + + # Extract base language code (zh-CN -> zh, en-US -> en) + base_lang = lang_code.split("-")[0].lower() + + # Extract quality value (default: 1.0) + quality = 1.0 + if len(parts) > 1: + # Look for q=0.9 pattern + q_match = re.search(r"q=([\d.]+)", parts[1]) + if q_match: + try: + quality = float(q_match.group(1)) + except ValueError: + quality = 1.0 + + languages.append((base_lang, quality)) + + # Sort by quality value (descending) + languages.sort(key=lambda x: x[1], reverse=True) + + # Return the first supported language + for lang_code, _ in languages: + if lang_code in settings.I18N_SUPPORTED_LANGUAGES: + return lang_code + + return None diff --git a/api/app/i18n/serializers.py b/api/app/i18n/serializers.py new file mode 100644 index 00000000..15ba4de5 --- /dev/null +++ b/api/app/i18n/serializers.py @@ -0,0 +1,221 @@ +""" +国际化响应序列化器 + +提供基础的 I18nResponseMixin 类,用于为 API 响应添加国际化字段。 +""" + +from typing import Any, Dict, List, Union +from pydantic import BaseModel + + +class I18nResponseMixin: + """国际化响应混入类 + + 为响应数据添加国际化字段,特别是为枚举值添加 _display 后缀的翻译字段。 + + 使用方法: + 1. 继承此类 + 2. 实现 _get_enum_fields() 方法定义需要翻译的枚举字段 + 3. 调用 serialize_with_i18n() 方法序列化数据 + + 示例: + class WorkspaceSerializer(I18nResponseMixin): + def _get_enum_fields(self) -> Dict[str, str]: + return { + "role": "workspace_role", + "status": "workspace_status" + } + + def serialize(self, workspace: Workspace, locale: str = "zh") -> Dict: + data = { + "id": str(workspace.id), + "name": workspace.name, + "role": workspace.role, + "status": workspace.status + } + return self.serialize_with_i18n(data, locale) + """ + + def serialize_with_i18n( + self, + data: Any, + locale: str = "zh" + ) -> Union[Dict, List[Dict], Any]: + """序列化数据并添加国际化字段 + + Args: + data: 要序列化的数据(字典、列表或 Pydantic 模型) + locale: 语言代码 + + Returns: + 序列化后的数据,包含国际化字段 + """ + # 如果是 Pydantic 模型,转换为字典 + if isinstance(data, BaseModel): + data = data.model_dump() + + # 处理不同类型的数据 + if isinstance(data, dict): + return self._serialize_dict(data, locale) + elif isinstance(data, list): + return [self._serialize_dict(item, locale) if isinstance(item, dict) else item for item in data] + else: + return data + + def _serialize_dict(self, data: Dict, locale: str) -> Dict: + """序列化字典并添加 _display 字段 + + Args: + data: 字典数据 + locale: 语言代码 + + Returns: + 添加了 _display 字段的字典 + """ + from app.i18n.service import get_translation_service + + translation_service = get_translation_service() + + result = data.copy() + + # 获取需要翻译的枚举字段 + enum_fields = self._get_enum_fields() + + # 为每个枚举字段添加 _display 字段 + for field, enum_type in enum_fields.items(): + if field in result and result[field] is not None: + value = result[field] + # 翻译枚举值 + display_value = translation_service.translate_enum( + enum_type=enum_type, + value=str(value), + locale=locale + ) + # 添加 _display 字段 + result[f"{field}_display"] = display_value + + return result + + def _get_enum_fields(self) -> Dict[str, str]: + """获取需要翻译的枚举字段 + + 子类必须实现此方法,返回字段名到枚举类型的映射。 + + Returns: + 字段名到枚举类型的映射 + 例如: {"role": "workspace_role", "status": "workspace_status"} + """ + return {} + + +class WorkspaceSerializer(I18nResponseMixin): + """工作空间序列化器 + + 为工作空间响应添加国际化字段。 + """ + + def _get_enum_fields(self) -> Dict[str, str]: + """定义工作空间的枚举字段""" + return { + "role": "workspace_role", + "status": "workspace_status" + } + + def serialize(self, workspace_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict: + """序列化工作空间数据 + + Args: + workspace_data: 工作空间数据(字典或 Pydantic 模型) + locale: 语言代码 + + Returns: + 序列化后的工作空间数据,包含国际化字段 + """ + return self.serialize_with_i18n(workspace_data, locale) + + def serialize_list(self, workspaces: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]: + """序列化工作空间列表 + + Args: + workspaces: 工作空间列表 + locale: 语言代码 + + Returns: + 序列化后的工作空间列表 + """ + return [self.serialize(ws, locale) for ws in workspaces] + + +class WorkspaceMemberSerializer(I18nResponseMixin): + """工作空间成员序列化器 + + 为工作空间成员响应添加国际化字段。 + """ + + def _get_enum_fields(self) -> Dict[str, str]: + """定义工作空间成员的枚举字段""" + return { + "role": "workspace_role" + } + + def serialize(self, member_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict: + """序列化工作空间成员数据 + + Args: + member_data: 成员数据(字典或 Pydantic 模型) + locale: 语言代码 + + Returns: + 序列化后的成员数据,包含国际化字段 + """ + return self.serialize_with_i18n(member_data, locale) + + def serialize_list(self, members: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]: + """序列化工作空间成员列表 + + Args: + members: 成员列表 + locale: 语言代码 + + Returns: + 序列化后的成员列表 + """ + return [self.serialize(member, locale) for member in members] + + +class WorkspaceInviteSerializer(I18nResponseMixin): + """工作空间邀请序列化器 + + 为工作空间邀请响应添加国际化字段。 + """ + + def _get_enum_fields(self) -> Dict[str, str]: + """定义工作空间邀请的枚举字段""" + return { + "status": "invite_status", + "role": "workspace_role" + } + + def serialize(self, invite_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict: + """序列化工作空间邀请数据 + + Args: + invite_data: 邀请数据(字典或 Pydantic 模型) + locale: 语言代码 + + Returns: + 序列化后的邀请数据,包含国际化字段 + """ + return self.serialize_with_i18n(invite_data, locale) + + def serialize_list(self, invites: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]: + """序列化工作空间邀请列表 + + Args: + invites: 邀请列表 + locale: 语言代码 + + Returns: + 序列化后的邀请列表 + """ + return [self.serialize(invite, locale) for invite in invites] diff --git a/api/app/i18n/service.py b/api/app/i18n/service.py new file mode 100644 index 00000000..9cbc0926 --- /dev/null +++ b/api/app/i18n/service.py @@ -0,0 +1,370 @@ +""" +Translation service for i18n system. + +This module provides the core translation functionality including: +- Translation lookup with fallback mechanism +- Parameterized message support +- Enum value translation +- Memory caching for performance +- Performance monitoring and metrics +""" + +import logging +from functools import lru_cache +from typing import Any, Dict, Optional + +from app.i18n.loader import TranslationLoader +from app.i18n.cache import TranslationCache +from app.i18n.metrics import get_metrics, monitor_performance, track_missing_translation, track_translation_request +from app.i18n.logger import get_translation_logger + +logger = logging.getLogger(__name__) + + +class TranslationService: + """ + Translation service that provides: + - Fast translation lookup with memory cache + - Parameterized message support ({param} syntax) + - Fallback mechanism (current locale → default locale → key) + - Enum value translation + - Deep merge of multi-directory translations + """ + + def __init__(self, locales_dirs: Optional[list] = None): + """ + Initialize the translation service. + + Args: + locales_dirs: List of directories containing translation files. + If None, will auto-detect from settings. + """ + from app.core.config import settings + + self.loader = TranslationLoader(locales_dirs) + self.default_locale = settings.I18N_DEFAULT_LANGUAGE + self.fallback_locale = settings.I18N_FALLBACK_LANGUAGE + self.log_missing = settings.I18N_LOG_MISSING_TRANSLATIONS + self.enable_cache = settings.I18N_ENABLE_TRANSLATION_CACHE + + # Initialize advanced cache with LRU + lru_cache_size = getattr(settings, 'I18N_LRU_CACHE_SIZE', 1000) + self.cache = TranslationCache( + max_lru_size=lru_cache_size, + enable_lazy_load=False # Load all at startup for now + ) + + # Load all translations into cache + self._load_all_locales() + + # Initialize metrics + self.metrics = get_metrics() + + # Initialize translation logger + self.translation_logger = get_translation_logger() + + logger.info( + f"TranslationService initialized with default locale: {self.default_locale}, " + f"LRU cache size: {lru_cache_size}" + ) + + def _load_all_locales(self): + """Load all available locales into memory cache.""" + available_locales = self.loader.get_available_locales() + logger.info(f"Loading translations for locales: {available_locales}") + + for locale in available_locales: + locale_data = self.loader.load_locale(locale) + self.cache.set_locale_data(locale, locale_data) + + logger.info(f"Loaded {len(available_locales)} locales into cache") + + @monitor_performance("translate") + def translate( + self, + key: str, + locale: Optional[str] = None, + **params + ) -> str: + """ + Translate a key to the target locale. + + Supports: + - Dot-separated keys (e.g., "common.success.created") + - Parameterized messages (e.g., "Hello {name}") + - Fallback mechanism + + Args: + key: Translation key (format: "namespace.key.subkey") + locale: Target locale (defaults to default locale) + **params: Parameters for parameterized messages + + Returns: + Translated string, or the key itself if translation not found + + Examples: + translate("common.success.created", "zh") + # => "创建成功" + + translate("common.validation.required", "zh", field="名称") + # => "名称不能为空" + """ + if locale is None: + locale = self.default_locale + + # Parse key (namespace.key.subkey) + parts = key.split(".", 1) + if len(parts) < 2: + if self.log_missing: + logger.warning(f"Invalid translation key format: {key}") + return key + + namespace = parts[0] + key_path = parts[1].split(".") + + # Track request + track_translation_request(locale, namespace) + + # Get translation from cache + translation = self.cache.get_translation(locale, namespace, key_path) + + # Fallback to default locale if not found + if translation is None and locale != self.fallback_locale: + translation = self.cache.get_translation( + self.fallback_locale, namespace, key_path + ) + + # If still not found, return the key itself + if translation is None: + if self.log_missing: + logger.warning( + f"Missing translation: {key} (locale: {locale})" + ) + track_missing_translation(key, locale) + + # Log to translation logger with context + self.translation_logger.log_missing_translation( + key=key, + locale=locale, + context={"namespace": namespace} + ) + return key + + # Apply parameters if provided + if params: + try: + translation = translation.format(**params) + except KeyError as e: + error_msg = f"Missing parameter in translation '{key}': {e}" + logger.error(error_msg) + self.translation_logger.log_translation_error( + error_type="parameter_missing", + message=error_msg, + key=key, + locale=locale, + context={"params": list(params.keys())} + ) + except Exception as e: + error_msg = f"Error formatting translation '{key}': {e}" + logger.error(error_msg) + self.translation_logger.log_translation_error( + error_type="format_error", + message=error_msg, + key=key, + locale=locale + ) + + return translation + + def _get_translation( + self, + locale: str, + namespace: str, + key_path: list + ) -> Optional[str]: + """ + Get translation from cache (deprecated, use cache.get_translation). + + Args: + locale: Locale code + namespace: Translation namespace + key_path: List of nested keys + + Returns: + Translation string or None if not found + """ + return self.cache.get_translation(locale, namespace, key_path) + + @monitor_performance("translate_enum") + def translate_enum( + self, + enum_type: str, + value: str, + locale: Optional[str] = None + ) -> str: + """ + Translate an enum value. + + Args: + enum_type: Enum type name (e.g., "workspace_role") + value: Enum value (e.g., "manager") + locale: Target locale + + Returns: + Translated enum display name + + Examples: + translate_enum("workspace_role", "manager", "zh") + # => "管理员" + + translate_enum("invite_status", "pending", "en") + # => "Pending" + """ + key = f"enums.{enum_type}.{value}" + return self.translate(key, locale) + + def has_translation(self, key: str, locale: str) -> bool: + """ + Check if a translation exists for the given key and locale. + + Args: + key: Translation key + locale: Locale code + + Returns: + True if translation exists, False otherwise + """ + parts = key.split(".", 1) + if len(parts) < 2: + return False + + namespace = parts[0] + key_path = parts[1].split(".") + + translation = self.cache.get_translation(locale, namespace, key_path) + return translation is not None + + def reload(self, locale: Optional[str] = None): + """ + Reload translation files. + + Args: + locale: Specific locale to reload. If None, reloads all locales. + """ + logger.info(f"Reloading translations for locale: {locale or 'all'}") + + if locale: + locale_data = self.loader.load_locale(locale) + self.cache.set_locale_data(locale, locale_data) + # Clear LRU cache for this locale + self.cache.clear_locale(locale) + else: + self._load_all_locales() + # Clear all LRU cache + self.cache.clear_lru() + + logger.info("Translation reload completed") + + def get_available_locales(self) -> list: + """ + Get list of all available locales. + + Returns: + List of locale codes + """ + return self.cache.get_loaded_locales() + + def get_cache_stats(self) -> Dict[str, Any]: + """ + Get cache statistics. + + Returns: + Dictionary with cache statistics + """ + return self.cache.get_stats() + + def get_metrics_summary(self) -> Dict[str, Any]: + """ + Get metrics summary. + + Returns: + Dictionary with metrics summary + """ + return self.metrics.get_summary() + + def get_memory_usage(self) -> Dict[str, Any]: + """ + Get memory usage information. + + Returns: + Dictionary with memory usage information + """ + return self.cache.get_memory_usage() + + def get_loaded_dirs(self) -> list: + """ + Get list of loaded translation directories. + + Returns: + List of directory paths + """ + return self.loader.locales_dirs + + +# Global singleton instance +_translation_service: Optional[TranslationService] = None + + +def get_translation_service() -> TranslationService: + """ + Get the global translation service instance. + + Returns: + TranslationService singleton + """ + global _translation_service + if _translation_service is None: + _translation_service = TranslationService() + return _translation_service + + +# Convenience functions for easy access +def t(key: str, locale: Optional[str] = None, **params) -> str: + """ + Translate a key (convenience function). + + Args: + key: Translation key + locale: Target locale (optional, uses default if not provided) + **params: Parameters for parameterized messages + + Returns: + Translated string + + Examples: + t("common.success.created") + t("common.validation.required", field="名称") + t("workspace.member_count", count=5) + """ + service = get_translation_service() + return service.translate(key, locale, **params) + + +def t_enum(enum_type: str, value: str, locale: Optional[str] = None) -> str: + """ + Translate an enum value (convenience function). + + Args: + enum_type: Enum type name + value: Enum value + locale: Target locale + + Returns: + Translated enum display name + + Examples: + t_enum("workspace_role", "manager") + t_enum("invite_status", "pending", "en") + """ + service = get_translation_service() + return service.translate_enum(enum_type, value, locale) diff --git a/api/app/locales/en/README.md b/api/app/locales/en/README.md new file mode 100644 index 00000000..0a605a60 --- /dev/null +++ b/api/app/locales/en/README.md @@ -0,0 +1,26 @@ +# English Translation Files + +This directory contains English translation files. + +## File Structure + +- `common.json` - Common translations (success messages, actions, validation) +- `auth.json` - Authentication module translations +- `workspace.json` - Workspace module translations +- `tenant.json` - Tenant module translations +- `errors.json` - Error message translations +- `enums.json` - Enum value translations + +## Translation File Format + +All translation files use JSON format and support nested structures. + +Example: +```json +{ + "success": { + "created": "Created successfully", + "updated": "Updated successfully" + } +} +``` diff --git a/api/app/locales/en/auth.json b/api/app/locales/en/auth.json new file mode 100644 index 00000000..50ba866b --- /dev/null +++ b/api/app/locales/en/auth.json @@ -0,0 +1,55 @@ +{ + "login": { + "success": "Login successful", + "failed": "Login failed", + "invalid_credentials": "Invalid username or password", + "account_locked": "Account has been locked", + "account_disabled": "Account has been disabled" + }, + "logout": { + "success": "Logout successful", + "failed": "Logout failed" + }, + "token": { + "refresh_success": "Token refreshed successfully", + "invalid": "Invalid token", + "expired": "Token has expired", + "blacklisted": "Token has been invalidated", + "invalid_refresh_token": "Invalid refresh token", + "refresh_token_blacklisted": "Refresh token has been invalidated" + }, + "registration": { + "success": "Registration successful", + "failed": "Registration failed", + "email_exists": "Email already in use", + "username_exists": "Username already taken" + }, + "password": { + "reset_success": "Password reset successful", + "reset_failed": "Password reset failed", + "change_success": "Password changed successfully", + "change_failed": "Password change failed", + "incorrect": "Incorrect password", + "too_weak": "Password is too weak", + "mismatch": "Passwords do not match" + }, + "invite": { + "invalid": "Invalid or expired invite code", + "email_mismatch": "Invite email does not match login email", + "accept_success": "Invite accepted successfully", + "accept_failed": "Failed to accept invite", + "password_verification_failed": "Failed to accept invite, password verification error", + "bind_workspace_success": "Workspace bound successfully", + "bind_workspace_failed": "Failed to bind workspace" + }, + "user": { + "not_found": "User not found", + "already_exists": "User already exists", + "created_with_invite": "User created successfully and joined workspace" + }, + "session": { + "expired": "Session expired, please login again", + "invalid": "Invalid session", + "single_session_enabled": "Single sign-on enabled, other device sessions will be logged out" + } +} diff --git a/api/app/locales/en/common.json b/api/app/locales/en/common.json new file mode 100644 index 00000000..505f83e3 --- /dev/null +++ b/api/app/locales/en/common.json @@ -0,0 +1,132 @@ +{ + "success": { + "created": "Created successfully", + "updated": "Updated successfully", + "deleted": "Deleted successfully", + "retrieved": "Retrieved successfully", + "saved": "Saved successfully", + "uploaded": "Uploaded successfully", + "downloaded": "Downloaded successfully", + "sent": "Sent successfully", + "completed": "Completed", + "confirmed": "Confirmed", + "cancelled": "Cancelled", + "archived": "Archived", + "restored": "Restored" + }, + "actions": { + "create": "Create", + "update": "Update", + "delete": "Delete", + "view": "View", + "edit": "Edit", + "save": "Save", + "cancel": "Cancel", + "confirm": "Confirm", + "submit": "Submit", + "upload": "Upload", + "download": "Download", + "send": "Send", + "search": "Search", + "filter": "Filter", + "sort": "Sort", + "export": "Export", + "import": "Import", + "refresh": "Refresh", + "reset": "Reset", + "back": "Back", + "next": "Next", + "previous": "Previous", + "finish": "Finish", + "close": "Close", + "open": "Open", + "archive": "Archive", + "restore": "Restore", + "duplicate": "Duplicate", + "share": "Share", + "invite": "Invite", + "remove": "Remove", + "add": "Add", + "select": "Select", + "clear": "Clear" + }, + "validation": { + "required": "{field} is required", + "invalid_format": "{field} format is invalid", + "too_long": "{field} cannot exceed {max} characters", + "too_short": "{field} must be at least {min} characters", + "invalid_email": "Invalid email format", + "invalid_url": "Invalid URL format", + "invalid_phone": "Invalid phone number format", + "invalid_date": "Invalid date format", + "invalid_number": "Must be a valid number", + "out_of_range": "{field} must be between {min} and {max}", + "already_exists": "{field} already exists", + "not_found": "{field} not found", + "invalid_value": "Invalid value for {field}", + "password_mismatch": "Passwords do not match", + "weak_password": "Password is too weak, please use a stronger password", + "invalid_credentials": "Invalid username or password", + "unauthorized": "Unauthorized access", + "forbidden": "Permission denied", + "expired": "{field} has expired", + "invalid_token": "Invalid token", + "file_too_large": "File size cannot exceed {max}", + "invalid_file_type": "Unsupported file type", + "duplicate": "Duplicate {field}" + }, + "status": { + "active": "Active", + "inactive": "Inactive", + "pending": "Pending", + "processing": "Processing", + "completed": "Completed", + "failed": "Failed", + "cancelled": "Cancelled", + "archived": "Archived", + "deleted": "Deleted", + "draft": "Draft", + "published": "Published", + "suspended": "Suspended", + "expired": "Expired" + }, + "messages": { + "loading": "Loading...", + "saving": "Saving...", + "processing": "Processing...", + "uploading": "Uploading...", + "downloading": "Downloading...", + "no_data": "No data available", + "no_results": "No results found", + "confirm_delete": "Are you sure you want to delete? This action cannot be undone.", + "confirm_action": "Are you sure you want to perform this action?", + "operation_success": "Operation successful", + "operation_failed": "Operation failed", + "please_wait": "Please wait...", + "try_again": "Please try again", + "contact_support": "If the problem persists, please contact support" + }, + "pagination": { + "page": "Page {page}", + "of": "of {total}", + "items": "{total} items", + "per_page": "{count} per page", + "showing": "Showing {from} to {to} of {total}", + "first": "First", + "last": "Last", + "next": "Next", + "previous": "Previous" + }, + "time": { + "just_now": "Just now", + "minutes_ago": "{count} minutes ago", + "hours_ago": "{count} hours ago", + "days_ago": "{count} days ago", + "weeks_ago": "{count} weeks ago", + "months_ago": "{count} months ago", + "years_ago": "{count} years ago", + "today": "Today", + "yesterday": "Yesterday", + "tomorrow": "Tomorrow" + } +} diff --git a/api/app/locales/en/enums.json b/api/app/locales/en/enums.json new file mode 100644 index 00000000..da7a3ace --- /dev/null +++ b/api/app/locales/en/enums.json @@ -0,0 +1,132 @@ +{ + "workspace_role": { + "owner": "Owner", + "manager": "Manager", + "member": "Member", + "guest": "Guest" + }, + "workspace_status": { + "active": "Active", + "inactive": "Inactive", + "archived": "Archived", + "suspended": "Suspended", + "deleted": "Deleted" + }, + "invite_status": { + "pending": "Pending", + "accepted": "Accepted", + "rejected": "Rejected", + "revoked": "Revoked", + "expired": "Expired" + }, + "user_status": { + "active": "Active", + "inactive": "Inactive", + "suspended": "Suspended", + "deleted": "Deleted", + "pending": "Pending" + }, + "tenant_status": { + "active": "Active", + "inactive": "Inactive", + "suspended": "Suspended", + "expired": "Expired", + "trial": "Trial" + }, + "file_status": { + "uploading": "Uploading", + "processing": "Processing", + "completed": "Completed", + "failed": "Failed", + "deleted": "Deleted" + }, + "task_status": { + "pending": "Pending", + "running": "Running", + "completed": "Completed", + "failed": "Failed", + "cancelled": "Cancelled", + "paused": "Paused" + }, + "priority": { + "low": "Low", + "medium": "Medium", + "high": "High", + "urgent": "Urgent" + }, + "visibility": { + "public": "Public", + "private": "Private", + "internal": "Internal", + "shared": "Shared" + }, + "permission": { + "read": "Read", + "write": "Write", + "delete": "Delete", + "admin": "Admin", + "owner": "Owner" + }, + "notification_type": { + "info": "Info", + "warning": "Warning", + "error": "Error", + "success": "Success" + }, + "language": { + "zh": "Chinese (Simplified)", + "en": "English", + "ja": "Japanese", + "ko": "Korean", + "fr": "French", + "de": "German", + "es": "Spanish" + }, + "timezone": { + "utc": "UTC", + "asia_shanghai": "Asia/Shanghai", + "asia_tokyo": "Asia/Tokyo", + "america_new_york": "America/New_York", + "europe_london": "Europe/London" + }, + "date_format": { + "short": "Short", + "medium": "Medium", + "long": "Long", + "full": "Full" + }, + "sort_order": { + "asc": "Ascending", + "desc": "Descending" + }, + "filter_operator": { + "equals": "Equals", + "not_equals": "Not Equals", + "contains": "Contains", + "not_contains": "Not Contains", + "starts_with": "Starts With", + "ends_with": "Ends With", + "greater_than": "Greater Than", + "less_than": "Less Than", + "greater_or_equal": "Greater or Equal", + "less_or_equal": "Less or Equal", + "in": "In", + "not_in": "Not In", + "is_null": "Is Null", + "is_not_null": "Is Not Null" + }, + "log_level": { + "debug": "Debug", + "info": "Info", + "warning": "Warning", + "error": "Error", + "critical": "Critical" + }, + "api_method": { + "get": "GET", + "post": "POST", + "put": "PUT", + "patch": "PATCH", + "delete": "DELETE" + } +} diff --git a/api/app/locales/en/errors.json b/api/app/locales/en/errors.json new file mode 100644 index 00000000..d0276dc9 --- /dev/null +++ b/api/app/locales/en/errors.json @@ -0,0 +1,138 @@ +{ + "common": { + "internal_error": "Internal server error", + "network_error": "Network connection error", + "timeout": "Request timeout", + "service_unavailable": "Service temporarily unavailable", + "bad_request": "Bad request parameters", + "unauthorized": "Unauthorized access", + "forbidden": "Access forbidden", + "not_found": "Resource not found", + "method_not_allowed": "Method not allowed", + "conflict": "Resource conflict", + "too_many_requests": "Too many requests, please try again later", + "validation_failed": "Validation failed", + "database_error": "Database operation failed", + "file_operation_error": "File operation failed" + }, + "auth": { + "invalid_credentials": "Invalid username or password", + "token_expired": "Session expired, please login again", + "token_invalid": "Invalid authentication token", + "token_missing": "Authentication token missing", + "unauthorized": "Unauthorized access", + "forbidden": "Permission denied", + "account_locked": "Account has been locked", + "account_disabled": "Account has been disabled", + "account_not_verified": "Account not verified", + "password_incorrect": "Incorrect password", + "password_too_weak": "Password is too weak", + "password_expired": "Password expired, please change it", + "email_not_verified": "Email not verified", + "phone_not_verified": "Phone number not verified", + "verification_code_invalid": "Invalid verification code", + "verification_code_expired": "Verification code expired", + "login_failed": "Login failed", + "logout_failed": "Logout failed", + "session_expired": "Session expired", + "already_logged_in": "Already logged in", + "not_logged_in": "Not logged in" + }, + "user": { + "not_found": "User not found", + "already_exists": "User already exists", + "email_already_exists": "Email already in use", + "phone_already_exists": "Phone number already in use", + "username_already_exists": "Username already taken", + "invalid_email": "Invalid email format", + "invalid_phone": "Invalid phone number format", + "invalid_username": "Invalid username format", + "create_failed": "Failed to create user", + "update_failed": "Failed to update user", + "delete_failed": "Failed to delete user", + "cannot_delete_self": "Cannot delete yourself", + "cannot_update_self_role": "Cannot update your own role", + "profile_update_failed": "Failed to update profile", + "avatar_upload_failed": "Failed to upload avatar", + "password_change_failed": "Failed to change password", + "old_password_incorrect": "Old password is incorrect" + }, + "workspace": { + "not_found": "Workspace not found", + "already_exists": "Workspace already exists", + "name_required": "Workspace name is required", + "name_too_long": "Workspace name is too long", + "create_failed": "Failed to create workspace", + "update_failed": "Failed to update workspace", + "delete_failed": "Failed to delete workspace", + "permission_denied": "Permission denied to access this workspace", + "not_member": "Not a workspace member", + "already_member": "Already a workspace member", + "member_limit_reached": "Member limit reached", + "cannot_leave_last_manager": "Cannot leave, you are the last manager", + "cannot_remove_last_manager": "Cannot remove the last manager", + "cannot_remove_self": "Cannot remove yourself", + "invite_not_found": "Invite not found", + "invite_expired": "Invite has expired", + "invite_already_accepted": "Invite already accepted", + "invite_already_revoked": "Invite already revoked", + "invite_send_failed": "Failed to send invite", + "archived": "Workspace is archived", + "suspended": "Workspace is suspended" + }, + "tenant": { + "not_found": "Tenant not found", + "already_exists": "Tenant already exists", + "create_failed": "Failed to create tenant", + "update_failed": "Failed to update tenant", + "delete_failed": "Failed to delete tenant", + "suspended": "Tenant is suspended", + "expired": "Tenant has expired", + "license_invalid": "Invalid license", + "license_expired": "License has expired", + "quota_exceeded": "Quota exceeded" + }, + "file": { + "not_found": "File not found", + "upload_failed": "File upload failed", + "download_failed": "File download failed", + "delete_failed": "File deletion failed", + "too_large": "File size exceeds limit", + "invalid_type": "Unsupported file type", + "invalid_format": "Invalid file format", + "corrupted": "File is corrupted", + "storage_full": "Storage is full", + "access_denied": "Access denied to this file" + }, + "api": { + "rate_limit_exceeded": "API rate limit exceeded", + "quota_exceeded": "API quota exceeded", + "invalid_api_key": "Invalid API key", + "api_key_expired": "API key has expired", + "api_key_revoked": "API key has been revoked", + "endpoint_not_found": "API endpoint not found", + "method_not_allowed": "Method not allowed", + "invalid_request": "Invalid request", + "missing_parameter": "Missing required parameter: {param}", + "invalid_parameter": "Invalid parameter: {param}" + }, + "database": { + "connection_failed": "Database connection failed", + "query_failed": "Database query failed", + "transaction_failed": "Database transaction failed", + "constraint_violation": "Data constraint violation", + "duplicate_key": "Duplicate data", + "foreign_key_violation": "Foreign key constraint violation", + "deadlock": "Database deadlock" + }, + "validation": { + "invalid_input": "Invalid input data", + "missing_field": "Missing required field: {field}", + "invalid_field": "Invalid field: {field}", + "field_too_long": "Field too long: {field}", + "field_too_short": "Field too short: {field}", + "invalid_format": "Invalid format: {field}", + "invalid_value": "Invalid value: {field}", + "out_of_range": "Value out of range: {field}" + } +} diff --git a/api/app/locales/en/i18n.json b/api/app/locales/en/i18n.json new file mode 100644 index 00000000..1662836d --- /dev/null +++ b/api/app/locales/en/i18n.json @@ -0,0 +1,27 @@ +{ + "language": { + "not_found": "Language {locale} not found", + "already_exists": "Language {locale} already exists", + "add_instructions": "Language {locale} validated successfully. Please create translation files in {dir} directory to complete the addition.", + "update_instructions": "Language {locale} update validated successfully. Please update I18N_SUPPORTED_LANGUAGES environment variable to apply configuration changes." + }, + "namespace": { + "not_found": "Namespace {namespace} not found in language {locale}" + }, + "translation": { + "invalid_key_format": "Invalid translation key format: {key}. Should use format: namespace.key.subkey", + "update_instructions": "Translation {locale}/{key} update validated successfully. Please modify the corresponding JSON translation file to apply changes." + }, + "reload": { + "disabled": "Translation hot reload is disabled. Please enable I18N_ENABLE_HOT_RELOAD in configuration.", + "success": "Translations reloaded successfully", + "failed": "Translation reload failed: {error}" + }, + "metrics": { + "reset_success": "Performance metrics reset successfully" + }, + "logs": { + "export_success": "Missing translations exported to: {file}", + "clear_success": "Missing translation logs cleared successfully" + } +} diff --git a/api/app/locales/en/tenant.json b/api/app/locales/en/tenant.json new file mode 100644 index 00000000..8c3b4b02 --- /dev/null +++ b/api/app/locales/en/tenant.json @@ -0,0 +1,63 @@ +{ + "info": { + "get_success": "Tenant information retrieved successfully", + "get_failed": "Failed to retrieve tenant information", + "update_success": "Tenant information updated successfully", + "update_failed": "Failed to update tenant information" + }, + "create": { + "success": "Tenant created successfully", + "failed": "Failed to create tenant" + }, + "delete": { + "success": "Tenant deleted successfully", + "failed": "Failed to delete tenant" + }, + "status": { + "activate_success": "Tenant activated successfully", + "activate_failed": "Failed to activate tenant", + "deactivate_success": "Tenant deactivated successfully", + "deactivate_failed": "Failed to deactivate tenant" + }, + "language": { + "get_success": "Tenant language configuration retrieved successfully", + "get_failed": "Failed to retrieve tenant language configuration", + "update_success": "Tenant language configuration updated successfully", + "update_failed": "Failed to update tenant language configuration", + "invalid_language": "Unsupported language code", + "default_not_in_supported": "Default language must be in the supported languages list" + }, + "list": { + "get_success": "Tenant list retrieved successfully", + "get_failed": "Failed to retrieve tenant list" + }, + "users": { + "list_success": "Tenant user list retrieved successfully", + "list_failed": "Failed to retrieve tenant user list", + "assign_success": "User assigned to tenant successfully", + "assign_failed": "Failed to assign user to tenant", + "remove_success": "User removed from tenant successfully", + "remove_failed": "Failed to remove user from tenant" + }, + "statistics": { + "get_success": "Tenant statistics retrieved successfully", + "get_failed": "Failed to retrieve tenant statistics" + }, + "validation": { + "name_required": "Tenant name is required", + "name_invalid": "Invalid tenant name format", + "name_too_long": "Tenant name cannot exceed {max} characters", + "description_too_long": "Tenant description cannot exceed {max} characters", + "language_code_invalid": "Invalid language code format", + "supported_languages_empty": "Supported languages list cannot be empty" + }, + "errors": { + "not_found": "Tenant not found", + "already_exists": "Tenant name already exists", + "permission_denied": "Permission denied to access this tenant", + "has_users": "Cannot delete tenant, associated users exist", + "has_workspaces": "Cannot delete tenant, associated workspaces exist", + "already_active": "Tenant is already active", + "already_inactive": "Tenant is already inactive" + } +} diff --git a/api/app/locales/en/users.json b/api/app/locales/en/users.json new file mode 100644 index 00000000..efd5d034 --- /dev/null +++ b/api/app/locales/en/users.json @@ -0,0 +1,72 @@ +{ + "info": { + "get_success": "User information retrieved successfully", + "get_failed": "Failed to retrieve user information", + "update_success": "User information updated successfully", + "update_failed": "Failed to update user information" + }, + "create": { + "success": "User created successfully", + "failed": "Failed to create user", + "superuser_success": "Superuser created successfully", + "superuser_failed": "Failed to create superuser" + }, + "delete": { + "success": "User deleted successfully", + "failed": "Failed to delete user", + "deactivate_success": "User deactivated successfully", + "deactivate_failed": "Failed to deactivate user" + }, + "activate": { + "success": "User activated successfully", + "failed": "Failed to activate user" + }, + "language": { + "get_success": "Language preference retrieved successfully", + "get_failed": "Failed to retrieve language preference", + "update_success": "Language preference updated successfully", + "update_failed": "Failed to update language preference", + "invalid_language": "Unsupported language code", + "current": "Current language preference" + }, + "email": { + "change_success": "Email changed successfully", + "change_failed": "Failed to change email", + "code_sent": "Verification code has been sent to your email", + "code_send_failed": "Failed to send verification code", + "code_invalid": "Invalid or expired verification code", + "already_exists": "Email already in use" + }, + "list": { + "get_success": "User list retrieved successfully", + "get_failed": "Failed to retrieve user list", + "superusers_success": "Tenant superuser list retrieved successfully", + "superusers_failed": "Failed to retrieve tenant superuser list" + }, + "validation": { + "username_required": "Username is required", + "username_invalid": "Invalid username format", + "username_too_long": "Username cannot exceed {max} characters", + "email_required": "Email is required", + "email_invalid": "Invalid email format", + "password_required": "Password is required", + "password_too_short": "Password must be at least {min} characters", + "password_too_long": "Password cannot exceed {max} characters", + "old_password_required": "Old password is required", + "new_password_required": "New password is required", + "verification_code_required": "Verification code is required", + "verification_code_invalid": "Invalid verification code format" + }, + "errors": { + "not_found": "User not found", + "already_exists": "User already exists", + "permission_denied": "Permission denied to access this user", + "cannot_delete_self": "Cannot delete yourself", + "cannot_deactivate_self": "Cannot deactivate yourself", + "already_deactivated": "User is already deactivated", + "already_activated": "User is already activated", + "password_verification_failed": "Password verification failed", + "old_password_incorrect": "Old password is incorrect", + "same_as_old_password": "New password cannot be the same as old password" + } +} diff --git a/api/app/locales/en/workspace.json b/api/app/locales/en/workspace.json new file mode 100644 index 00000000..cca29698 --- /dev/null +++ b/api/app/locales/en/workspace.json @@ -0,0 +1,44 @@ +{ + "list_retrieved": "Workspace list retrieved successfully", + "created": "Workspace created successfully", + "updated": "Workspace updated successfully", + "deleted": "Workspace deleted successfully", + "switched": "Workspace switched successfully", + "not_found": "Workspace not found or access denied", + "already_exists": "Workspace already exists", + "permission_denied": "No permission to access this workspace", + "name_required": "Workspace name is required", + "invalid_name": "Invalid workspace name format", + "members": { + "list_retrieved": "Workspace members list retrieved successfully", + "role_updated": "Member role updated successfully", + "deleted": "Member deleted successfully", + "not_found": "Member not found", + "cannot_remove_self": "Cannot remove yourself", + "cannot_remove_last_manager": "Cannot remove the last manager", + "already_member": "User is already a workspace member" + }, + "invites": { + "created": "Invite created successfully", + "list_retrieved": "Invite list retrieved successfully", + "validated": "Invite validated successfully", + "revoked": "Invite revoked successfully", + "accepted": "Invite accepted", + "not_found": "Invite not found", + "expired": "Invite has expired", + "already_used": "Invite has already been used", + "invalid_token": "Invalid invite token", + "email_required": "Email address is required", + "invalid_email": "Invalid email address format" + }, + "storage": { + "type_retrieved": "Storage type retrieved successfully", + "type_updated": "Storage type updated successfully", + "invalid_type": "Invalid storage type" + }, + "models": { + "config_retrieved": "Model configuration retrieved successfully", + "config_updated": "Model configuration updated successfully", + "invalid_config": "Invalid model configuration" + } +} diff --git a/api/app/locales/zh/README.md b/api/app/locales/zh/README.md new file mode 100644 index 00000000..edaa0fb4 --- /dev/null +++ b/api/app/locales/zh/README.md @@ -0,0 +1,26 @@ +# 中文翻译文件 + +此目录包含中文(简体)的翻译文件。 + +## 文件结构 + +- `common.json` - 通用翻译(成功消息、操作、验证) +- `auth.json` - 认证模块翻译 +- `workspace.json` - 工作空间模块翻译 +- `tenant.json` - 租户模块翻译 +- `errors.json` - 错误消息翻译 +- `enums.json` - 枚举值翻译 + +## 翻译文件格式 + +所有翻译文件使用 JSON 格式,支持嵌套结构。 + +示例: +```json +{ + "success": { + "created": "创建成功", + "updated": "更新成功" + } +} +``` diff --git a/api/app/locales/zh/auth.json b/api/app/locales/zh/auth.json new file mode 100644 index 00000000..283d2ffb --- /dev/null +++ b/api/app/locales/zh/auth.json @@ -0,0 +1,55 @@ +{ + "login": { + "success": "登录成功", + "failed": "登录失败", + "invalid_credentials": "用户名或密码错误", + "account_locked": "账户已被锁定", + "account_disabled": "账户已被禁用" + }, + "logout": { + "success": "登出成功", + "failed": "登出失败" + }, + "token": { + "refresh_success": "token刷新成功", + "invalid": "无效的token", + "expired": "token已过期", + "blacklisted": "token已失效", + "invalid_refresh_token": "无效的refresh token", + "refresh_token_blacklisted": "Refresh token已失效" + }, + "registration": { + "success": "注册成功", + "failed": "注册失败", + "email_exists": "邮箱已被使用", + "username_exists": "用户名已被使用" + }, + "password": { + "reset_success": "密码重置成功", + "reset_failed": "密码重置失败", + "change_success": "密码修改成功", + "change_failed": "密码修改失败", + "incorrect": "密码错误", + "too_weak": "密码强度不够", + "mismatch": "两次输入的密码不一致" + }, + "invite": { + "invalid": "邀请码无效或已过期", + "email_mismatch": "邀请邮箱与登录邮箱不匹配", + "accept_success": "接受邀请成功", + "accept_failed": "接受邀请失败", + "password_verification_failed": "接受邀请失败,密码验证错误", + "bind_workspace_success": "绑定工作空间成功", + "bind_workspace_failed": "绑定工作空间失败" + }, + "user": { + "not_found": "用户不存在", + "already_exists": "用户已存在", + "created_with_invite": "用户创建成功并已加入工作空间" + }, + "session": { + "expired": "会话已过期,请重新登录", + "invalid": "无效的会话", + "single_session_enabled": "单点登录已启用,其他设备的登录将被注销" + } +} diff --git a/api/app/locales/zh/common.json b/api/app/locales/zh/common.json new file mode 100644 index 00000000..b3c62adc --- /dev/null +++ b/api/app/locales/zh/common.json @@ -0,0 +1,132 @@ +{ + "success": { + "created": "创建成功", + "updated": "更新成功", + "deleted": "删除成功", + "retrieved": "获取成功", + "saved": "保存成功", + "uploaded": "上传成功", + "downloaded": "下载成功", + "sent": "发送成功", + "completed": "完成", + "confirmed": "已确认", + "cancelled": "已取消", + "archived": "已归档", + "restored": "已恢复" + }, + "actions": { + "create": "创建", + "update": "更新", + "delete": "删除", + "view": "查看", + "edit": "编辑", + "save": "保存", + "cancel": "取消", + "confirm": "确认", + "submit": "提交", + "upload": "上传", + "download": "下载", + "send": "发送", + "search": "搜索", + "filter": "筛选", + "sort": "排序", + "export": "导出", + "import": "导入", + "refresh": "刷新", + "reset": "重置", + "back": "返回", + "next": "下一步", + "previous": "上一步", + "finish": "完成", + "close": "关闭", + "open": "打开", + "archive": "归档", + "restore": "恢复", + "duplicate": "复制", + "share": "分享", + "invite": "邀请", + "remove": "移除", + "add": "添加", + "select": "选择", + "clear": "清除" + }, + "validation": { + "required": "{field}不能为空", + "invalid_format": "{field}格式不正确", + "too_long": "{field}长度不能超过{max}个字符", + "too_short": "{field}长度不能少于{min}个字符", + "invalid_email": "邮箱格式不正确", + "invalid_url": "URL格式不正确", + "invalid_phone": "手机号格式不正确", + "invalid_date": "日期格式不正确", + "invalid_number": "必须是有效的数字", + "out_of_range": "{field}必须在{min}和{max}之间", + "already_exists": "{field}已存在", + "not_found": "{field}不存在", + "invalid_value": "{field}的值无效", + "password_mismatch": "两次输入的密码不一致", + "weak_password": "密码强度不够,请使用更复杂的密码", + "invalid_credentials": "用户名或密码错误", + "unauthorized": "未授权访问", + "forbidden": "没有权限执行此操作", + "expired": "{field}已过期", + "invalid_token": "无效的令牌", + "file_too_large": "文件大小不能超过{max}", + "invalid_file_type": "不支持的文件类型", + "duplicate": "重复的{field}" + }, + "status": { + "active": "活跃", + "inactive": "未激活", + "pending": "待处理", + "processing": "处理中", + "completed": "已完成", + "failed": "失败", + "cancelled": "已取消", + "archived": "已归档", + "deleted": "已删除", + "draft": "草稿", + "published": "已发布", + "suspended": "已暂停", + "expired": "已过期" + }, + "messages": { + "loading": "加载中...", + "saving": "保存中...", + "processing": "处理中...", + "uploading": "上传中...", + "downloading": "下载中...", + "no_data": "暂无数据", + "no_results": "没有找到结果", + "confirm_delete": "确定要删除吗?此操作不可恢复。", + "confirm_action": "确定要执行此操作吗?", + "operation_success": "操作成功", + "operation_failed": "操作失败", + "please_wait": "请稍候...", + "try_again": "请重试", + "contact_support": "如果问题持续,请联系技术支持" + }, + "pagination": { + "page": "第{page}页", + "of": "共{total}页", + "items": "共{total}条", + "per_page": "每页{count}条", + "showing": "显示第{from}到第{to}条,共{total}条", + "first": "首页", + "last": "末页", + "next": "下一页", + "previous": "上一页" + }, + "time": { + "just_now": "刚刚", + "minutes_ago": "{count}分钟前", + "hours_ago": "{count}小时前", + "days_ago": "{count}天前", + "weeks_ago": "{count}周前", + "months_ago": "{count}个月前", + "years_ago": "{count}年前", + "today": "今天", + "yesterday": "昨天", + "tomorrow": "明天" + } +} diff --git a/api/app/locales/zh/enums.json b/api/app/locales/zh/enums.json new file mode 100644 index 00000000..9a241817 --- /dev/null +++ b/api/app/locales/zh/enums.json @@ -0,0 +1,132 @@ +{ + "workspace_role": { + "owner": "所有者", + "manager": "管理员", + "member": "成员", + "guest": "访客" + }, + "workspace_status": { + "active": "活跃", + "inactive": "未激活", + "archived": "已归档", + "suspended": "已暂停", + "deleted": "已删除" + }, + "invite_status": { + "pending": "待处理", + "accepted": "已接受", + "rejected": "已拒绝", + "revoked": "已撤销", + "expired": "已过期" + }, + "user_status": { + "active": "活跃", + "inactive": "未激活", + "suspended": "已暂停", + "deleted": "已删除", + "pending": "待激活" + }, + "tenant_status": { + "active": "活跃", + "inactive": "未激活", + "suspended": "已暂停", + "expired": "已过期", + "trial": "试用中" + }, + "file_status": { + "uploading": "上传中", + "processing": "处理中", + "completed": "已完成", + "failed": "失败", + "deleted": "已删除" + }, + "task_status": { + "pending": "待处理", + "running": "运行中", + "completed": "已完成", + "failed": "失败", + "cancelled": "已取消", + "paused": "已暂停" + }, + "priority": { + "low": "低", + "medium": "中", + "high": "高", + "urgent": "紧急" + }, + "visibility": { + "public": "公开", + "private": "私有", + "internal": "内部", + "shared": "共享" + }, + "permission": { + "read": "读取", + "write": "写入", + "delete": "删除", + "admin": "管理", + "owner": "所有者" + }, + "notification_type": { + "info": "信息", + "warning": "警告", + "error": "错误", + "success": "成功" + }, + "language": { + "zh": "中文(简体)", + "en": "English", + "ja": "日本語", + "ko": "한국어", + "fr": "Français", + "de": "Deutsch", + "es": "Español" + }, + "timezone": { + "utc": "UTC", + "asia_shanghai": "亚洲/上海", + "asia_tokyo": "亚洲/东京", + "america_new_york": "美洲/纽约", + "europe_london": "欧洲/伦敦" + }, + "date_format": { + "short": "短日期", + "medium": "中等日期", + "long": "长日期", + "full": "完整日期" + }, + "sort_order": { + "asc": "升序", + "desc": "降序" + }, + "filter_operator": { + "equals": "等于", + "not_equals": "不等于", + "contains": "包含", + "not_contains": "不包含", + "starts_with": "开始于", + "ends_with": "结束于", + "greater_than": "大于", + "less_than": "小于", + "greater_or_equal": "大于等于", + "less_or_equal": "小于等于", + "in": "在列表中", + "not_in": "不在列表中", + "is_null": "为空", + "is_not_null": "不为空" + }, + "log_level": { + "debug": "调试", + "info": "信息", + "warning": "警告", + "error": "错误", + "critical": "严重" + }, + "api_method": { + "get": "GET", + "post": "POST", + "put": "PUT", + "patch": "PATCH", + "delete": "DELETE" + } +} diff --git a/api/app/locales/zh/errors.json b/api/app/locales/zh/errors.json new file mode 100644 index 00000000..eafadad4 --- /dev/null +++ b/api/app/locales/zh/errors.json @@ -0,0 +1,138 @@ +{ + "common": { + "internal_error": "服务器内部错误", + "network_error": "网络连接错误", + "timeout": "请求超时", + "service_unavailable": "服务暂时不可用", + "bad_request": "请求参数错误", + "unauthorized": "未授权访问", + "forbidden": "没有权限访问", + "not_found": "请求的资源不存在", + "method_not_allowed": "不支持的请求方法", + "conflict": "资源冲突", + "too_many_requests": "请求过于频繁,请稍后再试", + "validation_failed": "数据验证失败", + "database_error": "数据库操作失败", + "file_operation_error": "文件操作失败" + }, + "auth": { + "invalid_credentials": "用户名或密码错误", + "token_expired": "登录已过期,请重新登录", + "token_invalid": "无效的登录令牌", + "token_missing": "缺少登录令牌", + "unauthorized": "未授权访问", + "forbidden": "没有权限执行此操作", + "account_locked": "账户已被锁定", + "account_disabled": "账户已被禁用", + "account_not_verified": "账户未验证", + "password_incorrect": "密码错误", + "password_too_weak": "密码强度不够", + "password_expired": "密码已过期,请修改密码", + "email_not_verified": "邮箱未验证", + "phone_not_verified": "手机号未验证", + "verification_code_invalid": "验证码无效", + "verification_code_expired": "验证码已过期", + "login_failed": "登录失败", + "logout_failed": "登出失败", + "session_expired": "会话已过期", + "already_logged_in": "已经登录", + "not_logged_in": "未登录" + }, + "user": { + "not_found": "用户不存在", + "already_exists": "用户已存在", + "email_already_exists": "邮箱已被使用", + "phone_already_exists": "手机号已被使用", + "username_already_exists": "用户名已被使用", + "invalid_email": "邮箱格式不正确", + "invalid_phone": "手机号格式不正确", + "invalid_username": "用户名格式不正确", + "create_failed": "创建用户失败", + "update_failed": "更新用户失败", + "delete_failed": "删除用户失败", + "cannot_delete_self": "不能删除自己", + "cannot_update_self_role": "不能修改自己的角色", + "profile_update_failed": "更新个人资料失败", + "avatar_upload_failed": "上传头像失败", + "password_change_failed": "修改密码失败", + "old_password_incorrect": "原密码错误" + }, + "workspace": { + "not_found": "工作空间不存在", + "already_exists": "工作空间已存在", + "name_required": "工作空间名称不能为空", + "name_too_long": "工作空间名称过长", + "create_failed": "创建工作空间失败", + "update_failed": "更新工作空间失败", + "delete_failed": "删除工作空间失败", + "permission_denied": "没有权限访问此工作空间", + "not_member": "不是工作空间成员", + "already_member": "已经是工作空间成员", + "member_limit_reached": "成员数量已达上限", + "cannot_leave_last_manager": "不能离开,您是最后一个管理员", + "cannot_remove_last_manager": "不能移除最后一个管理员", + "cannot_remove_self": "不能移除自己", + "invite_not_found": "邀请不存在", + "invite_expired": "邀请已过期", + "invite_already_accepted": "邀请已被接受", + "invite_already_revoked": "邀请已被撤销", + "invite_send_failed": "发送邀请失败", + "archived": "工作空间已归档", + "suspended": "工作空间已暂停" + }, + "tenant": { + "not_found": "租户不存在", + "already_exists": "租户已存在", + "create_failed": "创建租户失败", + "update_failed": "更新租户失败", + "delete_failed": "删除租户失败", + "suspended": "租户已暂停", + "expired": "租户已过期", + "license_invalid": "许可证无效", + "license_expired": "许可证已过期", + "quota_exceeded": "配额已超限" + }, + "file": { + "not_found": "文件不存在", + "upload_failed": "文件上传失败", + "download_failed": "文件下载失败", + "delete_failed": "文件删除失败", + "too_large": "文件大小超过限制", + "invalid_type": "不支持的文件类型", + "invalid_format": "文件格式不正确", + "corrupted": "文件已损坏", + "storage_full": "存储空间已满", + "access_denied": "没有权限访问此文件" + }, + "api": { + "rate_limit_exceeded": "API调用频率超限", + "quota_exceeded": "API调用配额已用完", + "invalid_api_key": "无效的API密钥", + "api_key_expired": "API密钥已过期", + "api_key_revoked": "API密钥已被撤销", + "endpoint_not_found": "API端点不存在", + "method_not_allowed": "不支持的请求方法", + "invalid_request": "无效的请求", + "missing_parameter": "缺少必需参数:{param}", + "invalid_parameter": "参数无效:{param}" + }, + "database": { + "connection_failed": "数据库连接失败", + "query_failed": "数据库查询失败", + "transaction_failed": "数据库事务失败", + "constraint_violation": "数据约束冲突", + "duplicate_key": "数据重复", + "foreign_key_violation": "外键约束冲突", + "deadlock": "数据库死锁" + }, + "validation": { + "invalid_input": "输入数据无效", + "missing_field": "缺少必需字段:{field}", + "invalid_field": "字段无效:{field}", + "field_too_long": "字段过长:{field}", + "field_too_short": "字段过短:{field}", + "invalid_format": "格式不正确:{field}", + "invalid_value": "值无效:{field}", + "out_of_range": "值超出范围:{field}" + } +} diff --git a/api/app/locales/zh/i18n.json b/api/app/locales/zh/i18n.json new file mode 100644 index 00000000..a072f332 --- /dev/null +++ b/api/app/locales/zh/i18n.json @@ -0,0 +1,27 @@ +{ + "language": { + "not_found": "语言 {locale} 不存在", + "already_exists": "语言 {locale} 已存在", + "add_instructions": "语言 {locale} 验证成功。请在 {dir} 目录下创建翻译文件以完成添加。", + "update_instructions": "语言 {locale} 更新验证成功。请更新环境变量 I18N_SUPPORTED_LANGUAGES 以应用配置更改。" + }, + "namespace": { + "not_found": "命名空间 {namespace} 在语言 {locale} 中不存在" + }, + "translation": { + "invalid_key_format": "翻译键格式无效: {key}。应使用格式: namespace.key.subkey", + "update_instructions": "翻译 {locale}/{key} 更新验证成功。请修改对应的 JSON 翻译文件以应用更改。" + }, + "reload": { + "disabled": "翻译热重载功能已禁用。请在配置中启用 I18N_ENABLE_HOT_RELOAD。", + "success": "翻译重载成功", + "failed": "翻译重载失败: {error}" + }, + "metrics": { + "reset_success": "性能指标已重置" + }, + "logs": { + "export_success": "缺失翻译已导出到: {file}", + "clear_success": "缺失翻译日志已清除" + } +} diff --git a/api/app/locales/zh/tenant.json b/api/app/locales/zh/tenant.json new file mode 100644 index 00000000..a8bdc124 --- /dev/null +++ b/api/app/locales/zh/tenant.json @@ -0,0 +1,63 @@ +{ + "info": { + "get_success": "租户信息获取成功", + "get_failed": "租户信息获取失败", + "update_success": "租户信息更新成功", + "update_failed": "租户信息更新失败" + }, + "create": { + "success": "租户创建成功", + "failed": "租户创建失败" + }, + "delete": { + "success": "租户删除成功", + "failed": "租户删除失败" + }, + "status": { + "activate_success": "租户启用成功", + "activate_failed": "租户启用失败", + "deactivate_success": "租户禁用成功", + "deactivate_failed": "租户禁用失败" + }, + "language": { + "get_success": "租户语言配置获取成功", + "get_failed": "租户语言配置获取失败", + "update_success": "租户语言配置更新成功", + "update_failed": "租户语言配置更新失败", + "invalid_language": "不支持的语言代码", + "default_not_in_supported": "默认语言必须在支持的语言列表中" + }, + "list": { + "get_success": "租户列表获取成功", + "get_failed": "租户列表获取失败" + }, + "users": { + "list_success": "租户用户列表获取成功", + "list_failed": "租户用户列表获取失败", + "assign_success": "用户分配到租户成功", + "assign_failed": "用户分配到租户失败", + "remove_success": "用户从租户移除成功", + "remove_failed": "用户从租户移除失败" + }, + "statistics": { + "get_success": "租户统计信息获取成功", + "get_failed": "租户统计信息获取失败" + }, + "validation": { + "name_required": "租户名称不能为空", + "name_invalid": "租户名称格式不正确", + "name_too_long": "租户名称长度不能超过{max}个字符", + "description_too_long": "租户描述长度不能超过{max}个字符", + "language_code_invalid": "语言代码格式不正确", + "supported_languages_empty": "支持的语言列表不能为空" + }, + "errors": { + "not_found": "租户不存在", + "already_exists": "租户名称已存在", + "permission_denied": "没有权限访问此租户", + "has_users": "无法删除租户,存在关联的用户", + "has_workspaces": "无法删除租户,存在关联的工作空间", + "already_active": "租户已处于激活状态", + "already_inactive": "租户已处于禁用状态" + } +} diff --git a/api/app/locales/zh/users.json b/api/app/locales/zh/users.json new file mode 100644 index 00000000..a446ed8d --- /dev/null +++ b/api/app/locales/zh/users.json @@ -0,0 +1,72 @@ +{ + "info": { + "get_success": "用户信息获取成功", + "get_failed": "用户信息获取失败", + "update_success": "用户信息更新成功", + "update_failed": "用户信息更新失败" + }, + "create": { + "success": "用户创建成功", + "failed": "用户创建失败", + "superuser_success": "超级管理员创建成功", + "superuser_failed": "超级管理员创建失败" + }, + "delete": { + "success": "用户删除成功", + "failed": "用户删除失败", + "deactivate_success": "用户停用成功", + "deactivate_failed": "用户停用失败" + }, + "activate": { + "success": "用户激活成功", + "failed": "用户激活失败" + }, + "language": { + "get_success": "语言偏好获取成功", + "get_failed": "语言偏好获取失败", + "update_success": "语言偏好更新成功", + "update_failed": "语言偏好更新失败", + "invalid_language": "不支持的语言代码", + "current": "当前语言偏好" + }, + "email": { + "change_success": "邮箱修改成功", + "change_failed": "邮箱修改失败", + "code_sent": "验证码已发送到您的邮箱,请查收", + "code_send_failed": "验证码发送失败", + "code_invalid": "验证码无效或已过期", + "already_exists": "该邮箱已被使用" + }, + "list": { + "get_success": "用户列表获取成功", + "get_failed": "用户列表获取失败", + "superusers_success": "租户超管列表获取成功", + "superusers_failed": "租户超管列表获取失败" + }, + "validation": { + "username_required": "用户名不能为空", + "username_invalid": "用户名格式不正确", + "username_too_long": "用户名长度不能超过{max}个字符", + "email_required": "邮箱不能为空", + "email_invalid": "邮箱格式不正确", + "password_required": "密码不能为空", + "password_too_short": "密码长度不能少于{min}个字符", + "password_too_long": "密码长度不能超过{max}个字符", + "old_password_required": "旧密码不能为空", + "new_password_required": "新密码不能为空", + "verification_code_required": "验证码不能为空", + "verification_code_invalid": "验证码格式不正确" + }, + "errors": { + "not_found": "用户不存在", + "already_exists": "用户已存在", + "permission_denied": "没有权限访问此用户", + "cannot_delete_self": "不能删除自己", + "cannot_deactivate_self": "不能停用自己", + "already_deactivated": "用户已被停用", + "already_activated": "用户已处于激活状态", + "password_verification_failed": "密码验证失败", + "old_password_incorrect": "旧密码不正确", + "same_as_old_password": "新密码不能与旧密码相同" + } +} diff --git a/api/app/locales/zh/workspace.json b/api/app/locales/zh/workspace.json new file mode 100644 index 00000000..e7dba7dc --- /dev/null +++ b/api/app/locales/zh/workspace.json @@ -0,0 +1,44 @@ +{ + "list_retrieved": "工作空间列表获取成功", + "created": "工作空间创建成功", + "updated": "工作空间更新成功", + "deleted": "工作空间删除成功", + "switched": "工作空间切换成功", + "not_found": "工作空间不存在或无权访问", + "already_exists": "工作空间已存在", + "permission_denied": "没有权限访问此工作空间", + "name_required": "工作空间名称不能为空", + "invalid_name": "工作空间名称格式不正确", + "members": { + "list_retrieved": "工作空间成员列表获取成功", + "role_updated": "成员角色更新成功", + "deleted": "成员删除成功", + "not_found": "成员不存在", + "cannot_remove_self": "不能删除自己", + "cannot_remove_last_manager": "不能删除最后一个管理员", + "already_member": "用户已经是工作空间成员" + }, + "invites": { + "created": "邀请创建成功", + "list_retrieved": "邀请列表获取成功", + "validated": "邀请验证成功", + "revoked": "邀请撤销成功", + "accepted": "邀请已接受", + "not_found": "邀请不存在", + "expired": "邀请已过期", + "already_used": "邀请已被使用", + "invalid_token": "无效的邀请令牌", + "email_required": "邮箱地址不能为空", + "invalid_email": "邮箱地址格式不正确" + }, + "storage": { + "type_retrieved": "存储类型获取成功", + "type_updated": "存储类型更新成功", + "invalid_type": "无效的存储类型" + }, + "models": { + "config_retrieved": "模型配置获取成功", + "config_updated": "模型配置更新成功", + "invalid_config": "无效的模型配置" + } +} diff --git a/api/app/main.py b/api/app/main.py index af5ed796..c6256e3c 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -92,6 +92,10 @@ app.add_middleware( allow_headers=["*"], ) +# Add i18n language detection middleware +from app.i18n.middleware import LanguageMiddleware +app.add_middleware(LanguageMiddleware) + logger.info("FastAPI应用程序启动") @@ -129,6 +133,11 @@ from app.core.exceptions import ( from app.core.sensitive_filter import SensitiveDataFilter import traceback +# Import i18n exception support +from app.i18n.exceptions import I18nException +from app.i18n.service import get_translation_service +from pydantic import ValidationError as PydanticValidationError + # 处理验证异常 @app.exception_handler(ValidationException) @@ -156,6 +165,131 @@ async def validation_exception_handler(request: Request, exc: ValidationExceptio ) +# 处理 i18n 异常(国际化异常) +@app.exception_handler(I18nException) +async def i18n_exception_handler(request: Request, exc: I18nException): + """ + 处理国际化异常 + + I18nException 已经自动翻译了错误消息,直接返回即可 + """ + # 获取当前语言 + language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE) + + # 获取异常详情(已经包含翻译后的消息) + detail = exc.detail + + # 过滤敏感信息 + if isinstance(detail, dict): + filtered_message = SensitiveDataFilter.filter_string(detail.get("message", "")) + filtered_detail = { + **detail, + "message": filtered_message + } + else: + filtered_detail = SensitiveDataFilter.filter_string(str(detail)) + + logger.warning( + f"I18n exception: {exc.error_key}", + extra={ + "path": request.url.path, + "method": request.method, + "error_code": exc.error_code, + "error_key": exc.error_key, + "language": language, + "status_code": exc.status_code, + "params": exc.params + } + ) + + return JSONResponse( + status_code=exc.status_code, + content={ + "success": False, + **filtered_detail + }, + headers=exc.headers + ) + + +# 处理 Pydantic 验证错误(国际化支持) +@app.exception_handler(PydanticValidationError) +async def pydantic_validation_exception_handler(request: Request, exc: PydanticValidationError): + """ + 处理 Pydantic 验证错误,支持国际化 + """ + # 获取当前语言 + language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE) + + # 获取翻译服务 + translation_service = get_translation_service() + + # 翻译验证错误消息 + errors = [] + for error in exc.errors(): + field = ".".join(str(loc) for loc in error["loc"]) + error_type = error["type"] + + # 尝试翻译错误消息 + if error_type == "value_error.missing": + message = translation_service.translate( + "errors.validation.missing_field", + language, + field=field + ) + elif error_type == "value_error.any_str.max_length": + message = translation_service.translate( + "errors.validation.field_too_long", + language, + field=field + ) + elif error_type == "value_error.any_str.min_length": + message = translation_service.translate( + "errors.validation.field_too_short", + language, + field=field + ) + else: + # 使用通用验证错误消息 + message = translation_service.translate( + "errors.validation.invalid_field", + language, + field=field + ) + + errors.append({ + "field": field, + "message": message, + "type": error_type + }) + + # 翻译主错误消息 + main_message = translation_service.translate( + "errors.common.validation_failed", + language + ) + + logger.warning( + f"Pydantic validation error: {len(errors)} errors", + extra={ + "path": request.url.path, + "method": request.method, + "language": language, + "errors": errors + } + ) + + return JSONResponse( + status_code=422, + content={ + "success": False, + "error_code": "VALIDATION_FAILED", + "message": main_message, + "errors": errors + } + ) + + # 处理资源不存在异常 @app.exception_handler(ResourceNotFoundException) async def not_found_exception_handler(request: Request, exc: ResourceNotFoundException): @@ -354,31 +488,66 @@ async def business_exception_handler(request: Request, exc: BusinessException): ) -# 统一异常处理:将HTTPException转换为统一响应结构 +# 统一异常处理:将HTTPException转换为统一响应结构(支持国际化) @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): - """处理HTTP异常""" - # 过滤敏感信息 - filtered_detail = SensitiveDataFilter.filter_string(str(exc.detail)) - + """处理HTTP异常,支持国际化""" + # 获取当前语言 + language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE) + + # 获取翻译服务 + translation_service = get_translation_service() + + # 尝试翻译标准HTTP错误 + error_key_map = { + 400: "errors.common.bad_request", + 401: "errors.common.unauthorized", + 403: "errors.common.forbidden", + 404: "errors.common.not_found", + 405: "errors.common.method_not_allowed", + 409: "errors.common.conflict", + 422: "errors.common.validation_failed", + 429: "errors.common.too_many_requests", + 500: "errors.common.internal_error", + 503: "errors.common.service_unavailable", + } + + # 如果有对应的翻译键,使用翻译 + if exc.status_code in error_key_map: + translated_message = translation_service.translate( + error_key_map[exc.status_code], + language + ) + else: + # 否则过滤原始消息 + translated_message = SensitiveDataFilter.filter_string(str(exc.detail)) + logger.warning( - f"HTTP exception: {filtered_detail}", + f"HTTP exception: {translated_message}", extra={ "path": request.url.path, "method": request.method, - "status_code": exc.status_code + "status_code": exc.status_code, + "language": language } ) + return JSONResponse( status_code=exc.status_code, - content=fail(code=exc.status_code, msg=filtered_detail, error=filtered_detail) + content=fail(code=exc.status_code, msg=translated_message, error=translated_message) ) -# 捕获未处理的异常,返回统一错误结构 +# 捕获未处理的异常,返回统一错误结构(支持国际化) @app.exception_handler(Exception) async def unhandled_exception_handler(request: Request, exc: Exception): - """处理未捕获的异常""" + """处理未捕获的异常,支持国际化""" + # 获取当前语言 + language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE) + + # 获取翻译服务 + translation_service = get_translation_service() + # 记录完整的堆栈跟踪(日志过滤器会自动过滤敏感信息) logger.error( f"Unhandled exception: {exc}", @@ -386,6 +555,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception): "path": request.url.path, "method": request.method, "exception_type": type(exc).__name__, + "language": language, "traceback": traceback.format_exc() }, exc_info=True @@ -394,7 +564,11 @@ async def unhandled_exception_handler(request: Request, exc: Exception): # 生产环境隐藏详细错误信息 environment = os.getenv("ENVIRONMENT", "development") if environment == "production": - message = "服务器内部错误,请稍后重试" + # 使用翻译的通用错误消息 + message = translation_service.translate( + "errors.common.internal_error", + language + ) else: # 开发环境也要过滤敏感信息 message = SensitiveDataFilter.filter_string(str(exc)) diff --git a/api/app/models/memory_perceptual_model.py b/api/app/models/memory_perceptual_model.py index cafb18d4..9fed7c5d 100644 --- a/api/app/models/memory_perceptual_model.py +++ b/api/app/models/memory_perceptual_model.py @@ -7,6 +7,7 @@ from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import JSONB from app.db import Base +from app.schemas import FileType class PerceptualType(IntEnum): @@ -15,6 +16,16 @@ class PerceptualType(IntEnum): TEXT = 3 CONVERSATION = 4 + @staticmethod + def trans_from_file_type(file_type: FileType | str): + type_map = { + FileType.IMAGE: PerceptualType.VISION, + FileType.AUDIO: PerceptualType.AUDIO, + FileType.VIDEO: PerceptualType.VISION, + FileType.DOCUMENT: PerceptualType.TEXT + } + return type_map.get(file_type, PerceptualType.TEXT) + class FileStorageService(IntEnum): LOCAL = 1 diff --git a/api/app/models/tenant_model.py b/api/app/models/tenant_model.py index 54a3e347..044857d2 100644 --- a/api/app/models/tenant_model.py +++ b/api/app/models/tenant_model.py @@ -1,7 +1,7 @@ import datetime import uuid -from sqlalchemy import Column, String, DateTime, Boolean -from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy import Column, String, DateTime, Boolean, text +from sqlalchemy.dialects.postgresql import UUID, ARRAY from sqlalchemy.orm import relationship from app.db import Base @@ -20,6 +20,10 @@ class Tenants(Base): external_id = Column(String(100), nullable=True, index=True) # 外部企业ID external_source = Column(String(50), nullable=True) # 来源系统 + # 国际化语言配置字段 + default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言 + supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表 + # Relationship to users - one tenant has many users users = relationship("User", back_populates="tenant") diff --git a/api/app/models/user_model.py b/api/app/models/user_model.py index 663bfc71..b6de28ec 100644 --- a/api/app/models/user_model.py +++ b/api/app/models/user_model.py @@ -1,6 +1,6 @@ import datetime import uuid -from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey +from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, text from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship from app.db import Base @@ -22,6 +22,9 @@ class User(Base): external_id = Column(String(100), nullable=True) # 外部用户ID external_source = Column(String(50), nullable=True) # 来源系统 + # 用户语言偏好 + preferred_language = Column(String(10), server_default=text("'zh'"), default='zh', nullable=False, index=True) # 用户偏好语言,默认中文 + current_workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=True) # 当前工作空间ID,可为空 # Foreign key to tenant - each user belongs to exactly one tenant diff --git a/api/app/repositories/memory_perceptual_repository.py b/api/app/repositories/memory_perceptual_repository.py index 9fa9536e..9077af03 100644 --- a/api/app/repositories/memory_perceptual_repository.py +++ b/api/app/repositories/memory_perceptual_repository.py @@ -2,7 +2,7 @@ import uuid from datetime import datetime from typing import List, Tuple, Optional -from sqlalchemy import and_, desc +from sqlalchemy import and_, desc, select from sqlalchemy.orm import Session from app.core.logging_config import get_db_logger @@ -127,6 +127,17 @@ class MemoryPerceptualRepository: db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}") raise + def get_by_url( + self, + file_url: str + ) -> list[MemoryPerceptualModel]: + try: + stmt = select(MemoryPerceptualModel).where(MemoryPerceptualModel.file_path == file_url) + return list(self.db.execute(stmt).scalars()) + except Exception: + db_logger.error(f"Failed to query perceptual memories by file_url: file_url={file_url}") + raise + def get_by_type( self, end_user_id: uuid.UUID, diff --git a/api/app/schemas/i18n_schema.py b/api/app/schemas/i18n_schema.py new file mode 100644 index 00000000..b2ae93c6 --- /dev/null +++ b/api/app/schemas/i18n_schema.py @@ -0,0 +1,73 @@ +""" +I18n Management API Schemas + +This module defines Pydantic schemas for i18n management APIs. +""" + +from pydantic import BaseModel, Field +from typing import Dict, List, Optional, Any + + +# ============================================================================ +# Language Management Schemas +# ============================================================================ + +class LanguageInfo(BaseModel): + """Language information""" + code: str = Field(..., description="Language code (e.g., 'zh', 'en')") + name: str = Field(..., description="Language name (e.g., 'Chinese', 'English')") + native_name: str = Field(..., description="Native language name (e.g., '中文', 'English')") + is_enabled: bool = Field(..., description="Whether the language is enabled") + is_default: bool = Field(..., description="Whether this is the default language") + + +class LanguageListResponse(BaseModel): + """Response for language list""" + languages: List[LanguageInfo] = Field(..., description="List of available languages") + + +class LanguageCreateRequest(BaseModel): + """Request to add a new language""" + code: str = Field(..., description="Language code (e.g., 'ja', 'ko')", min_length=2, max_length=10) + name: str = Field(..., description="Language name", min_length=1, max_length=100) + native_name: str = Field(..., description="Native language name", min_length=1, max_length=100) + is_enabled: bool = Field(default=True, description="Whether to enable the language") + + +class LanguageUpdateRequest(BaseModel): + """Request to update language configuration""" + is_enabled: Optional[bool] = Field(None, description="Whether the language is enabled") + is_default: Optional[bool] = Field(None, description="Whether this is the default language") + + +# ============================================================================ +# Translation Management Schemas +# ============================================================================ + +class TranslationResponse(BaseModel): + """Response for translation data""" + translations: Dict[str, Dict[str, Any]] = Field( + ..., + description="Translations organized by locale and namespace" + ) + + +class TranslationUpdateRequest(BaseModel): + """Request to update a translation""" + value: str = Field(..., description="New translation value", min_length=1) + description: Optional[str] = Field(None, description="Optional description of the translation") + + +class MissingTranslationsResponse(BaseModel): + """Response for missing translations""" + missing_translations: Dict[str, List[str]] = Field( + ..., + description="Missing translation keys organized by locale" + ) + + +class ReloadResponse(BaseModel): + """Response for translation reload""" + success: bool = Field(..., description="Whether the reload was successful") + reloaded_locales: List[str] = Field(..., description="List of reloaded locales") + total_locales: int = Field(..., description="Total number of available locales") diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index 1a5017eb..26a7390b 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -25,5 +25,6 @@ class AgentMemory_Long_Term(ABC): STRATEGY_CHUNK = "chunk" STRATEGY_TIME = "time" DEFAULT_SCOPE = 6 + TIME_SCOPE=5 diff --git a/api/app/schemas/memory_perceptual_schema.py b/api/app/schemas/memory_perceptual_schema.py index 7dfefe01..c9b741ef 100644 --- a/api/app/schemas/memory_perceptual_schema.py +++ b/api/app/schemas/memory_perceptual_schema.py @@ -1,5 +1,4 @@ import uuid -from datetime import datetime from typing import Optional from pydantic import BaseModel, Field @@ -85,7 +84,6 @@ class Semantic(BaseModel): class Content(BaseModel): - summary: str keywords: list[str] topic: str domain: str diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py index 4f3878ce..058f082d 100644 --- a/api/app/schemas/model_schema.py +++ b/api/app/schemas/model_schema.py @@ -326,3 +326,14 @@ class ModelBaseQuery(BaseModel): is_official: Optional[bool] = Field(None, description="是否官方模型") is_deprecated: Optional[bool] = Field(None, description="是否弃用") search: Optional[str] = Field(None, description="搜索关键词", max_length=255) + +class ModelInfo(BaseModel): + """模型信息Schema""" + model_name: str = Field(..., description="模型名称") + provider: str = Field(..., description="模型提供商") + api_key: str = Field(..., description="API密钥") + api_base: str = Field(..., description="API基础URL") + is_omni: bool = Field(default=False, description="是否为omni模型") + model_type: ModelType = Field(..., description="模型类型") + capability: List[str] = Field(default_factory=list, description="模型能力列表") + diff --git a/api/app/schemas/tenant_schema.py b/api/app/schemas/tenant_schema.py index 6e8bd158..4f49ee88 100644 --- a/api/app/schemas/tenant_schema.py +++ b/api/app/schemas/tenant_schema.py @@ -11,6 +11,8 @@ class TenantBase(BaseModel): name: str = Field(..., description="租户名称", max_length=255) description: Optional[str] = Field(None, description="租户描述", max_length=1000) is_active: bool = Field(True, description="是否激活") + default_language: Optional[str] = Field('zh', description="租户默认语言", max_length=10) + supported_languages: Optional[List[str]] = Field(['zh', 'en'], description="租户支持的语言列表") @field_validator('name') @classmethod @@ -18,6 +20,26 @@ class TenantBase(BaseModel): if not v or not v.strip(): raise ValidationException('租户名称不能为空', code=BizCode.VALIDATION_FAILED) return v.strip() + + @field_validator('default_language') + @classmethod + def validate_default_language(cls, v): + if v: + # Validate language code format (2-letter code, optionally with region) + import re + if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', v): + raise ValidationException('语言代码格式不正确', code=BizCode.VALIDATION_FAILED) + return v + + @field_validator('supported_languages') + @classmethod + def validate_supported_languages(cls, v): + if v: + import re + for lang in v: + if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', lang): + raise ValidationException(f'语言代码格式不正确: {lang}', code=BizCode.VALIDATION_FAILED) + return v class TenantCreate(TenantBase): @@ -30,6 +52,8 @@ class TenantUpdate(BaseModel): name: Optional[str] = Field(None, description="租户名称", max_length=255) description: Optional[str] = Field(None, description="租户描述", max_length=1000) is_active: Optional[bool] = Field(None, description="是否激活") + default_language: Optional[str] = Field(None, description="租户默认语言", max_length=10) + supported_languages: Optional[List[str]] = Field(None, description="租户支持的语言列表") @field_validator('name') @classmethod @@ -37,6 +61,25 @@ class TenantUpdate(BaseModel): if v is not None and (not v or not v.strip()): raise ValidationException('租户名称不能为空', code=BizCode.VALIDATION_FAILED) return v.strip() if v else v + + @field_validator('default_language') + @classmethod + def validate_default_language(cls, v): + if v: + import re + if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', v): + raise ValidationException('语言代码格式不正确', code=BizCode.VALIDATION_FAILED) + return v + + @field_validator('supported_languages') + @classmethod + def validate_supported_languages(cls, v): + if v: + import re + for lang in v: + if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', lang): + raise ValidationException(f'语言代码格式不正确: {lang}', code=BizCode.VALIDATION_FAILED) + return v class Tenant(TenantBase): @@ -62,4 +105,29 @@ class TenantList(BaseModel): total: int page: int size: int - pages: int \ No newline at end of file + pages: int + + +class TenantLanguageConfig(BaseModel): + """租户语言配置Schema""" + default_language: str = Field(..., description="租户默认语言", max_length=10) + supported_languages: List[str] = Field(..., description="租户支持的语言列表") + + @field_validator('default_language') + @classmethod + def validate_default_language(cls, v): + import re + if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', v): + raise ValidationException('语言代码格式不正确', code=BizCode.VALIDATION_FAILED) + return v + + @field_validator('supported_languages') + @classmethod + def validate_supported_languages(cls, v): + if not v: + raise ValidationException('支持的语言列表不能为空', code=BizCode.VALIDATION_FAILED) + import re + for lang in v: + if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', lang): + raise ValidationException(f'语言代码格式不正确: {lang}', code=BizCode.VALIDATION_FAILED) + return v diff --git a/api/app/schemas/user_schema.py b/api/app/schemas/user_schema.py index 7b9e201d..6b880696 100644 --- a/api/app/schemas/user_schema.py +++ b/api/app/schemas/user_schema.py @@ -58,6 +58,16 @@ class VerifyPasswordRequest(BaseModel): password: str = Field(..., description="密码") +class LanguagePreferenceRequest(BaseModel): + """语言偏好设置请求""" + language: str = Field(..., min_length=2, max_length=10, description="语言代码,如 'zh', 'en'") + + +class LanguagePreferenceResponse(BaseModel): + """语言偏好响应""" + language: str = Field(..., description="当前语言偏好") + + class ChangePasswordResponse(BaseModel): """修改密码响应""" message: str @@ -74,6 +84,7 @@ class User(UserBase): current_workspace_id: Optional[uuid.UUID] = None current_workspace_name: Optional[str] = None role: Optional[WorkspaceRole] = None + preferred_language: Optional[str] = "zh" # 用户语言偏好 # 将 datetime 转换为毫秒时间戳 @validator("created_at", pre=True) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index f3cdde2a..9b2b2a77 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -8,25 +8,21 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List from fastapi import Depends from sqlalchemy.orm import Session -from app.core.agent.agent_middleware import AgentMiddleware from app.core.agent.langchain_agent import LangChainAgent -from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.db import get_db -from app.models import MultiAgentConfig, AgentConfig +from app.models import MultiAgentConfig, AgentConfig, ModelType from app.models import WorkflowConfig from app.repositories.tool_repository import ToolRepository from app.schemas import DraftRunRequest from app.schemas.app_schema import FileInput +from app.schemas.model_schema import ModelInfo from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.services.conversation_service import ConversationService -from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \ - AgentRunService -from app.services.draft_run_service import create_web_search_tool +from app.services.draft_run_service import AgentRunService from app.services.model_service import ModelApiKeyService from app.services.multi_agent_orchestrator import MultiAgentOrchestrator from app.services.multimodal_service import MultimodalService -from app.services.tool_service import ToolService from app.services.workflow_service import WorkflowService logger = get_business_logger() @@ -126,8 +122,17 @@ class AppChatService: # 处理多模态文件 processed_files = None if files: - multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni) - processed_files = await multimodal_service.process_files(files) + model_info = ModelInfo( + model_name=api_key_obj.model_name, + provider=api_key_obj.provider, + api_key=api_key_obj.api_key, + api_base=api_key_obj.api_base, + capability=api_key_obj.capability, + is_omni=api_key_obj.is_omni, + model_type=ModelType.LLM + ) + multimodal_service = MultimodalService(self.db, model_info) + processed_files = await multimodal_service.process_files(user_id, files) logger.info(f"处理了 {len(processed_files)} 个文件") # 调用 Agent(支持多模态) @@ -266,8 +271,17 @@ class AppChatService: # 处理多模态文件 processed_files = None if files: - multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni) - processed_files = await multimodal_service.process_files(files) + model_info = ModelInfo( + model_name=api_key_obj.model_name, + provider=api_key_obj.provider, + api_key=api_key_obj.api_key, + api_base=api_key_obj.api_base, + capability=api_key_obj.capability, + is_omni=api_key_obj.is_omni, + model_type=ModelType.LLM + ) + multimodal_service = MultimodalService(self.db, model_info) + processed_files = await multimodal_service.process_files(user_id, files) logger.info(f"处理了 {len(processed_files)} 个文件") # 流式调用 Agent(支持多模态) diff --git a/api/app/services/audio_transcription_service.py b/api/app/services/audio_transcription_service.py index 11d13f38..8b94bbe2 100644 --- a/api/app/services/audio_transcription_service.py +++ b/api/app/services/audio_transcription_service.py @@ -75,7 +75,7 @@ class AudioTranscriptionService: try: # 下载音频文件 async with httpx.AsyncClient(timeout=60.0) as client: - audio_response = await client.get(audio_url) + audio_response = await client.get(audio_url, follow_redirects=True) audio_response.raise_for_status() audio_data = audio_response.content diff --git a/api/app/services/auth_service.py b/api/app/services/auth_service.py index 03e1ebc0..436a5c96 100644 --- a/api/app/services/auth_service.py +++ b/api/app/services/auth_service.py @@ -80,6 +80,7 @@ def authenticate_user_or_raise(db: Session, email: str, password: str) -> User: from app.core.exceptions import BusinessException from app.core.error_codes import BizCode from app.core.logging_config import get_auth_logger + from app.i18n.service import t logger = get_auth_logger() @@ -87,17 +88,17 @@ def authenticate_user_or_raise(db: Session, email: str, password: str) -> User: user = user_repository.get_user_by_email(db, email=email) if not user: logger.warning(f"用户不存在: {email}") - raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) + raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND) # 检查用户状态 if not user.is_active: logger.warning(f"用户未激活: {email}") - raise BusinessException("用户未激活", code=BizCode.USER_NOT_FOUND) + raise BusinessException(t("auth.login.account_disabled"), code=BizCode.USER_NOT_FOUND) # 验证密码 if not verify_password(password, user.hashed_password): logger.warning(f"密码错误: {email}") - raise BusinessException("密码错误", code=BizCode.PASSWORD_ERROR) + raise BusinessException(t("auth.password.incorrect"), code=BizCode.PASSWORD_ERROR) logger.info(f"用户认证成功: {email}") return user @@ -254,6 +255,8 @@ def decode_access_token(token: str) -> dict: Raises: BusinessException: token 无效 """ + from app.i18n.service import t + try: payload = jwt.decode(token, TOKEN_SECRET_KEY, algorithms=[TOKEN_ALGORITHM]) return { @@ -261,4 +264,4 @@ def decode_access_token(token: str) -> dict: "share_token": payload["share_token"] } except jwt.InvalidTokenError: - raise BusinessException("无效的访问 token", BizCode.INVALID_TOKEN) \ No newline at end of file + raise BusinessException(t("auth.token.invalid"), BizCode.INVALID_TOKEN) \ No newline at end of file diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index d7914db5..b3b136a1 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -23,9 +23,10 @@ from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.rag.nlp.search import knowledge_retrieval from app.db import get_db_context -from app.models import AgentConfig, ModelConfig +from app.models import AgentConfig, ModelConfig, ModelType from app.repositories.tool_repository import ToolRepository from app.schemas.app_schema import FileInput +from app.schemas.model_schema import ModelInfo from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.services import task_service from app.services.conversation_service import ConversationService @@ -501,9 +502,18 @@ class AgentRunService: processed_files = None if files: # 获取 provider 信息 + model_info = ModelInfo( + model_name=api_key_config["model_name"], + provider=api_key_config["provider"], + api_key=api_key_config["api_key"], + api_base=api_key_config["api_base"], + capability=api_key_config["capability"], + is_omni=api_key_config["is_omni"], + model_type=ModelType.LLM + ) provider = api_key_config.get("provider", "openai") - multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False)) - processed_files = await multimodal_service.process_files(files) + multimodal_service = MultimodalService(self.db, model_info) + processed_files = await multimodal_service.process_files(user_id, files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") # 7. 知识库检索 @@ -704,9 +714,18 @@ class AgentRunService: processed_files = None if files: # 获取 provider 信息 + model_info = ModelInfo( + model_name=api_key_config["model_name"], + provider=api_key_config["provider"], + api_key=api_key_config["api_key"], + api_base=api_key_config["api_base"], + capability=api_key_config["capability"], + is_omni=api_key_config["is_omni"], + model_type=ModelType.LLM + ) provider = api_key_config.get("provider", "openai") - multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False)) - processed_files = await multimodal_service.process_files(files) + multimodal_service = MultimodalService(self.db, model_info) + processed_files = await multimodal_service.process_files(user_id, files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") # 7. 知识库检索 @@ -841,7 +860,8 @@ class AgentRunService: "api_key": api_key.api_key, "api_base": api_key.api_base, "api_key_id": api_key.id, - "is_omni": api_key.is_omni + "is_omni": api_key.is_omni, + "capability": api_key.capability } async def _ensure_conversation( diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index a20b968a..1e1d9e45 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -274,7 +274,7 @@ class MemoryAgentService: Args: end_user_id: Group identifier (also used as end_user_id) - message: Message to write + messages: Message to write config_id: Configuration ID from database db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index b9d96a0b..53d935fe 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -1,19 +1,27 @@ +import os import uuid from typing import Dict, Any, Optional +from urllib.parse import urlparse, unquote +import json_repair +from jinja2 import Template from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger +from app.core.models import RedBearLLM, RedBearModelConfig from app.models.memory_perceptual_model import PerceptualType, FileStorageService +from app.models.prompt_optimizer_model import RoleType from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository +from app.schemas import FileType from app.schemas.memory_perceptual_schema import ( PerceptualQuerySchema, PerceptualTimelineResponse, PerceptualMemoryItem, AudioModal, Content, VideoModal, TextModal ) +from app.schemas.model_schema import ModelInfo business_logger = get_business_logger() @@ -99,7 +107,7 @@ class MemoryPerceptualService: "keywords": content.keywords, "topic": content.topic, "domain": content.domain, - "created_time": int(memory.created_time.timestamp()*1000), + "created_time": int(memory.created_time.timestamp() * 1000), **detail } @@ -108,7 +116,8 @@ class MemoryPerceptualService: return result except Exception as e: - business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}") + business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}", + exc_info=True) raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}", BizCode.DB_ERROR) @@ -138,7 +147,7 @@ class MemoryPerceptualService: for memory in memories: meta_data = memory.meta_data or {} content = meta_data.get("content", {}) - + # 安全地提取 content 字段,提供默认值 if content: content_obj = Content(**content) @@ -149,7 +158,7 @@ class MemoryPerceptualService: topic = "Unknown" domain = "Unknown" keywords = [] - + memory_item = PerceptualMemoryItem( id=memory.id, perceptual_type=PerceptualType(memory.perceptual_type), @@ -161,7 +170,7 @@ class MemoryPerceptualService: topic=topic, domain=domain, keywords=keywords, - created_time=int(memory.created_time.timestamp()*1000), + created_time=int(memory.created_time.timestamp() * 1000), storage_service=FileStorageService(memory.storage_service), ) memory_items.append(memory_item) @@ -183,3 +192,98 @@ class MemoryPerceptualService: except Exception as e: business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}") raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR) + + async def generate_perceptual_memory( + self, + end_user_id: str, + model_config: ModelInfo, + file_type: str, + file_url: str, + file_message: dict, + ): + memories = self.repository.get_by_url(file_url) + if memories: + business_logger.info(f"Perceptual memory already exists: {file_url}") + if end_user_id not in [memory.end_user_id for memory in memories]: + business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}") + memory_cache = memories[0] + self.repository.create_perceptual_memory( + end_user_id=uuid.UUID(end_user_id), + perceptual_type=PerceptualType(memory_cache.perceptual_type), + file_path=memory_cache.file_path, + file_name=memory_cache.file_name, + file_ext=memory_cache.file_ext, + summary=memory_cache.summary, + meta_data=memory_cache.meta_data + ) + self.db.commit() + + return + llm = RedBearLLM(RedBearModelConfig( + model_name=model_config.model_name, + provider=model_config.provider, + api_key=model_config.api_key, + base_url=model_config.api_base, + is_omni=model_config.is_omni + ), type=model_config.model_type) + try: + prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') + with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f: + opt_system_prompt = f.read() + rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh') + except FileNotFoundError: + raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) + messages = [ + {"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]}, + {"role": RoleType.USER.value, "content": [ + {"type": "text", "text": "Summarize the following file"}, file_message + ]} + ] + result = await llm.ainvoke(messages) + content = json_repair.repair_json(result.content, return_objects=True) + path = urlparse(file_url).path + filename = os.path.basename(path) + filename = unquote(filename) + file_ext = os.path.splitext(filename)[1] + if not file_ext: + if file_type == FileType.AUDIO: + file_ext = ".mp3" + elif file_type == FileType.VIDEO: + file_ext = ".mp4" + elif file_type == FileType.DOCUMENT: + file_ext = ".txt" + elif file_type == FileType.IMAGE: + file_ext = ".jpg" + filename += file_ext + file_content = { + "keywords": content.get("keywords", []), + "topic": content.get("topic"), + "domain": content.get("domain") + } + if file_type in [FileType.IMAGE, FileType.VIDEO]: + file_modalities = { + "scene": content.get("scene") + } + elif file_type in [FileType.DOCUMENT]: + file_modalities = { + "section_count": content.get("section_count"), + "title": content.get("title"), + "first_line": content.get("first_line") + } + else: + file_modalities = { + "speaker_count": content.get("speaker_count") + } + self.repository.create_perceptual_memory( + end_user_id=uuid.UUID(end_user_id), + perceptual_type=PerceptualType.trans_from_file_type(file_type), + file_path=file_url, + file_name=filename, + file_ext=file_ext, + summary=content.get('summary'), + meta_data={ + "content": file_content, + "modalities": file_modalities + } + ) + self.db.commit() diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index fffca2e5..b30b48b2 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -10,6 +10,7 @@ """ import base64 import io +import uuid from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional @@ -23,9 +24,12 @@ from app.core.config import settings from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger +from app.models import ModelApiKey from app.models.file_metadata_model import FileMetadata from app.schemas.app_schema import FileInput, FileType, TransferMethod +from app.schemas.model_schema import ModelInfo from app.services.audio_transcription_service import AudioTranscriptionService +from app.tasks import write_perceptual_memory logger = get_business_logger() @@ -39,6 +43,7 @@ DOC_MIME = [ class MultimodalFormatStrategy(ABC): """多模态格式策略基类""" + def __init__(self, file: FileInput): self.file = file @@ -95,7 +100,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy): if transcription: return { "type": "text", - "text": f"" + "text": f"" } # 通义千问音频格式:{"type": "audio", "audio": "url"} return { @@ -125,7 +130,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy): # 下载图片 if content is None: async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get(url) + response = await client.get(url, follow_redirects=True) response.raise_for_status() content = response.content self.file.set_content(content) @@ -231,7 +236,7 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy): audio_data = content if content is None: async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get(url) + response = await client.get(url, follow_redirects=True) response.raise_for_status() audio_data = response.content self.file.set_content(audio_data) @@ -284,34 +289,56 @@ PROVIDER_STRATEGIES = { class MultimodalService: - """多模态文件处理服务""" + """ + Service for handling multimodal file processing. - def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, - enable_audio_transcription: bool = False, is_omni: bool = False): + Attributes: + db (Session): Database session. + model_api_key (str): API key for the model provider. + provider (str): Name of the model provider. + is_omni (bool): Indicates whether the model supports full multimodal capability. + capability (list): Capability configuration of the model. + audio_api_key (str | None): API key used for audio transcription. + enable_audio_transcription (bool): Whether audio transcription is enabled. + """ + + def __init__( + self, + db: Session, + api_config: ModelInfo | None = None, + audio_api_key: Optional[str] = None, + enable_audio_transcription: bool = False, + ): """ - 初始化多模态服务 - + Initialize the multimodal service. + Args: - db: 数据库会话 - provider: 模型提供商(dashscope, bedrock, anthropic, openai 等) - api_key: API 密钥(用于音频转文本) - enable_audio_transcription: 是否启用音频转文本 - is_omni: 是否为 Omni 模型(dashscope 的 omni 模型需要使用 OpenAI 兼容格式) + db (Session): Database session. + api_config (ModelApiKey | None): Model API configuration. + audio_api_key (str | None): API key for audio transcription. + enable_audio_transcription (bool): Enable audio transcription. """ self.db = db - self.provider = provider.lower() - self.api_key = api_key + self.api_config = api_config + if self.api_config is not None: + self.model_api_key = api_config.api_key + self.provider = api_config.provider.lower() + self.is_omni = api_config.is_omni + self.capability = api_config.capability + self.audio_api_key = audio_api_key self.enable_audio_transcription = enable_audio_transcription - self.is_omni = is_omni async def process_files( self, - files: Optional[List[FileInput]] + end_user_id: uuid.UUID | str, + files: Optional[List[FileInput]], + ) -> List[Dict[str, Any]]: """ 处理文件列表,返回 LLM 可用的格式 Args: + end_user_id: 用户ID files: 文件输入列表 Returns: @@ -319,6 +346,8 @@ class MultimodalService: """ if not files: return [] + if isinstance(end_user_id, uuid.UUID): + end_user_id = str(end_user_id) # 获取对应的策略 # dashscope 的 omni 模型使用 OpenAI 兼容格式 @@ -333,19 +362,25 @@ class MultimodalService: result = [] for idx, file in enumerate(files): strategy = strategy_class(file) + if not file.url: + file.url = await self.get_file_url(file) try: - if file.type == FileType.IMAGE: + if file.type == FileType.IMAGE and "vision" in self.capability: content = await self._process_image(file, strategy) result.append(content) + self.write_perceptual_memory(end_user_id, file.type, file.url, content) elif file.type == FileType.DOCUMENT: content = await self._process_document(file, strategy) result.append(content) - elif file.type == FileType.AUDIO: + self.write_perceptual_memory(end_user_id, file.type, file.url, content) + elif file.type == FileType.AUDIO and "audio" in self.capability: content = await self._process_audio(file, strategy) result.append(content) - elif file.type == FileType.VIDEO: + self.write_perceptual_memory(end_user_id, file.type, file.url, content) + elif file.type == FileType.VIDEO and "video" in self.capability: content = await self._process_video(file, strategy) result.append(content) + self.write_perceptual_memory(end_user_id, file.type, file.url, content) else: logger.warning(f"不支持的文件类型: {file.type}") except Exception as e: @@ -355,7 +390,8 @@ class MultimodalService: "file_index": idx, "file_type": file.type, "error": str(e) - } + }, + exc_info=True ) # 继续处理其他文件,不中断整个流程 result.append({ @@ -366,6 +402,17 @@ class MultimodalService: logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") return result + def write_perceptual_memory( + self, + end_user_id: str, + file_type: str, + file_url: str, + file_message: dict + ): + """写入感知记忆""" + if end_user_id and self.api_config: + write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message) + async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理图片文件 @@ -387,43 +434,6 @@ class MultimodalService: "text": f"[图片处理失败: {str(e)}]" } - @staticmethod - async def _download_and_encode_image(url: str) -> tuple[str, str]: - """ - 下载图片并转换为 base64 - - Args: - url: 图片 URL - - Returns: - tuple: (base64_data, media_type) - """ - from mimetypes import guess_type - - # 下载图片 - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get(url) - response.raise_for_status() - - # 获取图片数据 - image_data = response.content - - # 确定 media type - content_type = response.headers.get("content-type") - if content_type and content_type.startswith("image/"): - media_type = content_type - else: - # 从 URL 推断 - guessed_type, _ = guess_type(url) - media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg" - - # 转换为 base64 - base64_data = base64.b64encode(image_data).decode("utf-8") - - logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}") - - return base64_data, media_type - async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理文档文件(PDF、Word 等) @@ -436,7 +446,6 @@ class MultimodalService: Dict: 根据 provider 返回不同格式的文档内容 """ if file.transfer_method == TransferMethod.REMOTE_URL: - # 远程文档暂不支持提取 return { "type": "text", "text": f"\n{await self._extract_document_text(file)}\n" @@ -471,12 +480,12 @@ class MultimodalService: # 如果启用音频转文本且有 API Key transcription = None - if self.enable_audio_transcription and self.api_key: + if self.enable_audio_transcription and self.audio_api_key: logger.info(f"开始音频转文本: {url}") if self.provider == "dashscope": - transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key) + transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.audio_api_key) elif self.provider == "openai": - transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key) + transcription = await AudioTranscriptionService.transcribe_openai(url, self.audio_api_key) else: logger.warning(f"Provider {self.provider} 不支持音频转文本") @@ -557,7 +566,7 @@ class MultimodalService: file_content = file.get_content() if not file_content: async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get(file.url) + response = await client.get(file.url, follow_redirects=True) response.raise_for_status() file_content = response.content file.set_content(file_content) diff --git a/api/app/services/prompt/perceptual_summary_system.jinja2 b/api/app/services/prompt/perceptual_summary_system.jinja2 new file mode 100644 index 00000000..ee5d3eb5 --- /dev/null +++ b/api/app/services/prompt/perceptual_summary_system.jinja2 @@ -0,0 +1,53 @@ +{% raw %}You are a professional information extraction system. + +Your task is to analyze the provided document content and generate structured metadata. + +Extract the following fields: + +* **summary**: A concise summary of the document in 2–4 sentences. +* **keywords**: 5–10 important keywords or key phrases that best represent the document. This field MUST be a JSON array of strings. +* **topic**: The primary topic of the document expressed as a short phrase (3–8 words). +* **domain**: The broader knowledge domain or field the document belongs to (e.g., Artificial Intelligence, Computer Science, Finance, Healthcare, Education, Law, etc.). + +STRICT RULES: + +1. Output MUST be valid JSON. +2. Do NOT output markdown. +3. Do NOT output explanations. +4. Do NOT output any text before or after the JSON. +5. The JSON MUST contain EXACTLY these four keys: + * summary + * keywords + * topic + * domain{% endraw %} +{% if file_type == 'image' or file_type == 'video' %} * scene {% endif %} +{% if file_type == 'audio' %} * speaker_count {% endif %} +{% if file_type == 'document' %} * section_count + * title + * first_line +{% endif %} +{% raw %} +6. `keywords` MUST be a JSON array of strings. +7. If the document content is insufficient, infer the best possible answer based on context. +8. Ensure the JSON is syntactically correct. +{% endraw %} +9. Output using the language {{ language }} +{% raw %} +Required JSON format: + +{ +"summary": "string", +"keywords": ["keyword1", "keyword2", "keyword3", "keyword4", "keyword5"], +"topic": "string", +"domain": "string", +{% endraw %} +{% if file_type == 'image' or file_type == 'video' %} "scene": ["string", "string"] {% endif %} +{% if file_type == 'document' %} "section_count": integer +"title": "string", +"first_line": "string" +{% endif %} +{% if file_type == 'audio' %} "speaker_count": integer {% endif %} +{% raw %} +} + +Now analyze the following document and return the JSON result.{% endraw %} diff --git a/api/app/services/tenant_service.py b/api/app/services/tenant_service.py index 2edb46df..066edf57 100644 --- a/api/app/services/tenant_service.py +++ b/api/app/services/tenant_service.py @@ -217,4 +217,55 @@ class TenantService: skip=skip, limit=limit, is_active=is_active - ) \ No newline at end of file + ) + + def get_tenant_language_config(self, tenant_id: uuid.UUID) -> Optional[dict]: + """获取租户语言配置""" + tenant = self.tenant_repo.get_tenant_by_id(tenant_id) + if not tenant: + raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND) + + return { + "default_language": tenant.default_language, + "supported_languages": tenant.supported_languages + } + + def update_tenant_language_config( + self, + tenant_id: uuid.UUID, + default_language: str, + supported_languages: list + ) -> Optional[dict]: + """更新租户语言配置""" + # 检查租户是否存在 + tenant = self.tenant_repo.get_tenant_by_id(tenant_id) + if not tenant: + raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND) + + # 验证默认语言在支持的语言列表中 + if default_language not in supported_languages: + raise BusinessException( + "默认语言必须在支持的语言列表中", + code=BizCode.VALIDATION_FAILED + ) + + try: + # 更新语言配置 + tenant.default_language = default_language + tenant.supported_languages = supported_languages + self.db.commit() + self.db.refresh(tenant) + + business_logger.info( + f"更新租户语言配置成功: {tenant.name} (ID: {tenant.id}), " + f"默认语言: {default_language}, 支持语言: {supported_languages}" + ) + + return { + "default_language": tenant.default_language, + "supported_languages": tenant.supported_languages + } + except Exception as e: + self.db.rollback() + business_logger.error(f"更新租户语言配置失败: {str(e)}") + raise BusinessException(f"更新租户语言配置失败: {str(e)}", code=BizCode.DB_ERROR) diff --git a/api/app/services/user_service.py b/api/app/services/user_service.py index 22dabed7..e23b1ac3 100644 --- a/api/app/services/user_service.py +++ b/api/app/services/user_service.py @@ -438,24 +438,26 @@ def update_last_login_time(db: Session, user_id: uuid.UUID) -> User: async def change_password(db: Session, user_id: uuid.UUID, old_password: str, new_password: str, current_user: User) -> User: """普通用户修改自己的密码""" + from app.i18n.service import t + business_logger.info(f"用户修改密码请求: user_id={user_id}, current_user={current_user.id}") # 检查权限:只能修改自己的密码 if current_user.id != user_id: business_logger.warning(f"用户尝试修改他人密码: current_user={current_user.id}, target_user={user_id}") - raise PermissionDeniedException("You can only change your own password") + raise PermissionDeniedException(t("auth.password.change_failed")) try: # 获取用户 db_user = user_repository.get_user_by_id(db=db, user_id=user_id) if not db_user: business_logger.warning(f"用户不存在: {user_id}") - raise BusinessException("User not found", code=BizCode.USER_NOT_FOUND) + raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND) # 验证旧密码 if not verify_password(old_password, db_user.hashed_password): business_logger.warning(f"用户旧密码验证失败: {user_id}") - raise BusinessException("当前密码不正确", code=BizCode.VALIDATION_FAILED) + raise BusinessException(t("auth.password.incorrect"), code=BizCode.VALIDATION_FAILED) # 更新密码 db_user.hashed_password = get_password_hash(new_password) @@ -471,7 +473,7 @@ async def change_password(db: Session, user_id: uuid.UUID, old_password: str, ne except Exception as e: business_logger.error(f"修改用户密码失败: user_id={user_id} - {str(e)}") db.rollback() - raise BusinessException(f"修改用户密码失败: user_id={user_id} - {str(e)}", code=BizCode.DB_ERROR) + raise BusinessException(t("auth.password.change_failed"), code=BizCode.DB_ERROR) async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_password: str = None, current_user: User = None) -> tuple[User, str]: @@ -487,6 +489,8 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass Returns: tuple[User, str]: (更新后的用户对象, 实际使用的密码) """ + from app.i18n.service import t + business_logger.info(f"管理员修改用户密码请求: admin={current_user.id}, target_user={target_user_id}") # 检查权限:只有超级管理员可以修改他人密码 @@ -496,7 +500,7 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass try: permission_service.check_superuser( subject, - error_message="只有超级管理员可以修改他人密码" + error_message=t("auth.password.change_failed") ) except PermissionDeniedException as e: business_logger.warning(f"非超管用户尝试修改他人密码: current_user={current_user.id}") @@ -507,12 +511,12 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass target_user = user_repository.get_user_by_id(db=db, user_id=target_user_id) if not target_user: business_logger.warning(f"目标用户不存在: {target_user_id}") - raise BusinessException("目标用户不存在", code=BizCode.USER_NOT_FOUND) + raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND) # 检查租户权限:超管只能修改同租户用户的密码 if current_user.tenant_id != target_user.tenant_id: business_logger.warning(f"跨租户密码修改尝试: admin_tenant={current_user.tenant_id}, target_tenant={target_user.tenant_id}") - raise BusinessException("不可跨租户修改用户密码", code=BizCode.FORBIDDEN) + raise BusinessException(t("auth.password.change_failed"), code=BizCode.FORBIDDEN) # 如果没有提供新密码,则生成随机密码 actual_password = new_password if new_password else generate_random_password() @@ -532,7 +536,7 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass except Exception as e: business_logger.error(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}") db.rollback() - raise BusinessException(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}", code=BizCode.DB_ERROR) + raise BusinessException(t("auth.password.change_failed"), code=BizCode.DB_ERROR) def generate_random_password(length: int = 12) -> str: @@ -740,3 +744,54 @@ async def verify_and_change_email(db: Session, user_id: uuid.UUID, new_email: Em # # business_logger.info(f"用户邮箱修改成功: {db_user.username}, new_email={new_email}") # return db_user + + +def get_user_language_preference(db: Session, user_id: uuid.UUID, current_user: User) -> str: + """获取用户语言偏好""" + business_logger.info(f"获取用户语言偏好: user_id={user_id}") + + # 权限检查:只能获取自己的语言偏好 + if current_user.id != user_id: + raise PermissionDeniedException("只能获取自己的语言偏好") + + db_user = user_repository.get_user_by_id(db=db, user_id=user_id) + if not db_user: + raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) + + language = db_user.preferred_language or "zh" + business_logger.info(f"用户语言偏好: {db_user.username}, language={language}") + return language + + +def update_user_language_preference( + db: Session, + user_id: uuid.UUID, + language: str, + current_user: User +) -> User: + """更新用户语言偏好""" + business_logger.info(f"更新用户语言偏好: user_id={user_id}, language={language}") + + # 权限检查:只能修改自己的语言偏好 + if current_user.id != user_id: + raise PermissionDeniedException("只能修改自己的语言偏好") + + # 验证语言代码是否支持 + from app.core.config import settings + if language not in settings.I18N_SUPPORTED_LANGUAGES: + raise BusinessException( + f"不支持的语言代码: {language}。支持的语言: {', '.join(settings.I18N_SUPPORTED_LANGUAGES)}", + code=BizCode.VALIDATION_FAILED + ) + + db_user = user_repository.get_user_by_id(db=db, user_id=user_id) + if not db_user: + raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) + + # 更新语言偏好 + db_user.preferred_language = language + db.commit() + db.refresh(db_user) + + business_logger.info(f"用户语言偏好更新成功: {db_user.username}, language={language}") + return db_user diff --git a/api/app/tasks.py b/api/app/tasks.py index 6fd9c954..5e1550bd 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,6 +1,5 @@ import asyncio -import json -import logging +import hashlib import os import re import shutil @@ -11,20 +10,48 @@ from datetime import datetime, timezone from math import ceil from pathlib import Path from typing import Any, Dict, List, Optional -from uuid import UUID import redis -import requests from redis.exceptions import RedisError -logger = logging.getLogger(__name__) +# Import a unified Celery instance +from app.celery_app import celery_app +from app.core.config import settings +from app.core.logging_config import get_logger +from app.core.rag.crawler.web_crawler import WebCrawler +from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb +from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache +from app.core.rag.integrations.feishu.client import FeishuAPIClient +from app.core.rag.integrations.feishu.models import FileInfo +from app.core.rag.integrations.yuque.client import YuqueAPIClient +from app.core.rag.integrations.yuque.models import YuqueDocInfo +from app.core.rag.llm.chat_model import Base +from app.core.rag.llm.cv_model import QWenCV +from app.core.rag.llm.embedding_model import OpenAIEmbed +from app.core.rag.llm.sequence2txt_model import QWenSeq2txt +from app.core.rag.models.chunk import DocumentChunk +from app.core.rag.prompts.generator import question_proposal +from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( + ElasticSearchVectorFactory, +) +from app.db import get_db, get_db_context +from app.models import Document, File, Knowledge +from app.schemas import document_schema, file_schema +from app.schemas.model_schema import ModelInfo +from app.services.memory_agent_service import MemoryAgentService +from app.services.memory_perceptual_service import MemoryPerceptualService +from app.utils.config_utils import resolve_config_id +from app.utils.redis_lock import RedisLock + +logger = get_logger(__name__) # 模块级同步 Redis 连接池,供 Celery 任务共享使用 # 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致 # 使用连接池而非单例客户端,提供更好的并发性能和自动重连 -_sync_redis_pool: redis.ConnectionPool = None +_sync_redis_pool: redis.ConnectionPool | None = None -def _get_or_create_redis_pool() -> redis.ConnectionPool: + +def _get_or_create_redis_pool() -> redis.ConnectionPool | None: """获取或创建 Redis 连接池(懒初始化)""" global _sync_redis_pool if _sync_redis_pool is None: @@ -47,6 +74,7 @@ def _get_or_create_redis_pool() -> redis.ConnectionPool: return None return _sync_redis_pool + def get_sync_redis_client() -> Optional[redis.StrictRedis]: """获取同步 Redis 客户端(使用连接池) @@ -60,7 +88,7 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]: pool = _get_or_create_redis_pool() if pool is None: return None - + client = redis.StrictRedis(connection_pool=pool) # 验证连接可用性 client.ping() @@ -72,32 +100,18 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]: logger.error(f"Unexpected error getting Redis client: {e}", exc_info=True) return None -# Import a unified Celery instance -from app.celery_app import celery_app -from app.core.config import settings -from app.core.rag.crawler.web_crawler import WebCrawler -from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb -from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache -from app.core.rag.integrations.feishu.client import FeishuAPIClient -from app.core.rag.integrations.feishu.models import FileInfo -from app.core.rag.integrations.yuque.client import YuqueAPIClient -from app.core.rag.integrations.yuque.models import YuqueDocInfo -from app.core.rag.llm.chat_model import Base -from app.core.rag.llm.cv_model import QWenCV -from app.core.rag.llm.embedding_model import OpenAIEmbed -from app.core.rag.llm.sequence2txt_model import QWenSeq2txt -from app.core.rag.models.chunk import DocumentChunk -from app.core.rag.prompts.generator import question_proposal -from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( - ElasticSearchVectorFactory, -) -from app.db import get_db, get_db_context -from app.models.document_model import Document -from app.models.file_model import File -from app.models.knowledge_model import Knowledge -from app.schemas import document_schema, file_schema -from app.services.memory_agent_service import MemoryAgentService -from app.utils.config_utils import resolve_config_id + +def set_asyncio_event_loop(): + """Set the asyncio event loop for the current thread.""" + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop @celery_app.task(name="tasks.process_item") @@ -294,9 +308,18 @@ def parse_document(file_path: str, document_id: uuid.UUID): vector_size = len(vts[0]) init_graphrag(task, vector_size) - async def _run(row: dict, document_ids: list[str], language: str, parser_config: dict, vector_service, - chat_model, embedding_model, callback, with_resolution: bool = True, - with_community: bool = True, ) -> dict: + async def _run( + row: dict, + document_ids: list[str], + language: str, + parser_config: dict, + vector_service, + chat_model, + embedding_model, + callback, + with_resolution: bool = True, + with_community: bool = True + ) -> dict: await trio.sleep(5) # Delay for 10 seconds nonlocal progress_msg # Declare the use of an external progress_msg variable result = await run_graphrag_for_kb( @@ -329,6 +352,7 @@ def parse_document(file_path: str, document_id: uuid.UUID): with_community=with_community, ) ) + try: with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(sync_task) @@ -448,6 +472,7 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): with_community=with_community, ) ) + try: with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(sync_task) @@ -1002,29 +1027,21 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s # Log but continue - will fail later with proper error pass - async def _run() -> str: + async def _run() -> dict: with get_db_context() as db: service = MemoryAgentService() - return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, - storage_type, user_rag_memory_id) + return await service.read_memory( + end_user_id, + message, + history, + search_switch, + actual_config_id, db, + storage_type, user_rag_memory_id + ) try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time @@ -1056,7 +1073,8 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s @celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, user_rag_memory_id: str, +def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, + user_rag_memory_id: str, language: str = "zh") -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. Args: @@ -1073,10 +1091,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s Raises: Exception on failure """ - from app.core.logging_config import get_logger - logger = get_logger(__name__) - logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id} (type: {type(config_id).__name__}), storage_type={storage_type}, language={language}") + logger.info( + f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, " + f"config_id={config_id} (type: {type(config_id).__name__}), " + f"storage_type={storage_type}, language={language}") start_time = time.time() # Convert config_id to UUID @@ -1086,13 +1105,14 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s try: with get_db_context() as db: actual_config_id = resolve_config_id(config_id, db) - print(100*'-') + print(100 * '-') print(actual_config_id) - print(100*'-') + print(100 * '-') logger.info( f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})") except (ValueError, AttributeError) as e: - logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}") + logger.error( + f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}") return { "status": "FAILURE", "error": f"Invalid config_id format: {config_id} - {str(e)}", @@ -1116,7 +1136,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s async def _run() -> str: with get_db_context() as db: logger.info( - f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") + f"[CELERY WRITE] Executing MemoryAgentService.write_memory " + f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") service = MemoryAgentService() result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, user_rag_memory_id, language) @@ -1124,22 +1145,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s return result try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time @@ -1193,28 +1200,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s } -def reflection_engine() -> None: - """Empty function placeholder for timed background reflection. - - Intentionally left blank; replace with real reflection logic later. - """ - import asyncio - - from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion - - host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122") - asyncio.run(self_reflexion(host_id)) - - -@celery_app.task(name="app.core.memory.agent.reflection.timer") -def reflection_timer_task() -> None: - """Periodic Celery task that invokes reflection_engine. - - Raises an exception on failure. - """ - reflection_engine() - - # unused task # @celery_app.task(name="app.core.memory.agent.health.check_read_service") # def check_read_service_task() -> Dict[str, str]: @@ -1368,6 +1353,8 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: "workspace_id": workspace_id, "elapsed_time": elapsed_time, } + + @celery_app.task( name="app.tasks.write_all_workspaces_memory_task", bind=True, @@ -1391,15 +1378,12 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: start_time = time.time() async def _run() -> Dict[str, Any]: - from app.core.logging_config import get_api_logger from app.models.app_model import App from app.models.end_user_model import EndUser from app.models.workspace_model import Workspace from app.repositories.memory_increment_repository import write_memory_increment from app.services.memory_storage_service import search_all - api_logger = get_api_logger() - with get_db_context() as db: try: # 获取所有活跃的工作空间 @@ -1408,7 +1392,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: ).all() if not workspaces: - api_logger.warning("没有找到活跃的工作空间") + logger.warning("没有找到活跃的工作空间") return { "status": "SUCCESS", "message": "没有找到活跃的工作空间", @@ -1416,13 +1400,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: "workspace_results": [] } - api_logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量") + logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量") all_workspace_results = [] # 遍历每个工作空间 for workspace in workspaces: workspace_id = workspace.id - api_logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})") + logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})") try: # 1. 查询当前workspace下的所有app(仅未删除的) @@ -1447,7 +1431,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: "memory_increment_id": str(memory_increment.id), "created_at": memory_increment.created_at.isoformat(), }) - api_logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0") + logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0") continue # 2. 查询所有app下的end_user_id(去重) @@ -1472,7 +1456,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: }) except Exception as e: # 记录单个用户查询失败,但继续处理其他用户 - api_logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}") + logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}") end_user_details.append({ "end_user_id": str(end_user_id), "total": 0, @@ -1496,13 +1480,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: "created_at": memory_increment.created_at.isoformat(), }) - api_logger.info( + logger.info( f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}" ) except Exception as e: db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间 - api_logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}") + logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}") all_workspace_results.append({ "workspace_id": str(workspace_id), "workspace_name": workspace.name, @@ -1525,7 +1509,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: } except Exception as e: - api_logger.error(f"记忆增量统计任务执行失败: {str(e)}") + logger.error(f"记忆增量统计任务执行失败: {str(e)}") return { "status": "FAILURE", "error": str(e), @@ -1534,22 +1518,8 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: } try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time @@ -1597,11 +1567,9 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: start_time = time.time() async def _run() -> Dict[str, Any]: - from app.core.logging_config import get_logger from app.repositories.end_user_repository import EndUserRepository from app.services.user_memory_service import UserMemoryService - logger = get_logger(__name__) logger.info("开始执行记忆缓存重新生成定时任务") service = UserMemoryService() @@ -1734,22 +1702,8 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: } try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time @@ -1785,15 +1739,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]: start_time = time.time() async def _run() -> Dict[str, Any]: - from app.core.logging_config import get_api_logger from app.models.workspace_model import Workspace from app.services.memory_reflection_service import ( MemoryReflectionService, WorkspaceAppService, ) - api_logger = get_api_logger() - with get_db_context() as db: try: # 获取所有工作空间 @@ -1812,7 +1763,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: # 遍历每个工作空间 for workspace in workspaces: workspace_id = workspace.id - api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}") + logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}") try: reflection_service = MemoryReflectionService(db) @@ -1824,7 +1775,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: workspace_reflection_results = [] for data in result['apps_detailed_info']: - if data['memory_configs'] == []: + if not data['memory_configs']: continue releases = data['releases'] @@ -1835,7 +1786,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str( user['app_id']): # 调用反思服务 - api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") + logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") reflection_result = await reflection_service.start_reflection_from_data( config_data=config, @@ -1855,12 +1806,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]: "reflection_results": workspace_reflection_results }) - api_logger.info( + logger.info( f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务") except Exception as e: db.rollback() # Rollback failed transaction to allow next query - api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") + logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") all_reflection_results.append({ "workspace_id": str(workspace_id), "error": str(e), @@ -1879,7 +1830,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: } except Exception as e: - api_logger.error(f"工作空间反思任务执行失败: {str(e)}") + logger.error(f"工作空间反思任务执行失败: {str(e)}") return { "status": "FAILURE", "error": str(e), @@ -1888,22 +1839,8 @@ def workspace_reflection_task(self) -> Dict[str, Any]: } try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time @@ -1944,18 +1881,16 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di start_time = time.time() async def _run() -> Dict[str, Any]: - from app.core.logging_config import get_api_logger from app.services.memory_forget_service import MemoryForgetService - api_logger = get_api_logger() - with get_db_context() as db: try: - api_logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}") + logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}") forget_service = MemoryForgetService() # 运行遗忘周期 + # FIXME: MemeoryForgetService report = await forget_service.trigger_forgetting( db=db, end_user_id=None, # 处理所有组 @@ -1964,7 +1899,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di duration = time.time() - start_time - api_logger.info( + logger.info( f"遗忘周期定时任务完成: " f"融合 {report['merged_count']} 对节点, " f"失败 {report['failed_count']} 对, " @@ -1980,7 +1915,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di except Exception as e: duration = time.time() - start_time - api_logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True) + logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True) return { "status": "FAILED", @@ -1997,6 +1932,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di finally: loop.close() + # ============================================================================= # Long-term Memory Storage Tasks (Batched Write Strategies) # ============================================================================= @@ -2222,9 +2158,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: start_time = time.time() async def _run() -> Dict[str, Any]: - from sqlalchemy import func, select + from sqlalchemy import select - from app.core.logging_config import get_logger from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage from app.repositories.implicit_emotions_storage_repository import ( ImplicitEmotionsStorageRepository, @@ -2233,7 +2168,6 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: from app.services.emotion_analytics_service import EmotionAnalyticsService from app.services.implicit_memory_service import ImplicitMemoryService - logger = get_logger(__name__) logger.info("开始执行隐性记忆和情绪数据更新定时任务") total_users = 0 @@ -2267,7 +2201,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: for end_user_id in refresh_iter: logger.info(f"开始处理用户: {end_user_id}") user_start_time = time.time() - + implicit_success = False emotion_success = False errors = [] @@ -2318,7 +2252,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: failed += 1 user_elapsed = time.time() - user_start_time - + # 记录用户处理结果 user_result = { "end_user_id": end_user_id, @@ -2460,22 +2394,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: } try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time @@ -2521,14 +2441,12 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str, start_time = time.time() async def _run() -> Dict[str, Any]: - from app.core.logging_config import get_logger from app.repositories.implicit_emotions_storage_repository import ( ImplicitEmotionsStorageRepository, ) from app.services.emotion_analytics_service import EmotionAnalyticsService from app.services.implicit_memory_service import ImplicitMemoryService - logger = get_logger(__name__) logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}") initialized = 0 @@ -2587,20 +2505,7 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str, } try: - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) result["elapsed_time"] = time.time() - start_time @@ -2633,6 +2538,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ 默认生成中文(zh)兴趣分布数据。 Args: + self: task object end_user_ids: 需要检查的用户ID列表 Returns: @@ -2641,11 +2547,9 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ start_time = time.time() async def _run() -> Dict[str, Any]: - from app.core.logging_config import get_logger from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE from app.services.memory_agent_service import MemoryAgentService - logger = get_logger(__name__) logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}") initialized = 0 @@ -2694,20 +2598,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ } try: - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) result["elapsed_time"] = time.time() - start_time @@ -2720,3 +2611,54 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ "elapsed_time": time.time() - start_time, "task_id": self.request.id, } + + +@celery_app.task( + name="app.tasks.write_perceptual_memory", + bind=True, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=3600, + soft_time_limit=3300, +) +def write_perceptual_memory( + self, + end_user_id: str, + model_api_config: dict, + file_type: str, + file_url: str, + file_message: dict +): + """ + Write perceptual memory for a user into PostgreSQL and Neo4j. + + This task generates or updates the user's perceptual memory + in the backend databases. It is intended to be executed asynchronously + via Celery. + + Args: + end_user_id (uuid.UUID): The unique identifier of the end user. + model_api_config (ModelInfo): API configuration for the model + used to generate perceptual memory. + file_type (str): The file type + file_url (url): The url of file + file_message (dict): The file message containing details about the file + to be processed. + + Returns: + None + """ + file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest() + set_asyncio_event_loop() + with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()): + model_info = ModelInfo(**model_api_config) + with get_db_context() as db: + memory_perceptual_service = MemoryPerceptualService(db) + return asyncio.run(memory_perceptual_service.generate_perceptual_memory( + end_user_id, + model_info, + file_type, + file_url, + file_message, + )) diff --git a/api/app/utils/redis_lock.py b/api/app/utils/redis_lock.py new file mode 100644 index 00000000..99f62d84 --- /dev/null +++ b/api/app/utils/redis_lock.py @@ -0,0 +1,61 @@ +import redis +import uuid +import time + +UNLOCK_SCRIPT = """ +if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("del", KEYS[1]) +else + return 0 +end +""" + + +class RedisLock: + def __init__( + self, + key: str, + redis_client: redis.StrictRedis, + expire: int = 60, + retry_interval: float = 0.1, + timeout: float = 30 + + ): + self.key = key + self.expire = expire + self.value = str(uuid.uuid4()) + self._locked = False + self.retry_interval = retry_interval + self.timeout = timeout + self.redis_client = redis_client + + def acquire(self) -> bool: + start = time.time() + while True: + ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True) + if ok: + self._locked = True + return True + if time.time() - start >= self.timeout: + return False + time.sleep(self.retry_interval) + + def release(self): + if not self._locked: + return + self.redis_client.eval( + UNLOCK_SCRIPT, + 1, + self.key, + self.value + ) + self._locked = False + + def __enter__(self): + ok = self.acquire() + if not ok: + raise RuntimeError(f"Get redis lock timeout: {self.key}") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.release() diff --git a/api/migrations/versions/01587a13522f_202603131028.py b/api/migrations/versions/01587a13522f_202603131028.py new file mode 100644 index 00000000..6412dedd --- /dev/null +++ b/api/migrations/versions/01587a13522f_202603131028.py @@ -0,0 +1,38 @@ +"""202603131028 + +Revision ID: 01587a13522f +Revises: fb834419b18f +Create Date: 2026-03-13 10:28:43.601370 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '01587a13522f' +down_revision: Union[str, None] = 'fb834419b18f' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('tenants', sa.Column('default_language', sa.String(length=10), server_default='zh', nullable=False)) + op.add_column('tenants', sa.Column('supported_languages', postgresql.ARRAY(sa.String(length=10)), server_default=sa.text("'{zh,en}'"), nullable=False)) + op.create_index(op.f('ix_tenants_default_language'), 'tenants', ['default_language'], unique=False) + op.add_column('users', sa.Column('preferred_language', sa.String(length=10), server_default=sa.text("'zh'"), nullable=False)) + op.create_index(op.f('ix_users_preferred_language'), 'users', ['preferred_language'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_users_preferred_language'), table_name='users') + op.drop_column('users', 'preferred_language') + op.drop_index(op.f('ix_tenants_default_language'), table_name='tenants') + op.drop_column('tenants', 'supported_languages') + op.drop_column('tenants', 'default_language') + # ### end Alembic commands ### diff --git a/api/migrations/versions/ea31b4e347d8_202603131452.py b/api/migrations/versions/ea31b4e347d8_202603131452.py new file mode 100644 index 00000000..12716fd9 --- /dev/null +++ b/api/migrations/versions/ea31b4e347d8_202603131452.py @@ -0,0 +1,30 @@ +"""202603131452 + +Revision ID: ea31b4e347d8 +Revises: 01587a13522f +Create Date: 2026-03-13 14:53:20.587580 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'ea31b4e347d8' +down_revision: Union[str, None] = '01587a13522f' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('app_shares', sa.Column('permission', sa.String(), nullable=False, comment='权限模式: readonly | editable')) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('app_shares', 'permission') + # ### end Alembic commands ### diff --git a/web/package.json b/web/package.json index e2d5c898..2799a631 100644 --- a/web/package.json +++ b/web/package.json @@ -36,6 +36,7 @@ "codemirror": "^6.0.2", "copy-to-clipboard": "^3.3.3", "crypto-js": "^4.2.0", + "d3": "^7.9.0", "dayjs": "^1.11.18", "echarts": "^5.6.0", "echarts-for-react": "^3.0.2", @@ -67,6 +68,7 @@ "@tailwindcss/vite": "^4.1.14", "@types/codemirror": "^5.60.17", "@types/crypto-js": "^4.2.2", + "@types/d3": "^7.4.3", "@types/js-yaml": "^4.0.9", "@types/node": "^24.6.0", "@types/react": "^18.2.0", diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index 491f78ea..b8bfac32 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -2,9 +2,10 @@ * @Author: ZhaoYing * @Date: 2026-02-03 14:00:06 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-04 10:58:41 + * @Last Modified time: 2026-03-13 10:48:41 */ import { request } from '@/utils/request' +import type { AxiosRequestConfig } from 'axios' import type { MemoryFormData, } from '@/views/MemoryManagement/types' @@ -94,8 +95,12 @@ export const updatedEndUserProfile = (values: EndUser) => { return request.post(`/memory-storage/updated_end_user/profile`, values) } // User Memory - Relationship network -export const getMemorySearchEdges = (end_user_id: string) => { - return request.get(`/memory-storage/analytics/graph_data`, { end_user_id }) +export const getMemorySearchEdges = (end_user_id: string, config?: AxiosRequestConfig) => { + return request.get(`/memory-storage/analytics/graph_data`, { end_user_id }, config) +} +// User Memory - Community graph +export const getMemoryCommunityGraph = (end_user_id: string, config?: AxiosRequestConfig) => { + return request.get(`/memory-storage/analytics/community_graph`, { end_user_id }, config) } // User Memory - User interest distribution export const getInterestDistributionByUser = (end_user_id: string) => { diff --git a/web/src/components/D3Graph/CommunityGraph.tsx b/web/src/components/D3Graph/CommunityGraph.tsx new file mode 100644 index 00000000..549d69f3 --- /dev/null +++ b/web/src/components/D3Graph/CommunityGraph.tsx @@ -0,0 +1,67 @@ +import React, { useState, useRef, useMemo, useEffect, type FC } from 'react' +import Empty from '@/components/Empty' +import { GRAPH_COLORS, initCommunityGraph } from './utils' +import { useD3Graph } from './hooks' +import type { CommunityD3Node, D3Link, CommunityGraphProps } from './types' + +// ─── Component ──────────────────────────────────────────────────────────────── +// Renders a D3-powered community graph with optional tooltip and legend. + +const CommunityGraph: FC = ({ + data, + empty: emptyProp, + colors = GRAPH_COLORS, + renderTooltip, + showLegend = true, + onCommunityClick, + onNodeClick, + defaultZoom = 1, +}) => { + // Tooltip position and hovered node state + const [tooltip, setTooltip] = useState<{ x: number; y: number; node: CommunityD3Node } | null>(null) + + // Keep callback refs stable to avoid re-initializing the graph on every render + const onCommunityClickRef = useRef(onCommunityClick) + const onNodeClickRef = useRef(onNodeClick) + const renderTooltipRef = useRef(renderTooltip) + useEffect(() => { onCommunityClickRef.current = onCommunityClick }, [onCommunityClick]) + useEffect(() => { onNodeClickRef.current = onNodeClick }, [onNodeClick]) + useEffect(() => { renderTooltipRef.current = renderTooltip }, [renderTooltip]) + + const graphState = useMemo(() => data, [data]) + // Show empty state when explicitly flagged or when there are no nodes + const isEmpty = emptyProp ?? !data?.nodes.length + + // Initialize (or re-initialize) the D3 graph whenever relevant state changes + const containerRef = useD3Graph((container) => { + if (!graphState) return + return initCommunityGraph( + container, + graphState.nodes, + graphState.links as D3Link[], + graphState.communityMap, + graphState.communityCaption, + graphState.communityNodeMap, + { colors, showLegend, defaultZoom, setTooltip: renderTooltip ? setTooltip : () => {}, onCommunityClickRef, onNodeClickRef } + ) + }, [graphState, showLegend, defaultZoom]) + + // Resolve tooltip content: use custom renderer if provided, otherwise fall back to DefaultTooltip + const tooltipNode = tooltip && renderTooltipRef.current + ? renderTooltipRef.current(tooltip.node) + : null + + if (isEmpty) return + return ( +
+
+ {tooltipNode ? ( +
+ {tooltipNode} +
+ ) : undefined} +
+ ) +} + +export default React.memo(CommunityGraph) diff --git a/web/src/components/D3Graph/hooks.ts b/web/src/components/D3Graph/hooks.ts new file mode 100644 index 00000000..93355718 --- /dev/null +++ b/web/src/components/D3Graph/hooks.ts @@ -0,0 +1,24 @@ +import { useRef, useEffect } from 'react' +import * as d3 from 'd3' + +/** + * Generic hook that mounts a D3 graph inside a div container. + * Clears any existing SVG before calling initFn, and runs cleanup on unmount or dep change. + */ +export function useD3Graph( + initFn: (container: HTMLDivElement) => (() => void) | void, + deps: T[] +) { + const containerRef = useRef(null) + useEffect(() => { + const container = containerRef.current + if (!container) return + d3.select(container).selectAll('svg').remove() + const cleanup = initFn(container) + return () => { + cleanup?.() + d3.select(container).selectAll('svg').remove() + } + }, deps) + return containerRef +} diff --git a/web/src/components/D3Graph/types.ts b/web/src/components/D3Graph/types.ts new file mode 100644 index 00000000..88bbbfeb --- /dev/null +++ b/web/src/components/D3Graph/types.ts @@ -0,0 +1,102 @@ +import type { ReactNode, RefObject } from 'react' +import type * as d3 from 'd3' + +// ─── Raw input types (mirror of API response, no external dependency) ───────── +// These interfaces map 1-to-1 with the graph API response shape. + +export interface RawCommunityNode { + id: string + label: 'Community' + properties: { + name: string + summary: string + member_entity_ids: string[] + member_count: number + core_entities: string[] + community_id: string + end_user_id?: string + updated_at?: string + } +} + +export interface RawEntityNode { + id: string + label: 'ExtractedEntity' + properties: { + name: string + description: string + entity_type: string + community_name?: string + [key: string]: unknown + } +} + +export interface RawEdge { + id: string + source: string + target: string +} + +export interface RawCommunityGraphData { + nodes: (RawCommunityNode | RawEntityNode)[] + edges: RawEdge[] +} + +// ─── D3 graph types ─────────────────────────────────────────────────────────── +// Runtime node shape used by D3 simulations; extends SimulationNodeDatum for x/y/vx/vy. + +export interface CommunityD3Node extends d3.SimulationNodeDatum { + id: string + name: string + community: string + label: string + symbolSize: number + color: string + properties?: RawEntityNode['properties'] +} + +export interface D3Link extends d3.SimulationLinkDatum { + isCross: boolean +} + +// Convex-hull shape rendered behind each community cluster. +export interface HullDatum { + id: string + path: string + color: string + labelX: number + labelY: number + dashed: boolean + caption: string +} + +// Fully transformed graph data ready to be passed into initCommunityGraph. +export interface CommunityGraphData { + nodes: CommunityD3Node[] + links: Array<{ source: string; target: string; isCross: boolean }> + communityMap: Map + communityCaption: Map + communityNodeMap: Map +} + +// Props accepted by the CommunityGraph React component. +export interface CommunityGraphProps { + data: CommunityGraphData | null + empty?: boolean + colors?: string[] + renderTooltip?: (node: CommunityD3Node) => ReactNode + showLegend?: boolean + onCommunityClick?: (node: RawCommunityNode) => void + onNodeClick?: (node: CommunityD3Node) => void + defaultZoom?: number +} + +// Options forwarded from the React component into the D3 initializer. +export interface InitOptions { + colors: string[] + showLegend: boolean + defaultZoom: number + setTooltip: (s: { x: number; y: number; node: CommunityD3Node } | null) => void + onCommunityClickRef: RefObject<((node: RawCommunityNode) => void) | undefined> + onNodeClickRef: RefObject<((node: CommunityD3Node) => void) | undefined> +} diff --git a/web/src/components/D3Graph/utils.ts b/web/src/components/D3Graph/utils.ts new file mode 100644 index 00000000..87d888ac --- /dev/null +++ b/web/src/components/D3Graph/utils.ts @@ -0,0 +1,547 @@ +import * as d3 from 'd3' +import type { CommunityD3Node, D3Link, HullDatum, CommunityGraphData, RawCommunityGraphData, RawCommunityNode, RawEntityNode, InitOptions } from './types' + +// ─── Colors ─────────────────────────────────────────────────────────────────── + +export const GRAPH_COLORS = ['#155EEF', '#369F21', '#4DA8FF', '#FF5D34', '#9C6FFF', '#FF8A4C', '#8BAEF7', '#FFB048'] +export const colorAt = (i: number) => GRAPH_COLORS[i % GRAPH_COLORS.length] + +export function connectionToRadius(connections: number): number { + if (connections <= 1) return 5 + if (connections <= 10) return 8 + if (connections <= 15) return 11 + if (connections <= 20) return 16 + return 22 +} + +// ─── Arrow markers ──────────────────────────────────────────────────────────── + +export function addArrowMarkers( + defs: d3.Selection, + markers: { id: string; color: string }[] +) { + markers.forEach(({ id, color }) => { + defs.append('marker') + .attr('id', id) + .attr('viewBox', '0 -4 8 8') + .attr('refX', 8).attr('refY', 0) + .attr('markerWidth', 6).attr('markerHeight', 6) + .attr('orient', 'auto') + .append('path').attr('d', 'M0,-4L8,0L0,4').attr('fill', color) + }) +} + +// ─── Zoom ───────────────────────────────────────────────────────────────────── + +export function addZoom( + svg: d3.Selection, + g: d3.Selection +) { + svg.call( + d3.zoom().scaleExtent([0.2, 4]) + .on('zoom', e => g.attr('transform', e.transform)) + ) +} + +// ─── Node drag ──────────────────────────────────────────────────────────────── + +export function makeNodeDrag( + simulation: d3.Simulation> +) { + return d3.drag() + .on('start', (e, d) => { if (!e.active) simulation.alphaTarget(0.3).restart(); d.fx = d.x; d.fy = d.y }) + .on('drag', (e, d) => { d.fx = e.x; d.fy = e.y }) + .on('end', (e, d) => { if (!e.active) simulation.alphaTarget(0); d.fx = e.x; d.fy = e.y }) +} + +// ─── Cluster force ──────────────────────────────────────────────────────────── +// Works for both string and number group keys. + +export function makeClusterForce( + nodes: N[], + getGroup: (d: N) => string | number, + centers: Record, + width: number, + height: number, + opts: { pullStrength?: number; minSepRatio?: number; pushStrength?: number } = {} +) { + const { pullStrength = 0.45, minSepRatio = 0.68, pushStrength = 1.0 } = opts + return (alpha: number) => { + // pre-group nodes by key to avoid repeated filter() in hot path + const groups = new Map() + nodes.forEach(d => { + const k = String(getGroup(d)) + if (!groups.has(k)) groups.set(k, []) + groups.get(k)!.push(d) + }) + // pull toward group center + nodes.forEach(d => { + const c = centers[getGroup(d)] + if (!c) return + d.vx = (d.vx ?? 0) + (c.x - (d.x ?? 0)) * pullStrength * alpha + d.vy = (d.vy ?? 0) + (c.y - (d.y ?? 0)) * pullStrength * alpha + }) + // live centroids + const centroids: Record = {} + nodes.forEach(d => { + const g = String(getGroup(d)) + if (!centroids[g]) centroids[g] = { x: 0, y: 0, n: 0 } + centroids[g].x += d.x ?? 0 + centroids[g].y += d.y ?? 0 + centroids[g].n++ + }) + Object.values(centroids).forEach(c => { c.x /= c.n; c.y /= c.n }) + // push groups apart + const keys = Object.keys(centroids) + const minSep = Math.min(width, height) * minSepRatio + for (let i = 0; i < keys.length; i++) { + for (let j = i + 1; j < keys.length; j++) { + const ci = centroids[keys[i]], cj = centroids[keys[j]] + const dx = cj.x - ci.x, dy = cj.y - ci.y + const dist = Math.sqrt(dx * dx + dy * dy) || 1 + if (dist >= minSep) continue + const push = ((minSep - dist) / dist) * pushStrength * alpha + const fx = dx * push, fy = dy * push + groups.get(keys[i])?.forEach(d => { d.vx = (d.vx ?? 0) - fx; d.vy = (d.vy ?? 0) - fy }) + groups.get(keys[j])?.forEach(d => { d.vx = (d.vx ?? 0) + fx; d.vy = (d.vy ?? 0) + fy }) + } + } + } +} + +// ─── Group centers ──────────────────────────────────────────────────────────── + +export function buildGroupCenters( + keys: (string | number)[], + width: number, + height: number, + radiusRatio = 0.4 +): Record { + const centers: Record = {} + const r = Math.min(width, height) * radiusRatio + keys.forEach((key, i) => { + const angle = (i / keys.length) * 2 * Math.PI - Math.PI / 2 + centers[key] = { x: width / 2 + r * Math.cos(angle), y: height / 2 + r * Math.sin(angle) } + }) + return centers +} + +// ─── Community graph data transform ───────────────────────────────────────── + +export function buildCommunityGraphData(raw: RawCommunityGraphData, colors: string[] = GRAPH_COLORS): CommunityGraphData | null { + const getColor = (i: number) => colors[i % colors.length] + + const communityNodes = raw.nodes.filter(n => n.label === 'Community') as RawCommunityNode[] + const communityCaption = new Map() + const communityMap = new Map() + + communityNodes.forEach(n => { + communityCaption.set(n.id, n.properties.name) + communityMap.set(n.id, n.properties.member_entity_ids) + }) + + const entityToCommunity = new Map() + communityMap.forEach((members, commId) => members.forEach(eid => entityToCommunity.set(eid, commId))) + + const commKeys = Array.from(communityMap.keys()) + const commIndex = new Map(commKeys.map((k, i) => [k, i])) + + const entityNodes = raw.nodes.filter(n => n.label === 'ExtractedEntity') as RawEntityNode[] + const entityNodeSet = new Set(entityNodes.map(n => n.id)) + + const connectionCount: Record = {} + raw.edges.forEach(e => { + if (entityNodeSet.has(e.source)) connectionCount[e.source] = (connectionCount[e.source] || 0) + 1 + if (entityNodeSet.has(e.target)) connectionCount[e.target] = (connectionCount[e.target] || 0) + 1 + }) + + const nodes: CommunityD3Node[] = entityNodes.map(n => { + const commId = entityToCommunity.get(n.id) ?? commKeys[0] + return { + id: n.id, + name: n.properties.name, + community: commId, + label: n.label, + symbolSize: connectionToRadius(connectionCount[n.id] || 0), + color: getColor(commIndex.get(commId) ?? 0), + properties: n.properties, + } + }) + + if (!nodes.length) return null + + const links = raw.edges + .filter(e => entityNodeSet.has(e.source) && entityNodeSet.has(e.target)) + .map(e => ({ + source: e.source, + target: e.target, + isCross: entityToCommunity.get(e.source) !== entityToCommunity.get(e.target), + })) + + const communityNodeMap = new Map( + communityNodes.map(n => [n.id, n]) + ) + return { nodes, links, communityMap, communityCaption, communityNodeMap } +} + +// ─── Hull helpers ───────────────────────────────────────────────────────────── + +const smoothLine = d3.line<[number, number]>() + .x(d => d[0]).y(d => d[1]) + .curve(d3.curveCatmullRomClosed.alpha(0.5)) + +function expandPoints(pts: [number, number][], pad: number): [number, number][] { + const cx = pts.reduce((s, p) => s + p[0], 0) / pts.length + const cy = pts.reduce((s, p) => s + p[1], 0) / pts.length + return pts.map(([x, y]) => { + const dx = x - cx, dy = y - cy + const len = Math.sqrt(dx * dx + dy * dy) || 1 + return [x + (dx / len) * pad, y + (dy / len) * pad] + }) +} + +function toHullPoints(pts: [number, number][]): [number, number][] { + if (pts.length === 1) { + const [x, y] = pts[0] + return [[x - 1, y - 1], [x + 1, y - 1], [x, y + 1]] + } + if (pts.length === 2) { + const [[x1, y1], [x2, y2]] = pts + return [[x1, y1], [x2, y2], [(x1 + x2) / 2, (y1 + y2) / 2 - 1]] + } + return d3.polygonHull(pts) ?? pts +} + +const CIRCLE_THRESHOLD = 4 // 节点数 < 此值时使用圆形 +const CIRCLE_SEGMENTS = 32 + +function circlePoints(cx: number, cy: number, r: number): [number, number][] { + return Array.from({ length: CIRCLE_SEGMENTS }, (_, i) => { + const a = (i / CIRCLE_SEGMENTS) * 2 * Math.PI + return [cx + r * Math.cos(a), cy + r * Math.sin(a)] as [number, number] + }) +} + +export function buildHullData( + nodes: CommunityD3Node[], + communityMap: Map, + communityCaption: Map, + colors: string[] +): HullDatum[] { + const getColor = (i: number) => colors[i % colors.length] + const byComm = new Map() + communityMap.forEach((_, id) => byComm.set(id, [])) + nodes.forEach(d => { + if (d.x != null && d.y != null) byComm.get(d.community)?.push([d.x, d.y]) + }) + + const hulls: HullDatum[] = [] + let ci = 0 + byComm.forEach((pts, id) => { + const color = getColor(ci++) + if (!pts.length) return + let pathPoints: [number, number][] + if (pts.length < CIRCLE_THRESHOLD) { + const cx = pts.reduce((s, p) => s + p[0], 0) / pts.length + const cy = pts.reduce((s, p) => s + p[1], 0) / pts.length + pathPoints = circlePoints(cx, cy, 60) + } else { + pathPoints = expandPoints(toHullPoints(pts), 60) as [number, number][] + } + const path = smoothLine(pathPoints) + if (!path) return + hulls.push({ + id, path, color, + labelX: pathPoints.reduce((s, p) => s + p[0], 0) / pathPoints.length, + labelY: Math.min(...pathPoints.map(p => p[1])) - 10, + dashed: pts.length <= 2, + caption: communityCaption.get(id) ?? id, + }) + }) + return hulls +} + +// ─── Hull render ────────────────────────────────────────────────────────────── + +export function renderHulls( + hullG: d3.Selection, + hulls: HullDatum[], + hiddenCommunities: Set, + nodes: CommunityD3Node[], + simulation: d3.Simulation, + onCommunityClick?: (node: RawCommunityNode) => void, + communityNodeMap?: Map +) { + let dragNodes: CommunityD3Node[] = [] + let dragStart = { x: 0, y: 0 } + const communityDrag = d3.drag() + .on('start', (event, d) => { + if (!event.active) simulation.alphaTarget(0.3).restart() + dragNodes = nodes.filter(n => n.community === d.id) + dragStart = { x: event.x, y: event.y } + dragNodes.forEach(n => { n.fx = n.x; n.fy = n.y }) + }) + .on('drag', (event) => { + const dx = event.x - dragStart.x, dy = event.y - dragStart.y + dragStart = { x: event.x, y: event.y } + dragNodes.forEach(n => { n.fx = (n.fx ?? n.x ?? 0) + dx; n.fy = (n.fy ?? n.y ?? 0) + dy }) + }) + .on('end', (event) => { if (!event.active) simulation.alphaTarget(0) }) + + const pathSel = hullG.selectAll('path.hull').data(hulls, d => d.id) + pathSel.enter().append('path').attr('class', 'hull').style('cursor', 'grab') + .merge(pathSel) + .call(communityDrag) + .attr('d', d => d.path) + .attr('fill', d => d.color).attr('fill-opacity', 0.08) + .attr('stroke', d => d.color).attr('stroke-opacity', 0.5).attr('stroke-width', 1.5) + .attr('stroke-dasharray', 'none') + .style('display', d => hiddenCommunities.has(d.id) ? 'none' : null) + .on('click', (event, d) => { + if ((event as MouseEvent).defaultPrevented) return + const node = communityNodeMap?.get(d.id) + if (node) onCommunityClick?.(node) + }) + pathSel.exit().remove() + + const labelSel = hullG.selectAll('text.hull-label').data(hulls, d => d.id) + labelSel.enter().append('text').attr('class', 'hull-label') + .attr('text-anchor', 'middle').attr('font-size', '12px').attr('font-weight', '500') + .style('pointer-events', 'none') + .merge(labelSel) + .attr('x', d => d.labelX).attr('y', d => d.labelY) + .attr('fill', d => d.color) + .style('display', d => hiddenCommunities.has(d.id) ? 'none' : null) + .text(d => d.caption) + labelSel.exit().remove() +} + +// ─── Community graph init ───────────────────────────────────────────────────── + +export function initCommunityGraph( + container: HTMLDivElement, + nodes: CommunityD3Node[], + links: D3Link[], + communityMap: Map, + communityCaption: Map, + communityNodeMap: Map, + opts: InitOptions +) { + const { colors, showLegend, defaultZoom, setTooltip, onCommunityClickRef, onNodeClickRef } = opts + const getColor = (i: number) => colors[i % colors.length] + + const width = container.clientWidth || 600 + const height = container.clientHeight || 518 + + const svg = d3.select(container).append('svg') + .attr('width', width).attr('height', height) + .style('width', '100%').style('height', '100%') + .style('background', '#F6F8FC') + + const g = svg.append('g') + + const zoom = d3.zoom() + .scaleExtent([0.2, 4]) + .on('zoom', e => g.attr('transform', e.transform)) + svg.call(zoom) + if (defaultZoom !== 1) { + svg.call(zoom.transform, d3.zoomIdentity + .translate(width / 2 * (1 - defaultZoom), height / 2 * (1 - defaultZoom)) + .scale(defaultZoom) + ) + } + + const defs = svg.append('defs') + addArrowMarkers(defs, [{ id: 'arrow', color: 'rgba(91, 97, 103, 0.7)' }]) + + const commKeys = Array.from(communityMap.keys()) + const centers = buildGroupCenters(commKeys, width, height, 0.45) + const linkedIds = new Set(links.flatMap(l => [l.source as string, l.target as string])) + + const simulation = d3.forceSimulation(nodes) + .force('link', d3.forceLink(links).id(d => d.id).distance(60)) + .force('charge', d3.forceManyBody().strength(-120)) + .force('center', d3.forceCenter(width / 2, height / 2).strength(0.02)) + .force('collision', d3.forceCollide(d => d.symbolSize + 16)) + .force('cluster', makeClusterForce(nodes, d => d.community, centers, width, height, { + pullStrength: 0.45, minSepRatio: 0.68, pushStrength: 1.0, + })) + .force('isolatedPull', (alpha: number) => { + nodes.forEach(d => { + if (linkedIds.has(d.id)) return + const c = centers[d.community] + if (!c) return + d.vx = (d.vx ?? 0) + (c.x - (d.x ?? 0)) * 0.4 * alpha + d.vy = (d.vy ?? 0) + (c.y - (d.y ?? 0)) * 0.4 * alpha + }) + }) + + const hullG = g.append('g').attr('class', 'hulls') + const hiddenCommunities = new Set() + + const linkSel = g.append('g').selectAll('line') + .data(links).enter().append('line') + .attr('stroke', '#5B6167') + .attr('stroke-opacity', d => d.isCross ? 0.3 : 0.5) + .attr('stroke-width', d => d.isCross ? 1 : 1.2) + .attr('marker-end', 'url(#arrow)') + + const nodeSel = g.append('g').selectAll('g') + .data(nodes).enter().append('g') + .call(makeNodeDrag(simulation)) + + nodeSel.append('circle') + .attr('r', d => d.symbolSize) + .attr('fill', d => d.color).attr('fill-opacity', 0.85) + .attr('stroke', '#fff').attr('stroke-width', 1.5) + .style('cursor', 'pointer') + .on('mouseenter', (event: MouseEvent, d: CommunityD3Node) => { + const { left, top } = container.getBoundingClientRect() + setTooltip({ x: event.clientX - left, y: event.clientY - top, node: d }) + }) + .on('mousemove', (event: MouseEvent) => { + const { left, top } = container.getBoundingClientRect() + const nd = d3.select(event.target as SVGCircleElement).datum() + setTooltip({ x: event.clientX - left, y: event.clientY - top, node: nd }) + }) + .on('mouseleave', () => setTooltip(null)) + .on('click', (_event: MouseEvent, d: CommunityD3Node) => onNodeClickRef.current?.(d)) + + nodeSel.append('text') + .text(d => d.name) + .attr('x', 0).attr('dy', d => -(d.symbolSize + 5)) + .attr('text-anchor', 'middle').attr('font-size', '11px').attr('fill', '#444') + .style('pointer-events', 'none') + + if (showLegend) { + renderLegend( + svg, + commKeys.map((cid, i) => ({ key: cid, label: communityCaption.get(cid) ?? cid, color: getColor(i) })), + width, height, + (key, hidden) => { + const cid = key as string + if (hidden) hiddenCommunities.add(cid) + else hiddenCommunities.delete(cid) + nodeSel.style('display', d => hiddenCommunities.has(d.community) ? 'none' : null) + linkSel.style('display', d => { + const s = d.source as CommunityD3Node, t = d.target as CommunityD3Node + return hiddenCommunities.has(s.community) || hiddenCommunities.has(t.community) ? 'none' : null + }) + hullG.selectAll('path.hull').style('display', d => hiddenCommunities.has(d.id) ? 'none' : null) + hullG.selectAll('text.hull-label').style('display', d => hiddenCommunities.has(d.id) ? 'none' : null) + } + ) + } + + simulation.on('tick', () => { + linkSel + .attr('x1', d => (d.source as CommunityD3Node).x ?? 0) + .attr('y1', d => (d.source as CommunityD3Node).y ?? 0) + .attr('x2', d => { + const s = d.source as CommunityD3Node, t = d.target as CommunityD3Node + const dx = (t.x ?? 0) - (s.x ?? 0), dy = (t.y ?? 0) - (s.y ?? 0) + const dist = Math.sqrt(dx * dx + dy * dy) || 1 + return (t.x ?? 0) - (dx / dist) * (t.symbolSize + 2) + }) + .attr('y2', d => { + const s = d.source as CommunityD3Node, t = d.target as CommunityD3Node + const dx = (t.x ?? 0) - (s.x ?? 0), dy = (t.y ?? 0) - (s.y ?? 0) + const dist = Math.sqrt(dx * dx + dy * dy) || 1 + return (t.y ?? 0) - (dy / dist) * (t.symbolSize + 2) + }) + nodeSel.attr('transform', d => `translate(${d.x ?? 0},${d.y ?? 0})`) + renderHulls(hullG, buildHullData(nodes, communityMap, communityCaption, colors), hiddenCommunities, nodes, simulation, (n) => onCommunityClickRef.current?.(n), communityNodeMap) + }) + + return () => { simulation.stop(); d3.select(container).selectAll('svg').remove() } +} + +// ─── Legend ─────────────────────────────────────────────────────────────────── + +export interface LegendItem { + key: string | number + label: string + color: string +} + +const LEGEND_GAP = 12 +const LEGEND_RECT_W = 20 +const LEGEND_RECT_H = 10 +const LEGEND_TEXT_OFFSET = 24 +const LEGEND_FONT_SIZE = 11 +const LEGEND_ROW_H = 24 +const LEGEND_BOTTOM_PAD = 8 + +// Approximate text width using canvas measureText if available, else char-based estimate +function measureText(text: string, fontSize: number): number { + try { + const ctx = document.createElement('canvas').getContext('2d') + if (ctx) { ctx.font = `${fontSize}px sans-serif`; return ctx.measureText(text).width } + } catch { /* noop */ } + return text.length * fontSize * 0.6 +} + +export function renderLegend( + svg: d3.Selection, + items: LegendItem[], + width: number, + height: number, + onToggle: (key: string | number, hidden: boolean) => void +) { + // Compute per-item width: rect + text-offset + textW + const itemWidths = items.map(item => + LEGEND_RECT_W + LEGEND_TEXT_OFFSET + measureText(item.label, LEGEND_FONT_SIZE) + ) + + // Layout items into rows + const rows: { item: LegendItem; w: number; x: number; row: number }[] = [] + let rowIdx = 0, curX = 0 + itemWidths.forEach((w, i) => { + const slotW = w + LEGEND_GAP + if (curX > 0 && curX + w > width - LEGEND_GAP * 2) { rowIdx++; curX = 0 } + rows.push({ item: items[i], w, x: curX, row: rowIdx }) + curX += slotW + }) + + const totalRows = rowIdx + 1 + const totalH = totalRows * LEGEND_ROW_H + const baseY = height - totalH - LEGEND_BOTTOM_PAD + + // Center each row + const rowWidths: number[] = Array(totalRows).fill(0) + rows.forEach(({ w, row }, i) => { + rowWidths[row] += w + (i > 0 && rows[i - 1].row === row ? LEGEND_GAP : 0) + }) + // Recalculate row widths properly + const rowTotals: number[] = Array(totalRows).fill(0) + const rowCounts: number[] = Array(totalRows).fill(0) + rows.forEach(r => { rowCounts[r.row]++; rowTotals[r.row] += r.w }) + rowTotals.forEach((_, ri) => { rowTotals[ri] += Math.max(0, rowCounts[ri] - 1) * LEGEND_GAP }) + + const legendG = svg.append('g') + + rows.forEach(({ item, x, row }) => { + const rowOffsetX = (width - rowTotals[row]) / 2 + const g = legendG.append('g') + .attr('transform', `translate(${rowOffsetX + x},${baseY + row * LEGEND_ROW_H + LEGEND_ROW_H / 2})`) + .style('cursor', 'pointer') + + const rect = g.append('rect') + .attr('x', 0).attr('y', -LEGEND_RECT_H / 2) + .attr('width', LEGEND_RECT_W).attr('height', LEGEND_RECT_H).attr('rx', 2) + .attr('fill', item.color) + + const text = g.append('text') + .text(item.label) + .attr('x', LEGEND_TEXT_OFFSET).attr('dy', '0.35em') + .attr('font-size', `${LEGEND_FONT_SIZE}px`).attr('fill', '#5B6167') + + let hidden = false + g.on('click', () => { + hidden = !hidden + rect.attr('fill', hidden ? '#ccc' : item.color) + text.attr('fill', hidden ? '#bbb' : '#5B6167') + onToggle(item.key, hidden) + }) + }) +} diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 62f404aa..baeb5848 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1482,6 +1482,33 @@ export const en = { memoryNum: 'memories', memory_config_name: 'Memory Engine', searchPlaceholder: 'Search memory store name', + + communityNetwork: 'Community Graph', + community: 'Community', + "Person": "Person Entity Node", + "Organization": "Organization Entity Node", + "ORG": "Organization Entity Node", + "Location": "Location Entity Node", + "LOC": "Location Entity Node", + "Event": "Event Entity Node", + "Concept": "Concept Entity Node", + "Time": "Time Entity Node", + "Position": "Position Entity Node", + "WorkRole": "Work Role Entity Node", + "System": "System Entity Node", + "Policy": "Policy Entity Node", + "HistoricalPeriod": "Historical Period Entity Node", + "HistoricalState": "Historical State Entity Node", + "HistoricalEvent": "Historical Event Entity Node", + "EconomicFactor": "Economic Factor Entity Node", + "Condition": "Condition Entity Node", + "Numeric": "Numeric Entity Node", + "Work": "Work / Output", + member_count: 'Member Count', + member_count_desc: 'entities', + summary: 'Summary', + core_entities: 'Core Entities', + communityDetailEmptyDesc: 'Click on a community in the chart on the left to view details', }, space: { createSpace: 'Create Space', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 387c67c3..e8aa7e5a 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1480,6 +1480,33 @@ export const zh = { memoryNum: '条记忆', memory_config_name: '记忆引擎', searchPlaceholder: '搜索记忆库名称', + + communityNetwork: '社区图谱', + community: '社区', + "Person": "人物实体节点", + "Organization": "组织实体节点", + "ORG": "组织实体节点", + "Location": "地点实体节点", + "LOC": "地点实体节点", + "Event": "事件实体节点", + "Concept": "概念实体节点", + "Time": "时间实体节点", + "Position": "职位实体节点", + "WorkRole": "职业实体节点", + "System": "系统实体节点", + "Policy": "政策实体节点", + "HistoricalPeriod": "历史时期实体节点", + "HistoricalState": "历史国家实体节点", + "HistoricalEvent": "历史事件实体节点", + "EconomicFactor": "经济因素实体节点", + "Condition": "条件实体节点", + "Numeric": "数值实体节点", + "Work": "作品/工作成果", + member_count: '成员数', + member_count_desc: '个实体', + summary: '摘要', + core_entities: '核心实体', + communityDetailEmptyDesc: '点击左侧图表中的社区查看详情', }, space: { createSpace: '创建空间', diff --git a/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx b/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx new file mode 100644 index 00000000..2757498d --- /dev/null +++ b/web/src/views/UserMemoryDetail/components/CommunityNetwork.tsx @@ -0,0 +1,72 @@ +import React, { useState, type FC, useEffect } from 'react' +import { useParams } from 'react-router-dom' +import { useTranslation } from 'react-i18next' +import type { CommunityD3Node, CommunityGraphData, RawCommunityGraphData, RawCommunityNode } from '@/components/D3Graph/types' +import { buildCommunityGraphData } from '@/components/D3Graph/utils' +import CommunityGraph from '@/components/D3Graph/CommunityGraph' +import { getMemoryCommunityGraph } from '@/api/memory' + +// ─── Tooltip ────────────────────────────────────────────────────────────────── + +const NodeTooltip: FC<{ node: CommunityD3Node }> = ({ node }) => { + const { t } = useTranslation() + return ( +
+
+ {node.properties?.name ?? node.name} +
+ {node.properties?.description && ( +
+ {node.properties.description} +
+ )} +
+ {t('userMemory.type')}: + {t(`userMemory.${node.properties?.entity_type}`)} +
+
+ {t('userMemory.community')}: + {node.properties?.community_name} +
+
+ ) +} + +// ─── Component ──────────────────────────────────────────────────────────────── + +const CommunityNetwork: FC<{ onSelectCommunity?: (node: RawCommunityNode) => void }> = ({ onSelectCommunity }) => { + const { id } = useParams() + const [graphData, setGraphData] = useState(null) + const [empty, setEmpty] = useState(false) + + useEffect(() => { + if (!id) return + const controller = new AbortController() + setEmpty(false) + setGraphData(null) + getMemoryCommunityGraph(id, { signal: controller.signal }).then(res => { + const raw = res as RawCommunityGraphData + if (!raw.nodes?.length) { setEmpty(true); return } + const built = buildCommunityGraphData(raw) + if (!built) { setEmpty(true); return } + setGraphData(built) + }).catch((e) => { if (e?.code !== 'ERR_CANCELED') setEmpty(true) }) + return () => controller.abort() + }, [id]) + + return ( + } + /> + ) +} + +export default React.memo(CommunityNetwork) diff --git a/web/src/views/UserMemoryDetail/components/RelationshipNetwork.tsx b/web/src/views/UserMemoryDetail/components/RelationshipNetwork.tsx index aa8b9d7c..66e37a45 100644 --- a/web/src/views/UserMemoryDetail/components/RelationshipNetwork.tsx +++ b/web/src/views/UserMemoryDetail/components/RelationshipNetwork.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 18:32:00 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 18:32:00 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-13 14:51:17 */ /** * Relationship Network Component @@ -13,18 +13,20 @@ import React, { type FC, useEffect, useState, useRef, useCallback } from 'react' import { useTranslation } from 'react-i18next' import { useParams, useNavigate } from 'react-router-dom' -import { Col, Row, Space, Button } from 'antd' +import { Col, Row, Space, Button, Tabs, Flex, Divider } from 'antd' import dayjs from 'dayjs' import ReactEcharts from 'echarts-for-react' import RbCard from '@/components/RbCard/Card' import detailEmpty from '@/assets/images/userMemory/detail_empty.png' import type { Node, Edge, GraphData, StatementNodeProperties, ExtractedEntityNodeProperties } from '../types' +import type { RawCommunityNode } from '@/components/D3Graph/types' import { getMemorySearchEdges, } from '@/api/memory' import Empty from '@/components/Empty' import Tag from '@/components/Tag' +import CommunityNetwork from './CommunityNetwork' /** Node color palette */ const colors = ['#155EEF', '#369F21', '#4DA8FF', '#FF5D34', '#9C6FFF', '#FF8A4C', '#8BAEF7', '#FFB048'] @@ -36,16 +38,21 @@ const RelationshipNetwork:FC = () => { const [nodes, setNodes] = useState([]) const [links, setLinks] = useState([]) const [categories, setCategories] = useState<{ name: string }[]>([]) - const [selectedNode, setSelectedNode] = useState(null) + const [selectedNode, setSelectedNode] = useState(null) // const [fullScreen, setFullScreen] = useState(false) const navigate = useNavigate() + const [activeTab, setActiveTab] = useState('relationshipNetwork') console.log('categories', categories) + const edgeAbortRef = useRef(null) + /** Fetch relationship network data */ const getEdgeData = useCallback(() => { if (!id) return + edgeAbortRef.current?.abort() + edgeAbortRef.current = new AbortController() setSelectedNode(null) - getMemorySearchEdges(id).then((res) => { + getMemorySearchEdges(id, { signal: edgeAbortRef.current.signal }).then((res) => { const { nodes, edges, statistics } = res as GraphData const curNodes: Node[] = [] const curEdges: Edge[] = [] @@ -123,6 +130,7 @@ const RelationshipNetwork:FC = () => { useEffect(() => { if (!id) return getEdgeData() + return () => { edgeAbortRef.current?.abort() } }, [id]) useEffect(() => { @@ -153,34 +161,36 @@ const RelationshipNetwork:FC = () => { const params = new URLSearchParams({ nodeId: selectedNode.id, nodeLabel: selectedNode.label, - nodeName: selectedNode.name || '' + nodeName: (selectedNode as Node).name || '' }) navigate(`/user-memory/detail/${id}/GRAPH?${params.toString()}`) } + const handleChangeTab = (tab: string) => { + if (tab === 'communityNetwork') { + edgeAbortRef.current?.abort() + } else { + getEdgeData() + } + setActiveTab(tab) + setSelectedNode(null) + } return ( {/* Relationship Network */} - - //
- // {t('userMemory.fullScreen')} - //
- // } - > + + ({ key, label: t(`userMemory.${key}`) }))} + activeKey={activeTab} + onChange={handleChangeTab} + />
- {nodes.length === 0 ? ( - - ) : ( - setSelectedNode(community)} /> + : nodes.length === 0 + ? + : { } }} /> - )} + }
{/* Memory Details */} - -
- {t('userMemory.completeMemory')} - } + bodyClassName="rb:p-0!" + extra={selectedNode && !(selectedNode as RawCommunityNode).properties.community_id && ( + + )} >
{!selectedNode - ? - : <> - {selectedNode.name &&
{selectedNode.name}
} -
- <> + ? + : (selectedNode as RawCommunityNode).properties.community_id + ?
+
+ {(selectedNode as RawCommunityNode).properties.name} +
+
{t('userMemory.summary')}
+
+ {(selectedNode as RawCommunityNode).properties.summary} +
+ + {t('userMemory.member_count')} + {(selectedNode as RawCommunityNode).properties.member_count}{t('userMemory.member_count_desc')} + + + +
{t('userMemory.core_entities')}
+
    + {(selectedNode as RawCommunityNode).properties.core_entities.map((entity, index) =>
  • {entity}
  • )} +
+
+ : <> + {(selectedNode as Node).name && ( +
+ {(selectedNode as Node).name} +
+ )} +
{t('userMemory.memoryContent')}
{['Chunk', 'Dialogue', 'MemorySummary'].includes(selectedNode.label) && 'content' in selectedNode.properties ? selectedNode.properties.content : selectedNode.label === 'ExtractedEntity' && 'description' in selectedNode.properties - ? selectedNode.properties.description - : selectedNode.label === 'Statement' && 'statement' in selectedNode.properties - ? selectedNode.properties.statement - : '' - } -
- -
-
{t('userMemory.created_at')}
-
- {dayjs(selectedNode?.properties.created_at).format('YYYY-MM-DD HH:mm:ss')} + ? selectedNode.properties.description + : selectedNode.label === 'Statement' && 'statement' in selectedNode.properties + ? selectedNode.properties.statement + : ''}
- {selectedNode?.properties.associative_memory > 0 &&
-
{t('userMemory.associative_memory')}
+
+
{t('userMemory.created_at')}
- {selectedNode?.properties.associative_memory} {t('userMemory.unix')}{t('userMemory.associative_memory')} + {dayjs((selectedNode as Node).properties.created_at).format('YYYY-MM-DD HH:mm:ss')}
-
} - {selectedNode.label === 'Statement' && <> - {(['emotion_keywords', 'emotion_type', 'emotion_subject', 'importance_score'] as const).map(key => { - const statementProps = selectedNode.properties as StatementNodeProperties; - if ((key === 'emotion_keywords' && statementProps[key]?.length > 0) || typeof statementProps[key] === 'string') { - console.log('statementProps[key]', statementProps[key]) - return ( -
- {t(`userMemory.Statement_${key}`)} -
- {key === 'emotion_keywords' - ? {statementProps.emotion_keywords.map((vo, index) => {vo})} - : statementProps[key] - } + {(selectedNode as Node).properties.associative_memory > 0 && ( +
+
{t('userMemory.associative_memory')}
+
+ {(selectedNode as Node).properties.associative_memory} + {' '}{t('userMemory.unix')}{t('userMemory.associative_memory')} +
+
+ )} + + {selectedNode.label === 'Statement' && ( + (['emotion_keywords', 'emotion_type', 'emotion_subject', 'importance_score'] as const).map(key => { + const p = selectedNode.properties as StatementNodeProperties + if ((key === 'emotion_keywords' && p[key]?.length > 0) || typeof p[key] === 'string') { + return ( +
+ {t(`userMemory.Statement_${key}`)} +
+ {key === 'emotion_keywords' + ? {p.emotion_keywords.map((v, i) => {v})} + : p[key]} +
-
- ) - } - return null - })} - } - {selectedNode.label === 'ExtractedEntity' && <> - {(['name', 'entity_type', 'aliases', 'connect_strngth', 'importance_score'] as const).map(key => { - const entityProps = selectedNode.properties as ExtractedEntityNodeProperties; - if (entityProps[key]) { - return ( -
- {t(`userMemory.ExtractedEntity_${key}`)} -
- {Array.isArray(entityProps[key]) && entityProps[key].length > 0 - ? entityProps[key].map((vo, index) =>
- {vo}
) - : entityProps[key] - } + ) + } + return null + }) + )} + + {selectedNode.label === 'ExtractedEntity' && ( + (['name', 'entity_type', 'aliases', 'connect_strngth', 'importance_score'] as const).map(key => { + const p = selectedNode.properties as ExtractedEntityNodeProperties + if (p[key]) { + return ( +
+ {t(`userMemory.ExtractedEntity_${key}`)} +
+ {Array.isArray(p[key]) && p[key].length > 0 + ? p[key].map((v, i) =>
- {v}
) + : p[key]} +
-
- ) - } - return null - })} - } + ) + } + return null + }) + )} +
-
- + }
diff --git a/web/src/views/UserMemoryDetail/types.ts b/web/src/views/UserMemoryDetail/types.ts index 8333cb2c..72e896ad 100644 --- a/web/src/views/UserMemoryDetail/types.ts +++ b/web/src/views/UserMemoryDetail/types.ts @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 17:57:15 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 17:57:15 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-13 11:49:52 */ /** * User Memory Detail Types @@ -90,6 +90,7 @@ export interface ExtractedEntityNodeProperties { connect_strngth: string; importance_score: number; associative_memory: number; + community_name?: string; } /** * Memory summary node @@ -246,4 +247,53 @@ export interface ForgetData { */ export interface GraphDetailRef { handleOpen: (vo: Node) => void -} \ No newline at end of file +} +// Community +export type CommunityNodeType = 'Community' | 'ExtractedEntity'; +export type CommunityEdgeType = 'BELONGS_TO_COMMUNITY' | 'EXTRACTED_RELATIONSHIP'; +export type CommunityEntityType = "Person" | "Organization" | "ORG" | "Location" | "LOC" | "Event" | "Concept" | "Time" | "Position" | "WorkRole" | "System" | "Policy" | "HistoricalPeriod" | "HistoricalState" | "HistoricalEvent" | "EconomicFactor" | "Condition" | "Numeric" | "Work"; +// 社区节点 +export interface CommunityTypeNode { + id: string; + label: 'Community'; + properties: { + community_id: string; + end_user_id: string; + member_count: number; + updated_at: string; + name: string; + summary: string; + core_entities: string[]; + member_entity_ids: string[]; + }; +} +// 核心实体 +export interface ExtractedEntityTypeNode { + id: string; + label: 'ExtractedEntity'; + properties: { + name: string; + end_user_id: string; + description: string; + created_at: string; + entity_type: CommunityEntityType; + community_name: string; + }; +} +// 社区图谱连线 +export interface CommunityEdge { + id: string; + target: string; + source: string; +} +export interface CommunityStatistics { + total_nodes: number; + total_edges: number; + node_types: Record; + edge_types: Record; +} +export interface CommunityGraphData { + nodes: (CommunityTypeNode | ExtractedEntityTypeNode)[]; + edges: CommunityEdge[]; + statistics: CommunityStatistics; +}