Merge branch 'develop' into release/v0.2.7
This commit is contained in:
@@ -60,7 +60,12 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = postgresql://user:password@localhost/dbname
|
||||
# Database connection URL - DO NOT hardcode credentials here!
|
||||
# Connection string is set dynamically from environment variables in migrations/env.py
|
||||
# Required env vars: DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME
|
||||
# Example: postgresql://user:password@localhost:5432/dbname
|
||||
; sqlalchemy.url = postgresql://user:password@host:port/dbname
|
||||
sqlalchemy.url = driver://user:password@host:port/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
from redis.asyncio import ConnectionPool
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# 设置日志记录器
|
||||
|
||||
@@ -70,43 +70,44 @@ celery_app.conf.update(
|
||||
# 任务追踪
|
||||
task_track_started=True,
|
||||
task_ignore_result=False,
|
||||
|
||||
|
||||
# 超时设置
|
||||
task_time_limit=3600, # 60分钟硬超时
|
||||
task_soft_time_limit=3000, # 50分钟软超时
|
||||
|
||||
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
|
||||
|
||||
# 结果过期时间
|
||||
result_expires=3600, # 结果保存1小时
|
||||
|
||||
|
||||
# 任务确认设置
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_disable_rate_limits=True,
|
||||
|
||||
|
||||
# FLower setting
|
||||
worker_send_task_events=True,
|
||||
task_send_sent_event=True,
|
||||
|
||||
|
||||
# task routing
|
||||
task_routes={
|
||||
# Memory tasks → memory_tasks queue (threads worker)
|
||||
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'},
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
@@ -115,6 +116,7 @@ celery_app.conf.update(
|
||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -131,7 +133,7 @@ implicit_emotions_update_schedule = crontab(
|
||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
||||
)
|
||||
|
||||
#构建定时任务配置
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
|
||||
@@ -13,4 +13,4 @@ logger.info("Celery worker logging initialized")
|
||||
# 导入任务模块以注册任务
|
||||
import app.tasks
|
||||
|
||||
__all__ = ['celery_app']
|
||||
__all__ = ['celery_app']
|
||||
|
||||
@@ -16,6 +16,7 @@ from . import (
|
||||
file_controller,
|
||||
file_storage_controller,
|
||||
home_page_controller,
|
||||
i18n_controller,
|
||||
implicit_memory_controller,
|
||||
knowledge_controller,
|
||||
knowledgeshare_controller,
|
||||
@@ -94,5 +95,6 @@ manager_router.include_router(memory_working_controller.router)
|
||||
manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
manager_router.include_router(skill_controller.router)
|
||||
manager_router.include_router(i18n_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -53,6 +53,7 @@ def list_apps(
|
||||
status: str | None = None,
|
||||
search: str | None = None,
|
||||
include_shared: bool = True,
|
||||
shared_only: bool = False,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
ids: Optional[str] = None,
|
||||
@@ -84,6 +85,7 @@ def list_apps(
|
||||
status=status,
|
||||
search=search,
|
||||
include_shared=include_shared,
|
||||
shared_only=shared_only,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
)
|
||||
@@ -93,6 +95,37 @@ def list_apps(
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
|
||||
@router.get("/my-shared-out", summary="列出本工作空间主动分享出去的记录")
|
||||
@cur_workspace_access_guard()
|
||||
def list_my_shared_out(
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""列出本工作空间主动分享给其他工作空间的所有记录(我的共享)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
shares = service.list_my_shared_out(workspace_id=workspace_id)
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
return success(data=data)
|
||||
|
||||
|
||||
@router.delete("/share/{target_workspace_id}", summary="取消对某工作空间的所有应用分享")
|
||||
@cur_workspace_access_guard()
|
||||
def unshare_all_apps_to_workspace(
|
||||
target_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""Cancel all app shares from current workspace to a target workspace."""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
count = service.unshare_all_apps_to_workspace(
|
||||
target_workspace_id=target_workspace_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return success(msg=f"已取消 {count} 个应用的分享", data={"count": count})
|
||||
|
||||
|
||||
@router.get("/{app_id}", summary="获取应用详情")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app(
|
||||
@@ -302,7 +335,8 @@ def share_app(
|
||||
app_id=app_id,
|
||||
target_workspace_ids=payload.target_workspace_ids,
|
||||
user_id=current_user.id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
permission=payload.permission
|
||||
)
|
||||
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
@@ -333,6 +367,32 @@ def unshare_app(
|
||||
return success(msg="应用分享已取消")
|
||||
|
||||
|
||||
@router.patch("/{app_id}/share/{target_workspace_id}", summary="更新共享权限")
|
||||
@cur_workspace_access_guard()
|
||||
def update_share_permission(
|
||||
app_id: uuid.UUID,
|
||||
target_workspace_id: uuid.UUID,
|
||||
payload: app_schema.UpdateSharePermissionRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""更新共享权限(readonly <-> editable)
|
||||
|
||||
- 只能修改自己工作空间应用的共享权限
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
service = app_service.AppService(db)
|
||||
share = service.update_share_permission(
|
||||
app_id=app_id,
|
||||
target_workspace_id=target_workspace_id,
|
||||
permission=payload.permission,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return success(data=app_schema.AppShare.model_validate(share))
|
||||
|
||||
|
||||
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
||||
@cur_workspace_access_guard()
|
||||
def list_app_shares(
|
||||
@@ -356,6 +416,46 @@ def list_app_shares(
|
||||
return success(data=data)
|
||||
|
||||
|
||||
@router.delete("/shared/{source_workspace_id}", summary="批量移除某来源工作空间的所有共享应用")
|
||||
@cur_workspace_access_guard()
|
||||
def remove_all_shared_apps_from_workspace(
|
||||
source_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""Remove all shared apps from a specific source workspace (recipient operation)."""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
count = service.remove_all_shared_apps_from_workspace(
|
||||
source_workspace_id=source_workspace_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return success(msg=f"已移除 {count} 个共享应用", data={"count": count})
|
||||
|
||||
|
||||
@router.delete("/{app_id}/shared", summary="移除共享给我的应用")
|
||||
@cur_workspace_access_guard()
|
||||
def remove_shared_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""被共享者从自己的工作空间移除共享应用
|
||||
|
||||
- 不会删除源应用,只删除共享记录
|
||||
- 只能移除共享给自己工作空间的应用
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
service = app_service.AppService(db)
|
||||
service.remove_shared_app(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return success(msg="已移除共享应用")
|
||||
|
||||
|
||||
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
||||
@cur_workspace_access_guard()
|
||||
async def draft_run(
|
||||
@@ -744,6 +844,15 @@ async def draft_run_compare(
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
other_id=str(current_user.id),
|
||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
from sqlalchemy import select
|
||||
from app.models import AgentConfig
|
||||
@@ -789,6 +898,8 @@ async def draft_run_compare(
|
||||
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||
})
|
||||
|
||||
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
@@ -800,7 +911,7 @@ async def draft_run_compare(
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
@@ -831,7 +942,7 @@ async def draft_run_compare(
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -16,6 +17,7 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.dependencies import get_current_user, oauth2_scheme
|
||||
from app.models.user_model import User
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
# 获取专用日志器
|
||||
auth_logger = get_auth_logger()
|
||||
@@ -26,7 +28,8 @@ router = APIRouter(tags=["Authentication"])
|
||||
@router.post("/token", response_model=ApiResponse)
|
||||
async def login_for_access_token(
|
||||
form_data: TokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""用户登录获取token"""
|
||||
auth_logger.info(f"用户登录请求: {form_data.email}")
|
||||
@@ -40,10 +43,10 @@ async def login_for_access_token(
|
||||
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
||||
|
||||
if not invite_info.is_valid:
|
||||
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
|
||||
raise BusinessException(t("auth.invite.invalid"), code=BizCode.BAD_REQUEST)
|
||||
|
||||
if invite_info.email != form_data.email:
|
||||
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
|
||||
raise BusinessException(t("auth.invite.email_mismatch"), code=BizCode.BAD_REQUEST)
|
||||
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
||||
try:
|
||||
# 尝试认证用户
|
||||
@@ -69,7 +72,7 @@ async def login_for_access_token(
|
||||
elif e.code == BizCode.PASSWORD_ERROR:
|
||||
# 用户存在但密码错误
|
||||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
|
||||
raise BusinessException(t("auth.invite.password_verification_failed"), BizCode.LOGIN_FAILED)
|
||||
else:
|
||||
# 其他认证失败情况,直接抛出
|
||||
raise
|
||||
@@ -82,7 +85,7 @@ async def login_for_access_token(
|
||||
except BusinessException as e:
|
||||
|
||||
# 其他认证失败情况,直接抛出
|
||||
raise BusinessException(e.message,BizCode.LOGIN_FAILED)
|
||||
raise BusinessException(e.message, BizCode.LOGIN_FAILED)
|
||||
|
||||
# 创建 tokens
|
||||
access_token, access_token_id = security.create_access_token(subject=user.id)
|
||||
@@ -110,14 +113,15 @@ async def login_for_access_token(
|
||||
expires_at=access_expires_at,
|
||||
refresh_expires_at=refresh_expires_at
|
||||
),
|
||||
msg="登录成功"
|
||||
msg=t("auth.login.success")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=ApiResponse)
|
||||
async def refresh_token(
|
||||
refresh_request: RefreshTokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""刷新token"""
|
||||
auth_logger.info("收到token刷新请求")
|
||||
@@ -125,18 +129,18 @@ async def refresh_token(
|
||||
# 验证 refresh token
|
||||
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
||||
if not userId:
|
||||
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
|
||||
raise BusinessException(t("auth.token.invalid_refresh_token"), code=BizCode.TOKEN_INVALID)
|
||||
|
||||
# 检查用户是否存在
|
||||
user = auth_service.get_user_by_id(db, userId)
|
||||
if not user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 检查 refresh token 黑名单
|
||||
if settings.ENABLE_SINGLE_SESSION:
|
||||
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||||
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
||||
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
|
||||
raise BusinessException(t("auth.token.refresh_token_blacklisted"), code=BizCode.TOKEN_BLACKLISTED)
|
||||
|
||||
# 生成新 tokens
|
||||
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
||||
@@ -167,7 +171,7 @@ async def refresh_token(
|
||||
expires_at=access_expires_at,
|
||||
refresh_expires_at=refresh_expires_at
|
||||
),
|
||||
msg="token刷新成功"
|
||||
msg=t("auth.token.refresh_success")
|
||||
)
|
||||
|
||||
|
||||
@@ -175,14 +179,15 @@ async def refresh_token(
|
||||
async def logout(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""登出当前用户:加入token黑名单并清理会话"""
|
||||
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
||||
|
||||
token_id = security.get_token_id(token)
|
||||
if not token_id:
|
||||
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
|
||||
raise BusinessException(t("auth.token.invalid"), code=BizCode.TOKEN_INVALID)
|
||||
|
||||
# 加入黑名单
|
||||
await SessionService.blacklist_token(token_id)
|
||||
@@ -192,5 +197,5 @@ async def logout(
|
||||
await SessionService.clear_user_session(current_user.username)
|
||||
|
||||
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
||||
return success(msg="登出成功")
|
||||
return success(msg=t("auth.logout.success"))
|
||||
|
||||
|
||||
833
api/app/controllers/i18n_controller.py
Normal file
833
api/app/controllers/i18n_controller.py
Normal file
@@ -0,0 +1,833 @@
|
||||
"""
|
||||
I18n Management API Controller
|
||||
|
||||
This module provides management APIs for:
|
||||
- Language management (list, get, add, update languages)
|
||||
- Translation management (get, update, reload translations)
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Callable, Optional
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, get_current_superuser
|
||||
from app.i18n.dependencies import get_translator
|
||||
from app.i18n.service import get_translation_service
|
||||
from app.models.user_model import User
|
||||
from app.schemas.i18n_schema import (
|
||||
LanguageInfo,
|
||||
LanguageListResponse,
|
||||
LanguageCreateRequest,
|
||||
LanguageUpdateRequest,
|
||||
TranslationResponse,
|
||||
TranslationUpdateRequest,
|
||||
MissingTranslationsResponse,
|
||||
ReloadResponse
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/i18n",
|
||||
tags=["I18n Management"],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Language Management APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/languages", response_model=ApiResponse)
|
||||
def get_languages(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get list of all supported languages.
|
||||
|
||||
Returns:
|
||||
List of language information including code, name, and status
|
||||
"""
|
||||
api_logger.info(f"Get languages request from user: {current_user.username}")
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Get available locales from translation service
|
||||
available_locales = translation_service.get_available_locales()
|
||||
|
||||
# Build language info list
|
||||
languages = []
|
||||
for locale in available_locales:
|
||||
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||
|
||||
# Get native names
|
||||
native_names = {
|
||||
"zh": "中文(简体)",
|
||||
"en": "English",
|
||||
"ja": "日本語",
|
||||
"ko": "한국어",
|
||||
"fr": "Français",
|
||||
"de": "Deutsch",
|
||||
"es": "Español"
|
||||
}
|
||||
|
||||
language_info = LanguageInfo(
|
||||
code=locale,
|
||||
name=f"{locale.upper()}",
|
||||
native_name=native_names.get(locale, locale),
|
||||
is_enabled=is_enabled,
|
||||
is_default=is_default
|
||||
)
|
||||
languages.append(language_info)
|
||||
|
||||
response = LanguageListResponse(languages=languages)
|
||||
|
||||
api_logger.info(f"Returning {len(languages)} languages")
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/languages/{locale}", response_model=ApiResponse)
|
||||
def get_language(
|
||||
locale: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get information about a specific language.
|
||||
|
||||
Args:
|
||||
locale: Language code (e.g., 'zh', 'en')
|
||||
|
||||
Returns:
|
||||
Language information
|
||||
"""
|
||||
api_logger.info(f"Get language info request: locale={locale}, user={current_user.username}")
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Build language info
|
||||
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||
|
||||
native_names = {
|
||||
"zh": "中文(简体)",
|
||||
"en": "English",
|
||||
"ja": "日本語",
|
||||
"ko": "한국어",
|
||||
"fr": "Français",
|
||||
"de": "Deutsch",
|
||||
"es": "Español"
|
||||
}
|
||||
|
||||
language_info = LanguageInfo(
|
||||
code=locale,
|
||||
name=f"{locale.upper()}",
|
||||
native_name=native_names.get(locale, locale),
|
||||
is_enabled=is_enabled,
|
||||
is_default=is_default
|
||||
)
|
||||
|
||||
api_logger.info(f"Returning language info for: {locale}")
|
||||
return success(data=language_info.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/languages", response_model=ApiResponse)
|
||||
def add_language(
|
||||
request: LanguageCreateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Add a new language (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual language addition
|
||||
requires creating translation files in the locales directory.
|
||||
|
||||
Args:
|
||||
request: Language creation request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Add language request: code={request.code}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if language already exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if request.code in available_locales:
|
||||
api_logger.warning(f"Language already exists: {request.code}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=t("i18n.language.already_exists", locale=request.code)
|
||||
)
|
||||
|
||||
# Note: Actual language addition requires creating translation files
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Language addition validated: {request.code}. "
|
||||
"Translation files need to be created manually."
|
||||
)
|
||||
|
||||
return success(
|
||||
msg=t(
|
||||
"i18n.language.add_instructions",
|
||||
locale=request.code,
|
||||
dir=settings.I18N_CORE_LOCALES_DIR
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.put("/languages/{locale}", response_model=ApiResponse)
|
||||
def update_language(
|
||||
locale: str,
|
||||
request: LanguageUpdateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Update language configuration (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual configuration
|
||||
changes require updating environment variables or config files.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
request: Language update request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Update language request: locale={locale}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if language exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Note: Actual configuration changes require updating settings
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Language update validated: {locale}. "
|
||||
"Configuration changes require environment variable updates."
|
||||
)
|
||||
|
||||
return success(msg=t("i18n.language.update_instructions", locale=locale))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Translation Management APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/translations", response_model=ApiResponse)
|
||||
def get_all_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all translations for all or specific locale.
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
All translations organized by locale and namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get all translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
if locale:
|
||||
# Get translations for specific locale
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
translations = {
|
||||
locale: translation_service._cache.get(locale, {})
|
||||
}
|
||||
else:
|
||||
# Get all translations
|
||||
translations = translation_service._cache
|
||||
|
||||
response = TranslationResponse(translations=translations)
|
||||
|
||||
api_logger.info(f"Returning translations for: {locale or 'all locales'}")
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/translations/{locale}", response_model=ApiResponse)
|
||||
def get_locale_translations(
|
||||
locale: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all translations for a specific locale.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
|
||||
Returns:
|
||||
All translations for the locale organized by namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get locale translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
translations = translation_service._cache.get(locale, {})
|
||||
|
||||
api_logger.info(f"Returning {len(translations)} namespaces for locale: {locale}")
|
||||
return success(data={"locale": locale, "translations": translations}, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/translations/{locale}/{namespace}", response_model=ApiResponse)
|
||||
def get_namespace_translations(
|
||||
locale: str,
|
||||
namespace: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get translations for a specific namespace in a locale.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
namespace: Translation namespace (e.g., 'common', 'auth')
|
||||
|
||||
Returns:
|
||||
Translations for the specified namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get namespace translations request: locale={locale}, "
|
||||
f"namespace={namespace}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Get namespace translations
|
||||
locale_translations = translation_service._cache.get(locale, {})
|
||||
namespace_translations = locale_translations.get(namespace, {})
|
||||
|
||||
if not namespace_translations:
|
||||
api_logger.warning(f"Namespace not found: {namespace} in locale: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.namespace.not_found", namespace=namespace, locale=locale)
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Returning translations for namespace: {namespace} in locale: {locale}"
|
||||
)
|
||||
return success(
|
||||
data={
|
||||
"locale": locale,
|
||||
"namespace": namespace,
|
||||
"translations": namespace_translations
|
||||
},
|
||||
msg=t("common.success.retrieved")
|
||||
)
|
||||
|
||||
|
||||
@router.put("/translations/{locale}/{key:path}", response_model=ApiResponse)
|
||||
def update_translation(
|
||||
locale: str,
|
||||
key: str,
|
||||
request: TranslationUpdateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Update a single translation (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual translation updates
|
||||
require modifying translation files in the locales directory.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
key: Translation key (format: "namespace.key.subkey")
|
||||
request: Translation update request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Update translation request: locale={locale}, key={key}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Validate key format
|
||||
if "." not in key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=t("i18n.translation.invalid_key_format", key=key)
|
||||
)
|
||||
|
||||
# Note: Actual translation updates require modifying JSON files
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Translation update validated: {locale}/{key}. "
|
||||
"Translation files need to be updated manually."
|
||||
)
|
||||
|
||||
return success(
|
||||
msg=t("i18n.translation.update_instructions", locale=locale, key=key)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/translations/missing", response_model=ApiResponse)
|
||||
def get_missing_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get list of missing translations.
|
||||
|
||||
Compares translations across locales to find missing keys.
|
||||
|
||||
Args:
|
||||
locale: Optional locale to check (defaults to checking all non-default locales)
|
||||
|
||||
Returns:
|
||||
List of missing translation keys
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get missing translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
default_locale = settings.I18N_DEFAULT_LANGUAGE
|
||||
available_locales = translation_service.get_available_locales()
|
||||
|
||||
# Get default locale translations as reference
|
||||
default_translations = translation_service._cache.get(default_locale, {})
|
||||
|
||||
# Collect all keys from default locale
|
||||
def collect_keys(data, prefix=""):
|
||||
keys = []
|
||||
for key, value in data.items():
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, dict):
|
||||
keys.extend(collect_keys(value, full_key))
|
||||
else:
|
||||
keys.append(full_key)
|
||||
return keys
|
||||
|
||||
default_keys = set()
|
||||
for namespace, translations in default_translations.items():
|
||||
namespace_keys = collect_keys(translations, namespace)
|
||||
default_keys.update(namespace_keys)
|
||||
|
||||
# Find missing keys in target locale(s)
|
||||
missing_by_locale = {}
|
||||
|
||||
target_locales = [locale] if locale else [
|
||||
loc for loc in available_locales if loc != default_locale
|
||||
]
|
||||
|
||||
for target_locale in target_locales:
|
||||
if target_locale not in available_locales:
|
||||
continue
|
||||
|
||||
target_translations = translation_service._cache.get(target_locale, {})
|
||||
target_keys = set()
|
||||
|
||||
for namespace, translations in target_translations.items():
|
||||
namespace_keys = collect_keys(translations, namespace)
|
||||
target_keys.update(namespace_keys)
|
||||
|
||||
missing_keys = default_keys - target_keys
|
||||
if missing_keys:
|
||||
missing_by_locale[target_locale] = sorted(list(missing_keys))
|
||||
|
||||
response = MissingTranslationsResponse(missing_translations=missing_by_locale)
|
||||
|
||||
total_missing = sum(len(keys) for keys in missing_by_locale.values())
|
||||
api_logger.info(f"Found {total_missing} missing translations across {len(missing_by_locale)} locales")
|
||||
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/reload", response_model=ApiResponse)
|
||||
def reload_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Trigger hot reload of translation files (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale to reload (defaults to reloading all locales)
|
||||
|
||||
Returns:
|
||||
Reload status and statistics
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Reload translations request: locale={locale or 'all'}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
if not settings.I18N_ENABLE_HOT_RELOAD:
|
||||
api_logger.warning("Hot reload is disabled in configuration")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=t("i18n.reload.disabled")
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
try:
|
||||
# Reload translations
|
||||
translation_service.reload(locale)
|
||||
|
||||
# Get statistics
|
||||
available_locales = translation_service.get_available_locales()
|
||||
reloaded_locales = [locale] if locale else available_locales
|
||||
|
||||
response = ReloadResponse(
|
||||
success=True,
|
||||
reloaded_locales=reloaded_locales,
|
||||
total_locales=len(available_locales)
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Successfully reloaded translations for: {', '.join(reloaded_locales)}"
|
||||
)
|
||||
|
||||
return success(data=response.dict(), msg=t("i18n.reload.success"))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to reload translations: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=t("i18n.reload.failed", error=str(e))
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Performance Monitoring APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/metrics", response_model=ApiResponse)
|
||||
def get_metrics(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get i18n performance metrics (admin only).
|
||||
|
||||
Returns:
|
||||
Performance metrics including:
|
||||
- Request counts
|
||||
- Missing translations
|
||||
- Timing statistics
|
||||
- Locale usage
|
||||
- Error counts
|
||||
"""
|
||||
api_logger.info(f"Get metrics request: admin={current_user.username}")
|
||||
|
||||
translation_service = get_translation_service()
|
||||
metrics = translation_service.get_metrics_summary()
|
||||
|
||||
api_logger.info("Returning i18n metrics")
|
||||
return success(data=metrics, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/metrics/cache", response_model=ApiResponse)
|
||||
def get_cache_stats(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get cache statistics (admin only).
|
||||
|
||||
Returns:
|
||||
Cache statistics including:
|
||||
- Hit/miss rates
|
||||
- LRU cache performance
|
||||
- Loaded locales
|
||||
- Memory usage
|
||||
"""
|
||||
api_logger.info(f"Get cache stats request: admin={current_user.username}")
|
||||
|
||||
translation_service = get_translation_service()
|
||||
cache_stats = translation_service.get_cache_stats()
|
||||
memory_usage = translation_service.get_memory_usage()
|
||||
|
||||
data = {
|
||||
"cache": cache_stats,
|
||||
"memory": memory_usage
|
||||
}
|
||||
|
||||
api_logger.info("Returning cache statistics")
|
||||
return success(data=data, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/metrics/prometheus")
|
||||
def get_prometheus_metrics(
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get metrics in Prometheus format (admin only).
|
||||
|
||||
Returns:
|
||||
Prometheus-formatted metrics as plain text
|
||||
"""
|
||||
api_logger.info(f"Get Prometheus metrics request: admin={current_user.username}")
|
||||
|
||||
from app.i18n.metrics import get_metrics
|
||||
metrics = get_metrics()
|
||||
prometheus_output = metrics.export_prometheus()
|
||||
|
||||
from fastapi.responses import PlainTextResponse
|
||||
return PlainTextResponse(content=prometheus_output)
|
||||
|
||||
|
||||
@router.post("/metrics/reset", response_model=ApiResponse)
|
||||
def reset_metrics(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Reset all metrics (admin only).
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(f"Reset metrics request: admin={current_user.username}")
|
||||
|
||||
from app.i18n.metrics import get_metrics
|
||||
metrics = get_metrics()
|
||||
metrics.reset()
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_service.cache.reset_stats()
|
||||
|
||||
api_logger.info("Metrics reset completed")
|
||||
return success(msg=t("i18n.metrics.reset_success"))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Missing Translation Logging and Reporting APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/logs/missing", response_model=ApiResponse)
|
||||
def get_missing_translation_logs(
|
||||
locale: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get missing translation logs (admin only).
|
||||
|
||||
Returns logged missing translations with context information.
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
limit: Maximum number of entries to return (default: 100)
|
||||
|
||||
Returns:
|
||||
Missing translation logs with context
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get missing translation logs request: locale={locale}, "
|
||||
f"limit={limit}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Get missing translations
|
||||
missing_translations = translation_logger.get_missing_translations(locale)
|
||||
|
||||
# Get missing with context
|
||||
missing_with_context = translation_logger.get_missing_with_context(locale, limit)
|
||||
|
||||
# Get statistics
|
||||
statistics = translation_logger.get_statistics()
|
||||
|
||||
data = {
|
||||
"missing_translations": missing_translations,
|
||||
"recent_context": missing_with_context,
|
||||
"statistics": statistics
|
||||
}
|
||||
|
||||
api_logger.info(
|
||||
f"Returning {statistics['total_missing']} missing translations"
|
||||
)
|
||||
return success(data=data, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/logs/missing/report", response_model=ApiResponse)
|
||||
def generate_missing_translation_report(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Generate a comprehensive missing translation report (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
Comprehensive report with missing translations and statistics
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Generate missing translation report request: locale={locale}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Generate report
|
||||
report = translation_logger.generate_report(locale)
|
||||
|
||||
api_logger.info(
|
||||
f"Generated report with {report['total_missing']} missing translations"
|
||||
)
|
||||
return success(data=report, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/logs/missing/export", response_model=ApiResponse)
|
||||
def export_missing_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Export missing translations to JSON file (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
Export status and file path
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Export missing translations request: locale={locale}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
from datetime import datetime
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
locale_suffix = f"_{locale}" if locale else "_all"
|
||||
output_file = f"logs/i18n/missing_translations{locale_suffix}_{timestamp}.json"
|
||||
|
||||
# Export to file
|
||||
translation_logger.export_to_json(output_file)
|
||||
|
||||
api_logger.info(f"Missing translations exported to: {output_file}")
|
||||
return success(
|
||||
data={"file_path": output_file},
|
||||
msg=t("i18n.logs.export_success", file=output_file)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/logs/missing", response_model=ApiResponse)
|
||||
def clear_missing_translation_logs(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Clear missing translation logs (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale to clear (clears all if not specified)
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Clear missing translation logs request: locale={locale or 'all'}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Clear logs
|
||||
translation_logger.clear(locale)
|
||||
|
||||
api_logger.info(f"Cleared missing translation logs for: {locale or 'all locales'}")
|
||||
return success(msg=t("i18n.logs.clear_success"))
|
||||
@@ -19,7 +19,7 @@ from app.models import mcp_market_config_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import mcp_market_config_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import mcp_market_config_service
|
||||
from app.services import mcp_market_config_service, mcp_market_service
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -123,6 +123,17 @@ async def get_mcp_servers(
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
# 5. Update mck_market.mcp_count
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=db_mcp_market_config.mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={db_mcp_market_config.mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or access is denied"
|
||||
)
|
||||
db_mcp_market.mcp_count = total
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market)
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@@ -265,6 +276,30 @@ async def create_mcp_market_config(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
|
||||
)
|
||||
# 2. verify token
|
||||
create_data.status = 1
|
||||
try:
|
||||
api = MCPApi()
|
||||
token = create_data.token
|
||||
api.login(token)
|
||||
|
||||
body = {
|
||||
'filter': {},
|
||||
'page_number': 1,
|
||||
'page_size': 20,
|
||||
'search': ""
|
||||
}
|
||||
cookies = api.get_cookies(token)
|
||||
r = api.session.put(
|
||||
url=api.mcp_base_url,
|
||||
headers=api.builder_headers(api.headers),
|
||||
json=body,
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
||||
create_data.status = 0
|
||||
# 3. create mcp_market_config
|
||||
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
|
||||
api_logger.info(
|
||||
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
|
||||
@@ -395,7 +430,7 @@ async def update_mcp_market_config(
|
||||
detail=f"The mcp market config update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return the updated mcp market config
|
||||
# 5. Return the updated mcp market config
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="The mcp market config information updated successfully")
|
||||
|
||||
|
||||
@@ -193,7 +193,16 @@ async def get_workspace_end_users(
|
||||
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
|
||||
|
||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||
# 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类
|
||||
try:
|
||||
from app.tasks import init_community_clustering_for_users
|
||||
init_community_clustering_for_users.delay(end_user_ids=end_user_ids)
|
||||
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
"""
|
||||
Memory Reflection Controller
|
||||
|
||||
This module provides REST API endpoints for managing memory reflection configurations
|
||||
and operations. It handles reflection engine setup, configuration management, and
|
||||
execution of self-reflection processes across memory systems.
|
||||
|
||||
Key Features:
|
||||
- Reflection configuration management (save, retrieve, update)
|
||||
- Workspace-wide reflection execution across multiple applications
|
||||
- Individual configuration-based reflection runs
|
||||
- Multi-language support for reflection outputs
|
||||
- Integration with Neo4j memory storage and LLM models
|
||||
- Comprehensive error handling and logging
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
@@ -28,9 +44,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# Load environment variables for configuration
|
||||
load_dotenv()
|
||||
|
||||
# Initialize API logger for request tracking and debugging
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Configure router with prefix and tags for API organization
|
||||
router = APIRouter(
|
||||
prefix="/memory",
|
||||
tags=["Memory"],
|
||||
@@ -43,7 +63,38 @@ async def save_reflection_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Save reflection configuration to data_comfig table"""
|
||||
"""
|
||||
Save reflection configuration to memory config table
|
||||
|
||||
Persists reflection engine configuration settings to the data_config table,
|
||||
including reflection parameters, model settings, and evaluation criteria.
|
||||
Validates configuration parameters and ensures data consistency.
|
||||
|
||||
Args:
|
||||
request: Memory reflection configuration data including:
|
||||
- config_id: Configuration identifier to update
|
||||
- reflection_enabled: Whether reflection is enabled
|
||||
- reflection_period_in_hours: Reflection execution interval
|
||||
- reflexion_range: Scope of reflection (partial/all)
|
||||
- baseline: Reflection strategy (time/fact/hybrid)
|
||||
- reflection_model_id: LLM model for reflection operations
|
||||
- memory_verify: Enable memory verification checks
|
||||
- quality_assessment: Enable quality assessment evaluation
|
||||
current_user: Authenticated user saving the configuration
|
||||
db: Database session for data operations
|
||||
|
||||
Returns:
|
||||
dict: Success response with saved reflection configuration data
|
||||
|
||||
Raises:
|
||||
HTTPException 400: If config_id is missing or parameters are invalid
|
||||
HTTPException 500: If configuration save operation fails
|
||||
|
||||
Database Operations:
|
||||
- Updates memory_config table with reflection settings
|
||||
- Commits transaction and refreshes entity
|
||||
- Maintains configuration consistency
|
||||
"""
|
||||
try:
|
||||
config_id = request.config_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
@@ -54,6 +105,7 @@ async def save_reflection_config(
|
||||
)
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
# Update reflection configuration in database
|
||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||
db,
|
||||
config_id=config_id,
|
||||
@@ -66,6 +118,7 @@ async def save_reflection_config(
|
||||
quality_assessment=request.quality_assessment
|
||||
)
|
||||
|
||||
# Commit transaction and refresh entity
|
||||
db.commit()
|
||||
db.refresh(memory_config)
|
||||
|
||||
@@ -102,13 +155,55 @@ async def start_workspace_reflection(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""启动工作空间中所有匹配应用的反思功能"""
|
||||
"""
|
||||
Start reflection functionality for all matching applications in workspace
|
||||
|
||||
Initiates reflection processes across all applications within the user's current
|
||||
workspace that have valid memory configurations. Processes each application's
|
||||
configurations and associated end users, executing reflection operations
|
||||
with proper error isolation and transaction management.
|
||||
|
||||
This endpoint serves as a workspace-wide reflection orchestrator, ensuring
|
||||
that reflection failures for individual users don't affect other operations.
|
||||
|
||||
Args:
|
||||
current_user: Authenticated user initiating workspace reflection
|
||||
db: Database session for configuration queries
|
||||
|
||||
Returns:
|
||||
dict: Success response with reflection results for all processed applications:
|
||||
- app_id: Application identifier
|
||||
- config_id: Memory configuration identifier
|
||||
- end_user_id: End user identifier
|
||||
- reflection_result: Individual reflection operation result
|
||||
|
||||
Processing Logic:
|
||||
1. Retrieve all applications in the current workspace
|
||||
2. Filter applications with valid memory configurations
|
||||
3. For each configuration, find matching releases
|
||||
4. Execute reflection for each end user with isolated transactions
|
||||
5. Aggregate results with error handling per user
|
||||
|
||||
Error Handling:
|
||||
- Individual user reflection failures are isolated
|
||||
- Failed operations are logged and included in results
|
||||
- Database transactions are isolated per user to prevent cascading failures
|
||||
- Comprehensive error reporting for debugging
|
||||
|
||||
Raises:
|
||||
HTTPException 500: If workspace reflection initialization fails
|
||||
|
||||
Performance Notes:
|
||||
- Uses independent database sessions for each user operation
|
||||
- Prevents transaction failures from affecting other users
|
||||
- Comprehensive logging for operation tracking
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||
|
||||
# 使用独立的数据库会话来获取工作空间应用详情,避免事务失败
|
||||
# Use independent database session to get workspace app details, avoiding transaction failures
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as query_db:
|
||||
service = WorkspaceAppService(query_db)
|
||||
@@ -116,8 +211,9 @@ async def start_workspace_reflection(
|
||||
|
||||
reflection_results = []
|
||||
|
||||
# Process each application in the workspace
|
||||
for data in result['apps_detailed_info']:
|
||||
# 跳过没有配置的应用
|
||||
# Skip applications without configurations
|
||||
if not data['memory_configs']:
|
||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||
continue
|
||||
@@ -126,22 +222,22 @@ async def start_workspace_reflection(
|
||||
memory_configs = data['memory_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
# 为每个配置和用户组合执行反思
|
||||
# Execute reflection for each configuration and user combination
|
||||
for config in memory_configs:
|
||||
config_id_str = str(config['config_id'])
|
||||
|
||||
# 找到匹配此配置的所有release
|
||||
# Find all releases matching this configuration
|
||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||
|
||||
if not matching_releases:
|
||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||
continue
|
||||
|
||||
# 为每个用户执行反思 - 使用独立的数据库会话
|
||||
# Execute reflection for each user - using independent database sessions
|
||||
for user in end_users:
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||
|
||||
# 为每个用户创建独立的数据库会话,避免事务失败影响其他用户
|
||||
# Create independent database session for each user to avoid transaction failure impact
|
||||
with get_db_context() as user_db:
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(user_db)
|
||||
@@ -184,14 +280,51 @@ async def start_reflection_configs(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||
"""
|
||||
Query reflection configuration information by config_id
|
||||
|
||||
Retrieves detailed reflection configuration settings from the memory_config
|
||||
table for a specific configuration ID. Provides comprehensive reflection
|
||||
parameters including model settings, evaluation criteria, and operational flags.
|
||||
|
||||
Args:
|
||||
config_id: Configuration identifier (UUID or integer) to query
|
||||
current_user: Authenticated user making the request
|
||||
db: Database session for data operations
|
||||
|
||||
Returns:
|
||||
dict: Success response with detailed reflection configuration:
|
||||
- config_id: Resolved configuration identifier
|
||||
- reflection_enabled: Whether reflection is enabled for this config
|
||||
- reflection_period_in_hours: Reflection execution interval
|
||||
- reflexion_range: Scope of reflection operations (partial/all)
|
||||
- baseline: Reflection strategy (time/fact/hybrid)
|
||||
- reflection_model_id: LLM model identifier for reflection
|
||||
- memory_verify: Memory verification flag
|
||||
- quality_assessment: Quality assessment flag
|
||||
|
||||
Database Operations:
|
||||
- Queries memory_config table by resolved config_id
|
||||
- Retrieves all reflection-related configuration fields
|
||||
- Resolves configuration ID for consistent formatting
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If configuration with specified ID is not found
|
||||
HTTPException 500: If configuration query operation fails
|
||||
|
||||
ID Resolution:
|
||||
- Supports both UUID and integer config_id formats
|
||||
- Automatically resolves to appropriate internal format
|
||||
- Maintains consistency across different ID representations
|
||||
"""
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
try:
|
||||
config_id=resolve_config_id(config_id,db)
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
memory_config_id = resolve_config_id(result.config_id, db)
|
||||
# 构建返回数据
|
||||
|
||||
# Build response data with comprehensive configuration details
|
||||
reflection_config = {
|
||||
"config_id": memory_config_id,
|
||||
"reflection_enabled": result.enable_self_reflexion,
|
||||
@@ -204,10 +337,12 @@ async def start_reflection_configs(
|
||||
}
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="反思配置查询成功")
|
||||
|
||||
|
||||
api_logger.info(f"Successfully queried reflection config, config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="Reflection configuration query successful")
|
||||
|
||||
except HTTPException:
|
||||
# 重新抛出HTTP异常
|
||||
# Re-raise HTTP exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"查询反思配置失败: {str(e)}")
|
||||
@@ -223,13 +358,66 @@ async def reflection_run(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
# 使用集中化的语言校验
|
||||
"""
|
||||
Execute reflection engine with specified configuration
|
||||
|
||||
Runs the reflection engine using configuration parameters from the database.
|
||||
Validates model availability, sets up the reflection engine with proper
|
||||
configuration, and executes the reflection process with multi-language support.
|
||||
|
||||
This endpoint provides a test run capability for reflection configurations,
|
||||
allowing users to validate their reflection settings and see results before
|
||||
deploying to production environments.
|
||||
|
||||
Args:
|
||||
config_id: Configuration identifier (UUID or integer) for reflection settings
|
||||
language_type: Language preference header for output localization (optional)
|
||||
current_user: Authenticated user executing the reflection
|
||||
db: Database session for configuration queries
|
||||
|
||||
Returns:
|
||||
dict: Success response with reflection execution results including:
|
||||
- baseline: Reflection strategy used
|
||||
- source_data: Input data processed
|
||||
- memory_verifies: Memory verification results (if enabled)
|
||||
- quality_assessments: Quality assessment results (if enabled)
|
||||
- reflexion_data: Generated reflection insights and solutions
|
||||
|
||||
Configuration Validation:
|
||||
- Verifies configuration exists in database
|
||||
- Validates LLM model availability
|
||||
- Falls back to default model if specified model is unavailable
|
||||
- Ensures all required parameters are properly set
|
||||
|
||||
Reflection Engine Setup:
|
||||
- Creates ReflectionConfig with database parameters
|
||||
- Initializes Neo4j connector for memory access
|
||||
- Sets up ReflectionEngine with validated model
|
||||
- Configures language preferences for output
|
||||
|
||||
Error Handling:
|
||||
- Model validation with fallback to default
|
||||
- Configuration validation and error reporting
|
||||
- Comprehensive logging for debugging
|
||||
- Graceful handling of missing configurations
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If configuration is not found
|
||||
HTTPException 500: If reflection execution fails
|
||||
|
||||
Performance Notes:
|
||||
- Direct database query for configuration retrieval
|
||||
- Model validation to prevent runtime failures
|
||||
- Efficient reflection engine initialization
|
||||
- Language-aware output processing
|
||||
"""
|
||||
# Use centralized language validation for consistent localization
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 使用MemoryConfigRepository查询反思配置
|
||||
|
||||
# Query reflection configuration using MemoryConfigRepository
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
@@ -239,7 +427,7 @@ async def reflection_run(
|
||||
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 验证模型ID是否存在
|
||||
# Validate model ID existence
|
||||
model_id = result.reflection_model_id
|
||||
if model_id:
|
||||
try:
|
||||
@@ -250,6 +438,7 @@ async def reflection_run(
|
||||
# 可以设置为None,让反思引擎使用默认模型
|
||||
model_id = None
|
||||
|
||||
# Create reflection configuration with database parameters
|
||||
config = ReflectionConfig(
|
||||
enabled=result.enable_self_reflexion,
|
||||
iteration_period=result.iteration_period,
|
||||
@@ -262,11 +451,13 @@ async def reflection_run(
|
||||
model_id=model_id,
|
||||
language_type=language_type
|
||||
)
|
||||
|
||||
# Initialize Neo4j connector and reflection engine
|
||||
connector = Neo4jConnector()
|
||||
engine = ReflectionEngine(
|
||||
config=config,
|
||||
neo4j_connector=connector,
|
||||
llm_client=model_id # 传入验证后的 model_id
|
||||
llm_client=model_id # Pass validated model_id
|
||||
)
|
||||
|
||||
result=await (engine.reflection_run())
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
"""
|
||||
Memory Short Term Controller
|
||||
|
||||
This module provides REST API endpoints for managing short-term and long-term memory
|
||||
data retrieval and analysis. It handles memory system statistics, data aggregation,
|
||||
and provides comprehensive memory insights for end users.
|
||||
|
||||
Key Features:
|
||||
- Short-term memory data retrieval and statistics
|
||||
- Long-term memory data aggregation
|
||||
- Entity count integration
|
||||
- Multi-language response support
|
||||
- Memory system analytics and reporting
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -13,9 +28,13 @@ from app.models.user_model import User
|
||||
from app.services.memory_short_service import LongService, ShortService
|
||||
from app.services.memory_storage_service import search_entity
|
||||
|
||||
# Load environment variables for configuration
|
||||
load_dotenv()
|
||||
|
||||
# Initialize API logger for request tracking and debugging
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Configure router with prefix and tags for API organization
|
||||
router = APIRouter(
|
||||
prefix="/memory/short",
|
||||
tags=["Memory"],
|
||||
@@ -27,24 +46,73 @@ async def short_term_configs(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
"""
|
||||
Retrieve comprehensive short-term and long-term memory statistics
|
||||
|
||||
Provides a comprehensive overview of memory system data for a specific end user,
|
||||
including short-term memory entries, long-term memory aggregations, entity counts,
|
||||
and retrieval statistics. Supports multi-language responses based on request headers.
|
||||
|
||||
This endpoint serves as a central dashboard for memory system analytics, combining
|
||||
data from multiple memory subsystems to provide a holistic view of user memory state.
|
||||
|
||||
Args:
|
||||
end_user_id: Unique identifier for the end user whose memory data to retrieve
|
||||
language_type: Language preference header for response localization (optional)
|
||||
current_user: Authenticated user making the request (injected by dependency)
|
||||
db: Database session for data operations (injected by dependency)
|
||||
|
||||
Returns:
|
||||
dict: Success response containing comprehensive memory statistics:
|
||||
- short_term: List of short-term memory entries with detailed data
|
||||
- long_term: List of long-term memory aggregations and summaries
|
||||
- entity: Count of entities associated with the end user
|
||||
- retrieval_number: Total count of short-term memory retrievals
|
||||
- long_term_number: Total count of long-term memory entries
|
||||
|
||||
Response Structure:
|
||||
{
|
||||
"code": 200,
|
||||
"msg": "Short-term memory system data retrieved successfully",
|
||||
"data": {
|
||||
"short_term": [...], # Short-term memory entries
|
||||
"long_term": [...], # Long-term memory data
|
||||
"entity": 42, # Entity count
|
||||
"retrieval_number": 156, # Short-term retrieval count
|
||||
"long_term_number": 23 # Long-term memory count
|
||||
}
|
||||
}
|
||||
|
||||
Raises:
|
||||
HTTPException: If end_user_id is invalid or data retrieval fails
|
||||
|
||||
Performance Notes:
|
||||
- Combines multiple service calls for comprehensive data
|
||||
- Entity search is performed asynchronously for better performance
|
||||
- Response time depends on memory data volume for the specified user
|
||||
"""
|
||||
# Use centralized language validation for consistent localization
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# 获取短期记忆数据
|
||||
short_term=ShortService(end_user_id, db)
|
||||
short_result=short_term.get_short_databasets()
|
||||
short_count=short_term.get_short_count()
|
||||
# Retrieve short-term memory data and statistics
|
||||
short_term = ShortService(end_user_id, db)
|
||||
short_result = short_term.get_short_databasets() # Get short-term memory entries
|
||||
short_count = short_term.get_short_count() # Get short-term retrieval count
|
||||
|
||||
long_term=LongService(end_user_id, db)
|
||||
long_result=long_term.get_long_databasets()
|
||||
# Retrieve long-term memory data and aggregations
|
||||
long_term = LongService(end_user_id, db)
|
||||
long_result = long_term.get_long_databasets() # Get long-term memory entries
|
||||
|
||||
# Get entity count for the specified end user
|
||||
entity_result = await search_entity(end_user_id)
|
||||
|
||||
# Compile comprehensive memory statistics response
|
||||
result = {
|
||||
'short_term': short_result,
|
||||
'long_term': long_result,
|
||||
'entity': entity_result.get('num', 0),
|
||||
"retrieval_number":short_count,
|
||||
"long_term_number":len(long_result)
|
||||
'short_term': short_result, # Short-term memory entries
|
||||
'long_term': long_result, # Long-term memory data
|
||||
'entity': entity_result.get('num', 0), # Entity count (default to 0 if not found)
|
||||
"retrieval_number": short_count, # Short-term retrieval statistics
|
||||
"long_term_number": len(long_result) # Long-term memory entry count
|
||||
}
|
||||
|
||||
return success(data=result, msg="短期记忆系统数据获取成功")
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -19,6 +20,7 @@ from app.services import user_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.security import verify_password
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -33,7 +35,8 @@ router = APIRouter(
|
||||
def create_superuser(
|
||||
user: user_schema.UserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_superuser: User = Depends(get_current_superuser)
|
||||
current_superuser: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""创建超级管理员(仅超级管理员可访问)"""
|
||||
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
|
||||
@@ -42,7 +45,7 @@ def create_superuser(
|
||||
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="超级管理员创建成功")
|
||||
return success(data=result_schema, msg=t("users.create.superuser_success"))
|
||||
|
||||
|
||||
@router.delete("/{user_id}", response_model=ApiResponse)
|
||||
@@ -50,6 +53,7 @@ def delete_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""停用用户(软删除)"""
|
||||
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -57,13 +61,14 @@ def delete_user(
|
||||
db=db, user_id_to_deactivate=user_id, current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
|
||||
return success(msg="用户停用成功")
|
||||
return success(msg=t("users.delete.deactivate_success"))
|
||||
|
||||
@router.post("/{user_id}/activate", response_model=ApiResponse)
|
||||
def activate_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""激活用户"""
|
||||
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -74,13 +79,14 @@ def activate_user(
|
||||
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="用户激活成功")
|
||||
return success(data=result_schema, msg=t("users.activate.success"))
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_current_user_info(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前用户信息"""
|
||||
api_logger.info(f"当前用户信息请求: {current_user.username}")
|
||||
@@ -105,7 +111,7 @@ def get_current_user_info(
|
||||
break
|
||||
|
||||
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||
return success(data=result_schema, msg="用户信息获取成功")
|
||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||
|
||||
|
||||
@router.get("/superusers", response_model=ApiResponse)
|
||||
@@ -113,6 +119,7 @@ def get_tenant_superusers(
|
||||
include_inactive: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
|
||||
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
|
||||
@@ -125,7 +132,7 @@ def get_tenant_superusers(
|
||||
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
|
||||
|
||||
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
|
||||
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
||||
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
||||
|
||||
|
||||
|
||||
@@ -134,6 +141,7 @@ def get_user_info_by_id(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""根据用户ID获取用户信息"""
|
||||
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -144,7 +152,7 @@ def get_user_info_by_id(
|
||||
api_logger.info(f"用户信息获取成功: {result.username}")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="用户信息获取成功")
|
||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||
|
||||
|
||||
@router.put("/change-password", response_model=ApiResponse)
|
||||
@@ -152,6 +160,7 @@ async def change_password(
|
||||
request: ChangePasswordRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""修改当前用户密码"""
|
||||
api_logger.info(f"用户密码修改请求: {current_user.username}")
|
||||
@@ -164,7 +173,7 @@ async def change_password(
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户密码修改成功: {current_user.username}")
|
||||
return success(msg="密码修改成功")
|
||||
return success(msg=t("auth.password.change_success"))
|
||||
|
||||
|
||||
@router.put("/admin/change-password", response_model=ApiResponse)
|
||||
@@ -172,6 +181,7 @@ async def admin_change_password(
|
||||
request: AdminChangePasswordRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""超级管理员修改指定用户的密码"""
|
||||
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
|
||||
@@ -186,16 +196,17 @@ async def admin_change_password(
|
||||
# 根据是否生成了随机密码来构造响应
|
||||
if request.new_password:
|
||||
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
|
||||
return success(msg="密码修改成功")
|
||||
return success(msg=t("auth.password.change_success"))
|
||||
else:
|
||||
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
return success(data=generated_password, msg=t("auth.password.reset_success"))
|
||||
|
||||
|
||||
@router.post("/verify_pwd", response_model=ApiResponse)
|
||||
def verify_pwd(
|
||||
request: VerifyPasswordRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""验证当前用户密码"""
|
||||
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
||||
@@ -203,8 +214,8 @@ def verify_pwd(
|
||||
is_valid = verify_password(request.password, current_user.hashed_password)
|
||||
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
||||
if not is_valid:
|
||||
raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg="验证完成")
|
||||
raise BusinessException(t("users.errors.password_verification_failed"), code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/send-email-code", response_model=ApiResponse)
|
||||
@@ -212,6 +223,7 @@ async def send_email_code(
|
||||
request: SendEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""发送邮箱验证码"""
|
||||
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
||||
@@ -219,7 +231,7 @@ async def send_email_code(
|
||||
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
||||
|
||||
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
||||
return success(msg="验证码已发送到您的邮箱,请查收")
|
||||
return success(msg=t("users.email.code_sent"))
|
||||
|
||||
|
||||
@router.put("/change-email", response_model=ApiResponse)
|
||||
@@ -227,6 +239,7 @@ async def change_email(
|
||||
request: VerifyEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""验证验证码并修改邮箱"""
|
||||
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
||||
@@ -239,4 +252,51 @@ async def change_email(
|
||||
)
|
||||
|
||||
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
||||
return success(msg="邮箱修改成功")
|
||||
return success(msg=t("users.email.change_success"))
|
||||
|
||||
|
||||
|
||||
@router.get("/me/language", response_model=ApiResponse)
|
||||
def get_current_user_language(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前用户的语言偏好"""
|
||||
api_logger.info(f"获取用户语言偏好: {current_user.username}")
|
||||
|
||||
language = user_service.get_user_language_preference(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"用户语言偏好获取成功: {current_user.username}, language={language}")
|
||||
return success(
|
||||
data=user_schema.LanguagePreferenceResponse(language=language),
|
||||
msg=t("users.language.get_success")
|
||||
)
|
||||
|
||||
|
||||
@router.put("/me/language", response_model=ApiResponse)
|
||||
def update_current_user_language(
|
||||
request: user_schema.LanguagePreferenceRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""设置当前用户的语言偏好"""
|
||||
api_logger.info(f"更新用户语言偏好: {current_user.username}, language={request.language}")
|
||||
|
||||
updated_user = user_service.update_user_language_preference(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
language=request.language,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"用户语言偏好更新成功: {current_user.username}, language={request.language}")
|
||||
return success(
|
||||
data=user_schema.LanguagePreferenceResponse(language=updated_user.preferred_language),
|
||||
msg=t("users.language.update_success")
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ from app.services.user_memory_service import (
|
||||
UserMemoryService,
|
||||
analytics_memory_types,
|
||||
analytics_graph_data,
|
||||
analytics_community_graph_data,
|
||||
)
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
@@ -295,6 +296,42 @@ async def get_graph_data_api(
|
||||
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
||||
async def get_community_graph_data_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询社区图谱但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"社区图谱查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await analytics_community_graph_data(db=db, end_user_id=end_user_id)
|
||||
|
||||
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
||||
api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}")
|
||||
return success(data=result, msg=result.get("message", "查询成功"))
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取社区图谱: end_user_id={end_user_id}, "
|
||||
f"nodes={result['statistics']['total_nodes']}, "
|
||||
f"edges={result['statistics']['total_edges']}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
||||
async def get_end_user_profile(
|
||||
end_user_id: str,
|
||||
|
||||
@@ -14,6 +14,12 @@ from app.dependencies import (
|
||||
get_current_user,
|
||||
workspace_access_guard,
|
||||
)
|
||||
from app.i18n.dependencies import get_current_language, get_translator
|
||||
from app.i18n.serializers import (
|
||||
WorkspaceSerializer,
|
||||
WorkspaceMemberSerializer,
|
||||
WorkspaceInviteSerializer
|
||||
)
|
||||
from app.models.tenant_model import Tenants
|
||||
from app.models.user_model import User
|
||||
from app.models.workspace_model import InviteStatus
|
||||
@@ -65,7 +71,9 @@ def get_workspaces(
|
||||
include_current: bool = Query(True, description="是否包含当前工作空间"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_tenant: Tenants = Depends(get_current_tenant)
|
||||
current_tenant: Tenants = Depends(get_current_tenant),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前租户下用户参与的所有工作空间
|
||||
|
||||
@@ -88,8 +96,13 @@ def get_workspaces(
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
|
||||
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
|
||||
return success(data=workspaces_schema, msg="工作空间列表获取成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
workspaces_data = [WorkspaceResponse.model_validate(w).model_dump() for w in workspaces]
|
||||
workspaces_i18n = serializer.serialize_list(workspaces_data, language)
|
||||
|
||||
return success(data=workspaces_i18n, msg=t("workspace.list_retrieved"))
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
@@ -98,6 +111,8 @@ def create_workspace(
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""创建新的工作空间"""
|
||||
from app.core.language_utils import get_language_from_header
|
||||
@@ -118,8 +133,13 @@ def create_workspace(
|
||||
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
|
||||
f"创建者: {current_user.username}, language={language}"
|
||||
)
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间创建成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||
result_i18n = serializer.serialize(result_data, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.created"))
|
||||
|
||||
@router.put("", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
@@ -127,6 +147,8 @@ def update_workspace(
|
||||
workspace: WorkspaceUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""更新工作空间"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -139,14 +161,21 @@ def update_workspace(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间更新成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||
result_i18n = serializer.serialize(result_data, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.updated"))
|
||||
|
||||
@router.get("/members", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_cur_workspace_members(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间成员列表(关系序列化)"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
|
||||
@@ -157,8 +186,14 @@ def get_cur_workspace_members(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
|
||||
|
||||
# 转换为表格项并使用序列化器添加国际化字段
|
||||
table_items = _convert_members_to_table_items(members)
|
||||
return success(data=table_items, msg="工作空间成员列表获取成功")
|
||||
serializer = WorkspaceMemberSerializer()
|
||||
members_data = [item.model_dump() for item in table_items]
|
||||
members_i18n = serializer.serialize_list(members_data, language)
|
||||
|
||||
return success(data=members_i18n, msg=t("workspace.members.list_retrieved"))
|
||||
|
||||
|
||||
@router.put("/members", response_model=ApiResponse)
|
||||
@@ -168,6 +203,7 @@ def update_workspace_members(
|
||||
updates: List[WorkspaceMemberUpdate],
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
|
||||
@@ -178,7 +214,7 @@ def update_workspace_members(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
|
||||
return success(msg="成员角色更新成功")
|
||||
return success(msg=t("workspace.members.role_updated"))
|
||||
|
||||
|
||||
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
||||
@@ -187,6 +223,7 @@ def delete_workspace_member(
|
||||
member_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
@@ -198,7 +235,7 @@ def delete_workspace_member(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
|
||||
return success(msg="成员删除成功")
|
||||
return success(msg=t("workspace.members.deleted"))
|
||||
|
||||
|
||||
# 创建空间协作邀请
|
||||
@@ -208,6 +245,8 @@ def create_workspace_invite(
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""创建工作空间邀请"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -220,7 +259,12 @@ def create_workspace_invite(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
|
||||
return success(data=result, msg="邀请创建成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.created"))
|
||||
|
||||
|
||||
@router.get("/invites", response_model=ApiResponse)
|
||||
@@ -232,6 +276,8 @@ def get_workspace_invites(
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间邀请列表"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -246,18 +292,30 @@ def get_workspace_invites(
|
||||
offset=offset
|
||||
)
|
||||
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
|
||||
return success(data=invites, msg="邀请列表获取成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
invites_i18n = serializer.serialize_list(invites, language)
|
||||
|
||||
return success(data=invites_i18n, msg=t("workspace.invites.list_retrieved"))
|
||||
|
||||
|
||||
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
|
||||
def get_workspace_invite_info(
|
||||
token: str,
|
||||
db: Session = Depends(get_db),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间邀请用户信息(无需认证)"""
|
||||
result = workspace_service.validate_invite_token(db=db, token=token)
|
||||
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
|
||||
return success(data=result, msg="邀请验证成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.validated"))
|
||||
|
||||
|
||||
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
|
||||
@@ -267,6 +325,8 @@ def revoke_workspace_invite(
|
||||
invite_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""撤销工作空间邀请"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -279,7 +339,12 @@ def revoke_workspace_invite(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
|
||||
return success(data=result, msg="邀请撤销成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.revoked"))
|
||||
|
||||
# ==================== 公开邀请接口(无需认证) ====================
|
||||
|
||||
@@ -302,6 +367,7 @@ def switch_workspace(
|
||||
workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""切换工作空间"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
|
||||
@@ -312,7 +378,7 @@ def switch_workspace(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"成功切换工作空间为 {workspace_id}")
|
||||
return success(msg="工作空间切换成功")
|
||||
return success(msg=t("workspace.switched"))
|
||||
|
||||
|
||||
@router.get("/storage", response_model=ApiResponse)
|
||||
@@ -320,6 +386,7 @@ def switch_workspace(
|
||||
def get_workspace_storage_type(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前工作空间的存储类型"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -331,7 +398,7 @@ def get_workspace_storage_type(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
|
||||
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
|
||||
return success(data={"storage_type": storage_type}, msg=t("workspace.storage.type_retrieved"))
|
||||
|
||||
|
||||
@router.get("/workspace_models", response_model=ApiResponse)
|
||||
@@ -339,6 +406,8 @@ def get_workspace_storage_type(
|
||||
def workspace_models_configs(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -354,14 +423,14 @@ def workspace_models_configs(
|
||||
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作空间不存在或无权访问"
|
||||
detail=t("workspace.not_found")
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||
)
|
||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功")
|
||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg=t("workspace.models.config_retrieved"))
|
||||
|
||||
|
||||
@router.put("/workspace_models", response_model=ApiResponse)
|
||||
@@ -370,6 +439,7 @@ def update_workspace_models_configs(
|
||||
models_update: WorkspaceModelsUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""更新当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -386,5 +456,5 @@ def update_workspace_models_configs(
|
||||
f"成功更新工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
|
||||
)
|
||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功")
|
||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg=t("workspace.models.config_updated"))
|
||||
|
||||
|
||||
@@ -162,6 +162,44 @@ class Settings:
|
||||
# This controls the language used for memory summary titles and other generated content
|
||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# ========================================================================
|
||||
# Internationalization (i18n) Configuration
|
||||
# ========================================================================
|
||||
# Default language for API responses
|
||||
I18N_DEFAULT_LANGUAGE: str = os.getenv("I18N_DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# Supported languages (comma-separated)
|
||||
I18N_SUPPORTED_LANGUAGES: list[str] = [
|
||||
lang.strip()
|
||||
for lang in os.getenv("I18N_SUPPORTED_LANGUAGES", "zh,en").split(",")
|
||||
if lang.strip()
|
||||
]
|
||||
|
||||
# Core locales directory (community edition)
|
||||
# Use absolute path to work from any working directory
|
||||
I18N_CORE_LOCALES_DIR: str = os.getenv(
|
||||
"I18N_CORE_LOCALES_DIR",
|
||||
os.path.join(os.path.dirname(os.path.dirname(__file__)), "locales")
|
||||
)
|
||||
|
||||
# Premium locales directory (enterprise edition, optional)
|
||||
I18N_PREMIUM_LOCALES_DIR: Optional[str] = os.getenv("I18N_PREMIUM_LOCALES_DIR", None)
|
||||
|
||||
# Enable translation cache
|
||||
I18N_ENABLE_TRANSLATION_CACHE: bool = os.getenv("I18N_ENABLE_TRANSLATION_CACHE", "true").lower() == "true"
|
||||
|
||||
# LRU cache size for hot translations
|
||||
I18N_LRU_CACHE_SIZE: int = int(os.getenv("I18N_LRU_CACHE_SIZE", "1000"))
|
||||
|
||||
# Enable hot reload of translation files
|
||||
I18N_ENABLE_HOT_RELOAD: bool = os.getenv("I18N_ENABLE_HOT_RELOAD", "false").lower() == "true"
|
||||
|
||||
# Fallback language when translation is missing
|
||||
I18N_FALLBACK_LANGUAGE: str = os.getenv("I18N_FALLBACK_LANGUAGE", "zh")
|
||||
|
||||
# Log missing translations
|
||||
I18N_LOG_MISSING_TRANSLATIONS: bool = os.getenv("I18N_LOG_MISSING_TRANSLATIONS", "true").lower() == "true"
|
||||
|
||||
# Logging settings
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
|
||||
@@ -2,15 +2,37 @@ from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
||||
|
||||
|
||||
def content_input_node(state: ReadState) -> ReadState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
"""
|
||||
Start node - Extract content and maintain state information
|
||||
|
||||
Extracts the content from the first message in the state and returns it
|
||||
as the data field while preserving all other state information.
|
||||
|
||||
Args:
|
||||
state: ReadState containing messages and other state data
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with extracted content in data field
|
||||
"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
# Return content and maintain all state information
|
||||
return {"data": content}
|
||||
|
||||
|
||||
def content_input_write(state: WriteState) -> WriteState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
"""
|
||||
Start node - Extract content and maintain state information for write operations
|
||||
|
||||
Extracts the content from the first message in the state for write operations.
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages and other state data
|
||||
|
||||
Returns:
|
||||
WriteState: Updated state with extracted content in data field
|
||||
"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
return {"data": content}
|
||||
# Return content and maintain all state information
|
||||
return {"data": content}
|
||||
|
||||
@@ -19,19 +19,39 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class ProblemNodeService(LLMServiceMixin):
|
||||
"""问题处理节点服务类"""
|
||||
"""
|
||||
Problem processing node service class
|
||||
|
||||
Handles problem decomposition and extension operations using LLM services.
|
||||
Inherits from LLMServiceMixin to provide structured LLM calling capabilities.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# Create global service instance
|
||||
problem_service = ProblemNodeService()
|
||||
|
||||
|
||||
async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
"""问题分解节点"""
|
||||
"""
|
||||
Problem decomposition node
|
||||
|
||||
Breaks down complex user queries into smaller, more manageable sub-problems.
|
||||
Uses LLM to analyze the input and generate structured problem decomposition
|
||||
with question types and reasoning.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user input and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with problem decomposition results
|
||||
"""
|
||||
# 从状态中获取数据
|
||||
content = state.get('data', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
@@ -64,7 +84,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
# 添加更详细的日志记录
|
||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if not structured or not hasattr(structured, 'root'):
|
||||
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
@@ -106,7 +126,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
# Provide more detailed error information
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
@@ -116,7 +136,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
|
||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||
|
||||
# 创建默认的空结果
|
||||
# Create default empty result
|
||||
result = {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": content,
|
||||
@@ -130,13 +150,25 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 返回更新后的状态,包含spit_context字段
|
||||
# Return updated state including spit_context field
|
||||
return {"spit_data": result}
|
||||
|
||||
|
||||
async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
"""问题扩展节点"""
|
||||
# 获取原始数据和分解结果
|
||||
"""
|
||||
Problem extension node
|
||||
|
||||
Extends the decomposed problems from Split_The_Problem node by generating
|
||||
additional related questions and organizing them by original question.
|
||||
Uses LLM to create comprehensive question extensions for better memory retrieval.
|
||||
|
||||
Args:
|
||||
state: ReadState containing decomposed problems and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with extended problem results
|
||||
"""
|
||||
# Get original data and decomposition results
|
||||
start = time.time()
|
||||
content = state.get('data', '')
|
||||
data = state.get('spit_data', '')['context']
|
||||
@@ -182,7 +214,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
|
||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if not response_content or not hasattr(response_content, 'root'):
|
||||
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
|
||||
aggregated_dict = {}
|
||||
@@ -216,7 +248,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
# Provide more detailed error information
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
|
||||
@@ -29,6 +29,18 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
"""
|
||||
Configure RAG (Retrieval-Augmented Generation) settings
|
||||
|
||||
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||
weights, and reranker settings.
|
||||
|
||||
Args:
|
||||
state: Current state containing user_rag_memory_id
|
||||
|
||||
Returns:
|
||||
dict: RAG configuration dictionary
|
||||
"""
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
@@ -48,6 +60,19 @@ async def rag_config(state):
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
"""
|
||||
Retrieve knowledge using RAG approach
|
||||
|
||||
Performs knowledge retrieval from configured knowledge bases using the
|
||||
provided question and returns formatted results.
|
||||
|
||||
Args:
|
||||
state: Current state containing configuration
|
||||
question: Question to search for
|
||||
|
||||
Returns:
|
||||
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||
"""
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
@@ -68,12 +93,24 @@ async def rag_knowledge(state, question):
|
||||
|
||||
|
||||
async def llm_infomation(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Get LLM configuration information from state
|
||||
|
||||
Retrieves model configuration details including model ID and tenant ID
|
||||
from the memory configuration in the current state.
|
||||
|
||||
Args:
|
||||
state: ReadState containing memory configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Model configuration as Pydantic model
|
||||
"""
|
||||
memory_config = state.get('memory_config', None)
|
||||
model_id = memory_config.llm_model_id
|
||||
tenant_id = memory_config.tenant_id
|
||||
|
||||
# 使用现有的 memory_config 而不是重新查询数据库
|
||||
# 或者使用线程安全的数据库访问
|
||||
# Use existing memory_config instead of re-querying database
|
||||
# or use thread-safe database access
|
||||
with get_db_context() as db:
|
||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
|
||||
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||
@@ -82,16 +119,20 @@ async def llm_infomation(state: ReadState) -> ReadState:
|
||||
|
||||
async def clean_databases(data) -> str:
|
||||
"""
|
||||
简化的数据库搜索结果清理函数
|
||||
Simplified database search result cleaning function
|
||||
|
||||
Processes and cleans search results from various sources including
|
||||
reranked results and time-based search results. Extracts text content
|
||||
from structured data and returns as formatted string.
|
||||
|
||||
Args:
|
||||
data: 搜索结果数据
|
||||
data: Search result data (can be string, dict, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的内容字符串
|
||||
str: Cleaned content string
|
||||
"""
|
||||
try:
|
||||
# 解析JSON字符串
|
||||
# Parse JSON string
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
data = json.loads(data)
|
||||
@@ -101,24 +142,24 @@ async def clean_databases(data) -> str:
|
||||
if not isinstance(data, dict):
|
||||
return str(data)
|
||||
|
||||
# 获取结果数据
|
||||
# Get result data
|
||||
# with open("搜索结果.json","w",encoding='utf-8') as f:
|
||||
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
||||
results = data.get('results', data)
|
||||
if not isinstance(results, dict):
|
||||
return str(results)
|
||||
|
||||
# 收集所有内容
|
||||
# Collect all content
|
||||
content_list = []
|
||||
|
||||
# 处理重排序结果
|
||||
# Process reranked results
|
||||
reranked = results.get('reranked_results', {})
|
||||
if reranked:
|
||||
for category in ['summaries', 'statements', 'chunks', 'entities']:
|
||||
items = reranked.get(category, [])
|
||||
if isinstance(items, list):
|
||||
content_list.extend(items)
|
||||
# 处理时间搜索结果
|
||||
# Process time search results
|
||||
time_search = results.get('time_search', {})
|
||||
if time_search:
|
||||
if isinstance(time_search, dict):
|
||||
@@ -128,7 +169,7 @@ async def clean_databases(data) -> str:
|
||||
elif isinstance(time_search, list):
|
||||
content_list.extend(time_search)
|
||||
|
||||
# 提取文本内容
|
||||
# Extract text content
|
||||
text_parts = []
|
||||
for item in content_list:
|
||||
if isinstance(item, dict):
|
||||
@@ -146,10 +187,19 @@ async def clean_databases(data) -> str:
|
||||
|
||||
|
||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
'''
|
||||
|
||||
模型信息
|
||||
'''
|
||||
"""
|
||||
Retrieve information using simplified search approach
|
||||
|
||||
Processes extended problems from previous nodes and performs retrieval
|
||||
using either RAG or hybrid search based on storage type. Handles concurrent
|
||||
processing of multiple questions and deduplicates results.
|
||||
|
||||
Args:
|
||||
state: ReadState containing problem extensions and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with retrieval results and intermediate outputs
|
||||
"""
|
||||
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
@@ -163,7 +213,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
problem_list.append(data)
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
# Create async task to process individual questions
|
||||
async def process_question_nodes(idx, question):
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
@@ -209,7 +259,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
# Process all questions concurrently
|
||||
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
databases_data = {
|
||||
@@ -257,7 +307,20 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取end_user_id
|
||||
"""
|
||||
Advanced retrieve function using LangChain agents and tools
|
||||
|
||||
Uses LangChain agents with specialized retrieval tools (time-based and hybrid)
|
||||
to perform sophisticated information retrieval. Supports both RAG and traditional
|
||||
memory storage approaches with concurrent processing and result deduplication.
|
||||
|
||||
Args:
|
||||
state: ReadState containing problem extensions and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with retrieval results and intermediate outputs
|
||||
"""
|
||||
# Get end_user_id from state
|
||||
import time
|
||||
start = time.time()
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
@@ -299,21 +362,21 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||
)
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
# Create async task to process individual questions
|
||||
import asyncio
|
||||
|
||||
# 在模块级别定义信号量,限制最大并发数
|
||||
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
|
||||
# Define semaphore at module level to limit maximum concurrency
|
||||
SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations
|
||||
|
||||
async def process_question(idx, question):
|
||||
async with SEMAPHORE: # 限制并发
|
||||
async with SEMAPHORE: # Limit concurrency
|
||||
try:
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||
question)
|
||||
else:
|
||||
cleaned_query = question
|
||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||
# Use asyncio to run synchronous agent.invoke in thread pool
|
||||
import asyncio
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
@@ -362,7 +425,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
# Process all questions concurrently
|
||||
import asyncio
|
||||
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
|
||||
@@ -23,18 +23,39 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class SummaryNodeService(LLMServiceMixin):
|
||||
"""总结节点服务类"""
|
||||
"""
|
||||
Summary node service class
|
||||
|
||||
Handles summary generation operations using LLM services. Inherits from
|
||||
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||
generating summaries from retrieved information.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# Create global service instance
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
"""
|
||||
Configure RAG (Retrieval-Augmented Generation) settings for summary operations
|
||||
|
||||
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||
weights, and reranker settings specifically for summary generation.
|
||||
|
||||
Args:
|
||||
state: Current state containing user_rag_memory_id
|
||||
|
||||
Returns:
|
||||
dict: RAG configuration dictionary with knowledge base settings
|
||||
"""
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
@@ -54,6 +75,23 @@ async def rag_config(state):
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
"""
|
||||
Retrieve knowledge using RAG approach for summary generation
|
||||
|
||||
Performs knowledge retrieval from configured knowledge bases using the
|
||||
provided question and returns formatted results for summary processing.
|
||||
|
||||
Args:
|
||||
state: Current state containing configuration
|
||||
question: Question to search for in knowledge base
|
||||
|
||||
Returns:
|
||||
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||
- retrieval_knowledge: List of retrieved knowledge chunks
|
||||
- clean_content: Formatted content string
|
||||
- cleaned_query: Processed query string
|
||||
- raw_results: Raw retrieval results
|
||||
"""
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
@@ -74,6 +112,18 @@ async def rag_knowledge(state, question):
|
||||
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Retrieve conversation history for summary context
|
||||
|
||||
Gets the conversation history for the current user to provide context
|
||||
for summary generation operations.
|
||||
|
||||
Args:
|
||||
state: ReadState containing end_user_id
|
||||
|
||||
Returns:
|
||||
ReadState: Conversation history data
|
||||
"""
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
return history
|
||||
@@ -82,11 +132,26 @@ async def summary_history(state: ReadState) -> ReadState:
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||
search_mode) -> str:
|
||||
"""
|
||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||
Enhanced summary_llm function with better error handling and data validation
|
||||
|
||||
Generates summaries using LLM with structured output. Includes fallback mechanisms
|
||||
for handling LLM failures and provides robust error recovery.
|
||||
|
||||
Args:
|
||||
state: ReadState containing current context
|
||||
history: Conversation history for context
|
||||
retrieve_info: Retrieved information to summarize
|
||||
template_name: Jinja2 template name for prompt generation
|
||||
operation_name: Type of operation (summary, input_summary, retrieve_summary)
|
||||
response_model: Pydantic model for structured output
|
||||
search_mode: Search mode flag ("0" for simple, "1" for complex)
|
||||
|
||||
Returns:
|
||||
str: Generated summary text or fallback message
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
|
||||
# 构建系统提示词
|
||||
# Build system prompt
|
||||
if str(search_mode) == "0":
|
||||
system_prompt = await summary_service.template_service.render_template(
|
||||
template_name=template_name,
|
||||
@@ -103,7 +168,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
try:
|
||||
# 使用优化的LLM服务进行结构化输出
|
||||
# Use optimized LLM service for structured output
|
||||
with get_db_context() as db_session:
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
@@ -112,23 +177,23 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
response_model=response_model,
|
||||
fallback_value=None
|
||||
)
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if structured is None:
|
||||
logger.warning("LLM返回None,使用默认回答")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
# 根据操作类型提取答案
|
||||
# Extract answer based on operation type
|
||||
if operation_name == "summary":
|
||||
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
# 处理RetrieveSummaryResponse
|
||||
# Handle RetrieveSummaryResponse
|
||||
if hasattr(structured, 'data') and structured.data:
|
||||
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
logger.warning("结构化响应缺少data字段")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
# 验证答案不为空
|
||||
# Validate answer is not empty
|
||||
if not aimessages or aimessages.strip() == "":
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
@@ -137,7 +202,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
||||
|
||||
# 尝试非结构化输出作为fallback
|
||||
# Try unstructured output as fallback
|
||||
try:
|
||||
logger.info("尝试非结构化输出作为fallback")
|
||||
response = await summary_service.call_llm_simple(
|
||||
@@ -148,9 +213,9 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
)
|
||||
|
||||
if response and response.strip():
|
||||
# 简单清理响应
|
||||
# Simple response cleaning
|
||||
cleaned_response = response.strip()
|
||||
# 移除可能的JSON标记
|
||||
# Remove possible JSON markers
|
||||
if cleaned_response.startswith('```'):
|
||||
lines = cleaned_response.split('\n')
|
||||
cleaned_response = '\n'.join(lines[1:-1])
|
||||
@@ -165,6 +230,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
|
||||
|
||||
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||
"""
|
||||
Save summary results to Redis session storage
|
||||
|
||||
Stores the generated summary and user query in Redis for session management
|
||||
and conversation history tracking.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user and query information
|
||||
aimessages: Generated summary message to save
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state after saving to Redis
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
await SessionService(store).save_session(
|
||||
@@ -179,6 +257,20 @@ async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||
|
||||
|
||||
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||
"""
|
||||
Format summary results for different output types
|
||||
|
||||
Creates structured output formats for both input summary and retrieval summary
|
||||
operations, including metadata and intermediate results for frontend display.
|
||||
|
||||
Args:
|
||||
state: ReadState containing storage and user information
|
||||
aimessages: Generated summary message
|
||||
raw_results: Raw search/retrieval results
|
||||
|
||||
Returns:
|
||||
tuple: (input_summary, retrieve_summary) formatted result dictionaries
|
||||
"""
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
data = state.get("data", '')
|
||||
@@ -217,6 +309,19 @@ async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState
|
||||
|
||||
|
||||
async def Input_Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate quick input summary from retrieved information
|
||||
|
||||
Performs fast retrieval and generates a quick summary response for user queries.
|
||||
This function prioritizes speed by only searching summary nodes and provides
|
||||
immediate feedback to users.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user query, storage configuration, and context
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing summary results with status and metadata
|
||||
"""
|
||||
start = time.time()
|
||||
storage_type = state.get("storage_type", '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
@@ -266,6 +371,19 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
|
||||
|
||||
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate comprehensive summary from retrieved expansion issues
|
||||
|
||||
Processes retrieved expansion issues and generates a detailed summary using LLM.
|
||||
This function handles complex retrieval results and provides comprehensive answers
|
||||
based on expanded query results.
|
||||
|
||||
Args:
|
||||
state: ReadState containing retrieve data with expansion issues
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing comprehensive summary results
|
||||
"""
|
||||
retrieve = state.get("retrieve", '')
|
||||
history = await summary_history(state)
|
||||
import json
|
||||
@@ -299,13 +417,26 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
# Fixed coroutine call - await first, then access return value
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate final comprehensive summary from verified data
|
||||
|
||||
Creates the final summary using verified expansion issues and conversation history.
|
||||
This function processes verified data to generate the most comprehensive and
|
||||
accurate response to user queries.
|
||||
|
||||
Args:
|
||||
state: ReadState containing verified data and query information
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing final summary results
|
||||
"""
|
||||
start = time.time()
|
||||
query = state.get("data", '')
|
||||
verify = state.get("verify", '')
|
||||
@@ -336,13 +467,26 @@ async def Summary(state: ReadState) -> ReadState:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
# Fixed coroutine call - await first, then access return value
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary_fails(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate fallback summary when normal summary process fails
|
||||
|
||||
Provides a fallback summary generation mechanism when the standard summary
|
||||
process encounters errors or fails to produce satisfactory results. Uses
|
||||
a specialized failure template to handle edge cases.
|
||||
|
||||
Args:
|
||||
state: ReadState containing verified data and failure context
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing fallback summary results
|
||||
"""
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
history = await summary_history(state)
|
||||
|
||||
@@ -18,24 +18,46 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class VerificationNodeService(LLMServiceMixin):
|
||||
"""验证节点服务类"""
|
||||
"""
|
||||
Verification node service class
|
||||
|
||||
Handles data verification operations using LLM services. Inherits from
|
||||
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||
verifying and validating retrieved information.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# Create global service instance
|
||||
verification_service = VerificationNodeService()
|
||||
|
||||
|
||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
"""处理验证结果并生成输出格式"""
|
||||
"""
|
||||
Process verification results and generate output format
|
||||
|
||||
Transforms VerificationResult objects into structured output format suitable
|
||||
for frontend consumption. Handles conversion of VerificationItem objects to
|
||||
dictionary format and adds metadata for tracking.
|
||||
|
||||
Args:
|
||||
state: ReadState containing storage and user configuration
|
||||
messages_deal: VerificationResult containing verification outcomes
|
||||
|
||||
Returns:
|
||||
dict: Formatted verification result with status and metadata
|
||||
"""
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
data = state.get('data', '')
|
||||
|
||||
# 将 VerificationItem 对象转换为字典列表
|
||||
# Convert VerificationItem objects to dictionary list
|
||||
verified_data = []
|
||||
if messages_deal.expansion_issue:
|
||||
for item in messages_deal.expansion_issue:
|
||||
@@ -89,7 +111,7 @@ async def Verify(state: ReadState):
|
||||
|
||||
logger.info("Verify: 开始渲染模板")
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
# Generate JSON schema to guide LLM output format
|
||||
json_schema = VerificationResult.model_json_schema()
|
||||
|
||||
system_prompt = await verification_service.template_service.render_template(
|
||||
@@ -104,8 +126,8 @@ async def Verify(state: ReadState):
|
||||
# 使用优化的LLM服务,添加超时保护
|
||||
logger.info("Verify: 开始调用 LLM")
|
||||
try:
|
||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||
# Add asyncio.wait_for timeout wrapper to prevent infinite waiting
|
||||
# Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds)
|
||||
|
||||
with get_db_context() as db_session:
|
||||
structured = await asyncio.wait_for(
|
||||
@@ -122,7 +144,7 @@ async def Verify(state: ReadState):
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150秒超时
|
||||
timeout=150.0 # 150 second timeout
|
||||
)
|
||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
@@ -33,7 +33,19 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph():
|
||||
"""创建并返回 LangGraph 工作流"""
|
||||
"""
|
||||
Create and return a LangGraph workflow for memory reading operations
|
||||
|
||||
Builds a state graph workflow that handles memory retrieval, problem analysis,
|
||||
verification, and summarization. The workflow includes nodes for content input,
|
||||
problem splitting, retrieval, verification, and various summary operations.
|
||||
|
||||
Yields:
|
||||
StateGraph: Compiled LangGraph workflow for memory reading
|
||||
|
||||
Raises:
|
||||
Exception: If workflow creation fails
|
||||
"""
|
||||
try:
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
@@ -48,7 +60,7 @@ async def make_read_graph():
|
||||
workflow.add_node("Summary", Summary)
|
||||
workflow.add_node("Summary_fails", Summary_fails)
|
||||
|
||||
# 添加边
|
||||
# Add edges to define workflow flow
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
@@ -63,7 +75,7 @@ async def make_read_graph():
|
||||
'''-----'''
|
||||
# workflow.add_edge("Retrieve", END)
|
||||
|
||||
# 编译工作流
|
||||
# Compile workflow
|
||||
graph = workflow.compile()
|
||||
yield graph
|
||||
|
||||
@@ -72,108 +84,3 @@ async def make_read_graph():
|
||||
raise
|
||||
finally:
|
||||
print("工作流创建完成")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
end_user_id = '88a459f5_text09' # 组ID
|
||||
storage_type = 'neo4j' # 存储类型
|
||||
search_switch = '1' # 搜索开关
|
||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
import time
|
||||
start = time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||
"end_user_id": end_user_id
|
||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||
"memory_config": memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
summary = ''
|
||||
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
print(f"处理节点: {node_name}")
|
||||
|
||||
# 处理不同Summary节点的返回结构
|
||||
if 'Summary' in node_name:
|
||||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||||
summary = node_data['InputSummary']['summary_result']
|
||||
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
||||
summary = node_data['RetrieveSummary']['summary_result']
|
||||
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
||||
summary = node_data['summary']['summary_result']
|
||||
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
||||
summary = node_data['SummaryFails']['summary_result']
|
||||
|
||||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||||
if spit_data and spit_data != [] and spit_data != {}:
|
||||
_intermediate_outputs.append(spit_data)
|
||||
|
||||
# Problem_Extension 节点
|
||||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||||
_intermediate_outputs.append(problem_extension)
|
||||
|
||||
# Retrieve 节点
|
||||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||
_intermediate_outputs.extend(retrieve_node)
|
||||
|
||||
# Verify 节点
|
||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||
if verify_n and verify_n != [] and verify_n != {}:
|
||||
_intermediate_outputs.append(verify_n)
|
||||
|
||||
# Summary 节点
|
||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||
if summary_n and summary_n != [] and summary_n != {}:
|
||||
_intermediate_outputs.append(summary_n)
|
||||
|
||||
# # 过滤掉空值
|
||||
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||
#
|
||||
# # 优化搜索结果
|
||||
# print("=== 开始优化搜索结果 ===")
|
||||
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||
# result=reorder_output_results(optimized_outputs)
|
||||
# # 保存优化后的结果到文件
|
||||
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
|
||||
# import json
|
||||
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
|
||||
#
|
||||
print(f"=== 最终摘要 ===")
|
||||
print(summary)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
end = time.time()
|
||||
print(100 * 'y')
|
||||
print(f"总耗时: {end - start}s")
|
||||
print(100 * 'y')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
counter = COUNTState(limit=3)
|
||||
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
|
||||
|
||||
def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
@@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # 默认情况
|
||||
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
@@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
|
||||
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
status=state.get('verify', '')['status']
|
||||
status = state.get('verify', '')['status']
|
||||
# loop_count = counter.get_total()
|
||||
if "success" in status:
|
||||
# counter.reset()
|
||||
@@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
|
||||
# if loop_count < 2: # Maximum loop count is 3
|
||||
# return "content_input"
|
||||
# else:
|
||||
# counter.reset()
|
||||
# counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# Add default return value to avoid returning None
|
||||
|
||||
@@ -2,77 +2,104 @@ import json
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.redis_tool import count_store
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context, get_db
|
||||
from app.db import get_db_context
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
"""
|
||||
Write messages to RAG storage system
|
||||
|
||||
Combines user and AI messages into a single string format and stores them
|
||||
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for the conversation
|
||||
user_message: User's input message content
|
||||
ai_message: AI's response message content
|
||||
user_rag_memory_id: RAG memory identifier for storage location
|
||||
"""
|
||||
# RAG mode: combine messages into string format (maintain original logic)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||
actual_config_id, long_term_messages=[]):
|
||||
|
||||
|
||||
async def write(
|
||||
storage_type,
|
||||
end_user_id,
|
||||
user_message,
|
||||
ai_message,
|
||||
user_rag_memory_id,
|
||||
actual_end_user_id,
|
||||
actual_config_id,
|
||||
long_term_messages=None
|
||||
):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
Write memory with structured message support
|
||||
|
||||
Handles memory writing operations for different storage types (Neo4j/RAG).
|
||||
Supports both individual message pairs and batch long-term message processing.
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
storage_type: Storage type identifier ("neo4j" or "rag")
|
||||
end_user_id: Terminal user identifier
|
||||
user_message: User message content
|
||||
ai_message: AI response content
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_end_user_id: Actual user identifier for storage
|
||||
actual_config_id: Configuration identifier
|
||||
long_term_messages: Optional list of structured messages for batch processing
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
Logic explanation:
|
||||
- RAG mode: Combines user_message and ai_message into string format, maintains original logic
|
||||
- Neo4j mode: Uses structured message lists
|
||||
1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant]
|
||||
2. If only user_message exists: Creates single user message [user] (for historical memory scenarios)
|
||||
3. Each message is converted to independent Chunk, preserving speaker field
|
||||
"""
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
if long_term_messages is None:
|
||||
long_term_messages = []
|
||||
with get_db_context() as db:
|
||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
# Neo4j mode: Use structured message lists
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
# Always add user message (if not empty)
|
||||
if isinstance(user_message, str) and user_message.strip() != "":
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
# Only add assistant message when AI reply is not empty
|
||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||
# If long_term_messages provided, use it to replace structured_messages
|
||||
if long_term_messages and isinstance(long_term_messages, list):
|
||||
structured_messages = long_term_messages
|
||||
elif long_term_messages and isinstance(long_term_messages, str):
|
||||
# 如果是 JSON 字符串,先解析
|
||||
# If it's a JSON string, parse it first
|
||||
try:
|
||||
structured_messages = json.loads(long_term_messages)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
# If no messages, return directly
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
@@ -80,29 +107,41 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
|
||||
logger.info(
|
||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: JSON 字符串格式的消息列表
|
||||
str(actual_config_id), # config_id: 配置ID字符串
|
||||
actual_end_user_id, # end_user_id: User ID
|
||||
structured_messages, # message: JSON string format message list
|
||||
str(actual_config_id), # config_id: Configuration ID string
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||
|
||||
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
|
||||
"""
|
||||
Save long-term memory data to database
|
||||
|
||||
Handles the storage of long-term memory data based on different strategies
|
||||
(chunk-based or aggregate-based) and manages the transition from short-term
|
||||
to long-term memory storage.
|
||||
|
||||
Args:
|
||||
long_term_messages: Long-term message data to be saved
|
||||
actual_config_id: Configuration identifier for memory settings
|
||||
end_user_id: User identifier for memory association
|
||||
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||
scope: Scope/window size for memory processing
|
||||
"""
|
||||
with get_db_context() as db_session:
|
||||
repo = LongTermMemoryRepository(db_session)
|
||||
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data)==scope:
|
||||
if len(chunk_data) == scope:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'---------写入短长期-----------')
|
||||
else:
|
||||
@@ -112,18 +151,23 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
"""Window-based dialogue processing"""
|
||||
|
||||
'''根据窗口'''
|
||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
'''
|
||||
根据窗口获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
langchain_messages:原始数据LIST
|
||||
scope:窗口大小
|
||||
'''
|
||||
scope=scope
|
||||
|
||||
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||
"""
|
||||
Process dialogue based on window size and write to Neo4j
|
||||
|
||||
Manages conversation data based on a sliding window approach. When the window
|
||||
reaches the specified scope size, it triggers long-term memory storage to Neo4j.
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
memory_config: Memory configuration object containing settings
|
||||
langchain_messages: Original message data list
|
||||
scope: Window size determining when to trigger long-term storage
|
||||
"""
|
||||
scope = scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
@@ -135,50 +179,72 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = (redis_messages)
|
||||
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
config_id, formatted_messages)
|
||||
|
||||
await write(
|
||||
AgentMemory_Long_Term.STORAGE_NEO4J,
|
||||
end_user_id,
|
||||
"",
|
||||
"",
|
||||
None,
|
||||
end_user_id,
|
||||
config_id,
|
||||
formatted_messages
|
||||
)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""根据时间"""
|
||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
'''
|
||||
根据时间获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
'''
|
||||
"""Time-based memory processing"""
|
||||
|
||||
|
||||
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||
"""
|
||||
Process memory storage based on time intervals and write to Neo4j
|
||||
|
||||
Retrieves Redis data based on time intervals and writes it to Neo4j for
|
||||
long-term storage. This function handles time-based memory consolidation.
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
memory_config: Memory configuration object containing settings
|
||||
time: Time interval for data retrieval
|
||||
"""
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
format_messages = (long_time_data)
|
||||
messages=[]
|
||||
memory_config=memory_config.config_id
|
||||
format_messages = long_time_data
|
||||
messages = []
|
||||
memory_config = memory_config.config_id
|
||||
for i in format_messages:
|
||||
message=json.loads(i['Query'])
|
||||
messages+= message
|
||||
if format_messages!=[]:
|
||||
message = json.loads(i['Query'])
|
||||
messages += message
|
||||
if format_messages:
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
memory_config, messages)
|
||||
'''聚合判断'''
|
||||
|
||||
|
||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||
"""
|
||||
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
||||
Aggregation judgment function: determine if input sentence and historical messages describe the same event
|
||||
|
||||
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
|
||||
historical data or stored as separate events. This helps optimize memory storage and retrieval.
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: 内存配置对象
|
||||
"""
|
||||
end_user_id: Terminal user identifier
|
||||
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: Memory configuration object containing LLM settings
|
||||
|
||||
Returns:
|
||||
dict: Aggregation judgment result containing is_same_event flag and processed output
|
||||
"""
|
||||
history = None
|
||||
try:
|
||||
# 1. 获取历史会话数据(使用新方法)
|
||||
# 1. Get historical session data (using new method)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
@@ -210,7 +276,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
output_value = structured.output
|
||||
if isinstance(output_value, list):
|
||||
output_value = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in output_value
|
||||
]
|
||||
|
||||
@@ -223,16 +289,16 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||
memory_config.config_id, output_value)
|
||||
return result_dict
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
"output": ori_messages,
|
||||
"messages": ori_messages,
|
||||
"history": history if 'history' in locals() else [],
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,41 +2,53 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from app.core.memory.src.search import (
|
||||
search_by_temporal,
|
||||
search_by_keyword_temporal,
|
||||
)
|
||||
|
||||
|
||||
def extract_tool_message_content(response):
|
||||
"""从agent响应中提取ToolMessage内容和工具名称"""
|
||||
"""
|
||||
Extract ToolMessage content and tool names from agent response
|
||||
|
||||
Parses agent response messages to extract tool execution results and metadata.
|
||||
Handles JSON parsing and provides structured access to tool output data.
|
||||
|
||||
Args:
|
||||
response: Agent response dictionary containing messages
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
|
||||
- tool_name: Name of the executed tool
|
||||
- content: Parsed tool execution result (JSON or raw text)
|
||||
"""
|
||||
messages = response.get('messages', [])
|
||||
|
||||
for message in messages:
|
||||
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
|
||||
# 这是一个ToolMessage
|
||||
# This is a ToolMessage
|
||||
tool_content = message.content
|
||||
tool_name = None
|
||||
|
||||
# 尝试获取工具名称
|
||||
# Try to get tool name
|
||||
if hasattr(message, 'name'):
|
||||
tool_name = message.name
|
||||
elif hasattr(message, 'tool_name'):
|
||||
tool_name = message.tool_name
|
||||
|
||||
try:
|
||||
# 解析JSON内容
|
||||
# Parse JSON content
|
||||
parsed_content = json.loads(tool_content)
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': parsed_content
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# 如果不是JSON格式,直接返回内容
|
||||
# If not JSON format, return content directly
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': tool_content
|
||||
@@ -46,38 +58,61 @@ def extract_tool_message_content(response):
|
||||
|
||||
|
||||
class TimeRetrievalInput(BaseModel):
|
||||
"""时间检索工具的输入模式"""
|
||||
"""
|
||||
Input schema for time retrieval tool
|
||||
|
||||
Defines the expected input parameters for time-based retrieval operations.
|
||||
Used for validation and documentation of tool parameters.
|
||||
|
||||
Attributes:
|
||||
context: User input query content for search
|
||||
end_user_id: Group ID for filtering search results, defaults to test user
|
||||
"""
|
||||
context: str = Field(description="用户输入的查询内容")
|
||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
|
||||
|
||||
def create_time_retrieval_tool(end_user_id: str):
|
||||
"""
|
||||
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
|
||||
|
||||
Creates a specialized time-based retrieval tool that searches for statements within
|
||||
specified time ranges. Includes field cleaning functionality to remove unnecessary
|
||||
metadata from search results.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for scoping search results
|
||||
|
||||
Returns:
|
||||
function: Configured TimeRetrievalWithGroupId tool function
|
||||
"""
|
||||
|
||||
|
||||
def clean_temporal_result_fields(data):
|
||||
"""
|
||||
清理时间搜索结果中不需要的字段,并修改结构
|
||||
Clean unnecessary fields from temporal search results and modify structure
|
||||
|
||||
Removes metadata fields that are not needed for end-user consumption and
|
||||
restructures the response format for better usability.
|
||||
|
||||
Args:
|
||||
data: 要清理的数据
|
||||
data: Data to be cleaned (dict, list, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
Cleaned data with unnecessary fields removed
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# List of fields to filter out
|
||||
fields_to_remove = {
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'valid_at', 'invalid_at', 'statement_ids'
|
||||
}
|
||||
|
||||
|
||||
if isinstance(data, dict):
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
|
||||
# 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
|
||||
# Change statements: {"statements": [...]} to time_search: {"statements": [...]}
|
||||
cleaned_value = clean_temporal_result_fields(value)
|
||||
# 进一步将内部的 statements 改为 time_search
|
||||
# Further change internal statements to time_search
|
||||
if 'statements' in cleaned_value:
|
||||
cleaned['results'] = {
|
||||
'time_search': cleaned_value['statements']
|
||||
@@ -91,26 +126,35 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
return [clean_temporal_result_fields(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
@tool
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
|
||||
end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询上下文内容
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Performs time-based search operations with automatic metadata filtering. Supports
|
||||
flexible date range specification and provides clean, user-friendly output.
|
||||
|
||||
Explicit parameters:
|
||||
- context: Query context content
|
||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||
- end_user_id_param: Group ID (optional, overrides default group ID)
|
||||
- clean_output: Whether to clean metadata fields from output
|
||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results with temporal data
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
# 使用传入的参数或默认值
|
||||
# Use passed parameters or default values
|
||||
actual_end_user_id = end_user_id_param or end_user_id
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
# 基本时间搜索
|
||||
# Basic time search
|
||||
results = await search_by_temporal(
|
||||
end_user_id=actual_end_user_id,
|
||||
start_date=actual_start_date,
|
||||
@@ -118,33 +162,43 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
limit=10
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
cleaned_results = results
|
||||
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
@tool
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
|
||||
clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询内容
|
||||
- days_back: 向前搜索的天数,默认7天
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
- end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Performs combined keyword and temporal search operations with automatic metadata
|
||||
filtering. Provides more targeted search results by combining content relevance
|
||||
with time-based filtering.
|
||||
|
||||
Explicit parameters:
|
||||
- context: Query content for keyword matching
|
||||
- days_back: Number of days to search backwards, default 7 days
|
||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||
- clean_output: Whether to clean metadata fields from output
|
||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results combining keyword and temporal data
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
||||
|
||||
# 关键词时间搜索
|
||||
# Keyword time search
|
||||
results = await search_by_keyword_temporal(
|
||||
query_text=context,
|
||||
end_user_id=end_user_id,
|
||||
@@ -153,7 +207,7 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
limit=15
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
@@ -162,51 +216,60 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
|
||||
return TimeRetrievalWithGroupId
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"""
|
||||
创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段
|
||||
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
|
||||
|
||||
Creates an advanced hybrid search tool that combines multiple search strategies
|
||||
(keyword, vector, hybrid) with automatic result cleaning and formatting.
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数,包含end_user_id, limit, include等
|
||||
memory_config: Memory configuration object containing LLM and search settings
|
||||
**search_params: Search parameters including end_user_id, limit, include, etc.
|
||||
|
||||
Returns:
|
||||
function: Configured HybridSearch tool function with async capabilities
|
||||
"""
|
||||
|
||||
|
||||
def clean_result_fields(data):
|
||||
"""
|
||||
递归清理结果中不需要的字段
|
||||
Recursively clean unnecessary fields from results
|
||||
|
||||
Removes metadata fields that are not needed for end-user consumption,
|
||||
improving readability and reducing response size.
|
||||
|
||||
Args:
|
||||
data: 要清理的数据(可能是字典、列表或其他类型)
|
||||
data: Data to be cleaned (can be dict, list, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
Cleaned data with unnecessary fields removed
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# List of fields to filter out
|
||||
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
||||
}
|
||||
|
||||
|
||||
if isinstance(data, dict):
|
||||
# 对字典进行清理
|
||||
# Clean dictionary
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key not in fields_to_remove:
|
||||
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
|
||||
cleaned[key] = clean_result_fields(value) # Recursively clean nested data
|
||||
return cleaned
|
||||
elif isinstance(data, list):
|
||||
# 对列表中的每个元素进行清理
|
||||
# Clean each element in list
|
||||
return [clean_result_fields(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
# Return other types directly
|
||||
return data
|
||||
|
||||
|
||||
@tool
|
||||
async def HybridSearch(
|
||||
context: str,
|
||||
@@ -216,57 +279,63 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
clean_output: bool = True # 新增:是否清理输出字段
|
||||
clean_output: bool = True # New: whether to clean output fields
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
|
||||
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
|
||||
|
||||
Provides comprehensive search capabilities combining multiple search strategies
|
||||
with intelligent result ranking and automatic metadata filtering for clean output.
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
rerank_alpha: 重排序权重参数
|
||||
use_forgetting_rerank: 是否使用遗忘重排序
|
||||
use_llm_rerank: 是否使用LLM重排序
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
context: Query content for search
|
||||
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||
limit: Result quantity limit
|
||||
end_user_id: Group ID for filtering search results
|
||||
rerank_alpha: Reranking weight parameter for result scoring
|
||||
use_forgetting_rerank: Whether to use forgetting-based reranking
|
||||
use_llm_rerank: Whether to use LLM-based reranking
|
||||
clean_output: Whether to clean metadata fields from output
|
||||
|
||||
Returns:
|
||||
str: JSON formatted comprehensive search results
|
||||
"""
|
||||
try:
|
||||
# 导入run_hybrid_search函数
|
||||
# Import run_hybrid_search function
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
|
||||
# 合并参数,优先使用传入的参数
|
||||
# Merge parameters, prioritize passed parameters
|
||||
final_params = {
|
||||
"query_text": context,
|
||||
"search_type": search_type,
|
||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||
"limit": limit or search_params.get("limit", 10),
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||
"output_path": None, # 不保存到文件
|
||||
"output_path": None, # Don't save to file
|
||||
"memory_config": memory_config,
|
||||
"rerank_alpha": rerank_alpha,
|
||||
"use_forgetting_rerank": use_forgetting_rerank,
|
||||
"use_llm_rerank": use_llm_rerank
|
||||
}
|
||||
|
||||
# 执行混合检索
|
||||
# Execute hybrid retrieval
|
||||
raw_results = await run_hybrid_search(**final_params)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_result_fields(raw_results)
|
||||
else:
|
||||
cleaned_results = raw_results
|
||||
|
||||
# 格式化返回结果
|
||||
# Format return results
|
||||
formatted_results = {
|
||||
"search_query": context,
|
||||
"search_type": search_type,
|
||||
"results": cleaned_results
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_result = {
|
||||
"error": f"混合检索失败: {str(e)}",
|
||||
@@ -275,38 +344,52 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
return json.dumps(error_result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return HybridSearch
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"""
|
||||
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
|
||||
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
|
||||
|
||||
Creates a synchronous wrapper around the async hybrid search functionality,
|
||||
making it compatible with synchronous tool execution environments.
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数
|
||||
memory_config: Memory configuration object containing search settings
|
||||
**search_params: Search parameters for configuration
|
||||
|
||||
Returns:
|
||||
function: Configured HybridSearchSync tool function
|
||||
"""
|
||||
|
||||
@tool
|
||||
def HybridSearchSync(
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
|
||||
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Provides the same hybrid search capabilities as the async version but in a
|
||||
synchronous execution context. Automatically handles async-to-sync conversion.
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
context: Query content for search
|
||||
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||
limit: Result quantity limit
|
||||
end_user_id: Group ID for filtering search results
|
||||
clean_output: Whether to clean metadata fields from output
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
# 创建异步工具并执行
|
||||
# Create async tool and execute
|
||||
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
||||
return await async_tool.ainvoke({
|
||||
"context": context,
|
||||
@@ -315,7 +398,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"end_user_id": end_user_id,
|
||||
"clean_output": clean_output
|
||||
})
|
||||
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
return HybridSearchSync
|
||||
|
||||
return HybridSearchSync
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
import json
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
async def format_parsing(messages: list,type:str='string'):
|
||||
|
||||
|
||||
async def format_parsing(messages: list, type: str = 'string'):
|
||||
"""
|
||||
格式化解析消息列表
|
||||
Format and parse message lists into different output types
|
||||
|
||||
Processes message lists from storage and converts them into either string format
|
||||
or dictionary format based on the specified type parameter. Handles JSON parsing
|
||||
and role-based message organization.
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
type: 返回类型 ('string' 或 'dict')
|
||||
messages: List of message objects from storage containing message data
|
||||
type: Return type specification ('string' for text format, 'dict' for key-value pairs)
|
||||
|
||||
Returns:
|
||||
格式化后的消息列表
|
||||
list: Formatted message list in the specified format
|
||||
- 'string': List of formatted text messages with role prefixes
|
||||
- 'dict': List of dictionaries mapping user messages to AI responses
|
||||
"""
|
||||
result = []
|
||||
user=[]
|
||||
ai=[]
|
||||
user = []
|
||||
ai = []
|
||||
|
||||
for message in messages:
|
||||
hstory_messages = message['messages']
|
||||
@@ -24,25 +32,38 @@ async def format_parsing(messages: list,type:str='string'):
|
||||
role = content['role']
|
||||
content = content['content']
|
||||
if type == "string":
|
||||
if role == 'human' or role=="user":
|
||||
if role == 'human' or role == "user":
|
||||
content = '用户:' + content
|
||||
else:
|
||||
content = 'AI:' + content
|
||||
result.append(content)
|
||||
if type == "dict" :
|
||||
if role == 'human' or role=="user":
|
||||
user.append( content)
|
||||
if type == "dict":
|
||||
if role == 'human' or role == "user":
|
||||
user.append(content)
|
||||
else:
|
||||
ai.append(content)
|
||||
if type == "dict":
|
||||
for key,values in zip(user,ai):
|
||||
result.append({key:values})
|
||||
for key, values in zip(user, ai):
|
||||
result.append({key: values})
|
||||
return result
|
||||
|
||||
|
||||
async def messages_parse(messages: list | dict):
|
||||
user=[]
|
||||
ai=[]
|
||||
database=[]
|
||||
"""
|
||||
Parse messages from storage format into user-AI conversation pairs
|
||||
|
||||
Extracts and organizes conversation data from stored message format,
|
||||
separating user and AI messages and pairing them for database storage.
|
||||
|
||||
Args:
|
||||
messages: List or dictionary containing stored message data with Query fields
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing user-AI message pairs for database storage
|
||||
"""
|
||||
user = []
|
||||
ai = []
|
||||
database = []
|
||||
for message in messages:
|
||||
Query = message['Query']
|
||||
Query = json.loads(Query)
|
||||
@@ -54,10 +75,23 @@ async def messages_parse(messages: list | dict):
|
||||
ai.append(data['content'])
|
||||
for key, values in zip(user, ai):
|
||||
database.append({key, values})
|
||||
return database
|
||||
return database
|
||||
|
||||
|
||||
async def agent_chat_messages(user_content,ai_content):
|
||||
async def agent_chat_messages(user_content, ai_content):
|
||||
"""
|
||||
Create structured chat message format for agent conversations
|
||||
|
||||
Formats user and AI content into a standardized message structure suitable
|
||||
for agent processing and storage. Creates role-based message objects.
|
||||
|
||||
Args:
|
||||
user_content: User's message content string
|
||||
ai_content: AI's response content string
|
||||
|
||||
Returns:
|
||||
list: List of structured message dictionaries with role and content fields
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
@@ -13,7 +13,6 @@ from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -42,10 +41,26 @@ async def make_write_graph():
|
||||
|
||||
yield graph
|
||||
|
||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||
|
||||
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
|
||||
end_user_id: str = '', scope: int = 6):
|
||||
"""
|
||||
Handle long-term memory storage with different strategies
|
||||
|
||||
Supports multiple storage strategies including chunk-based, time-based,
|
||||
and aggregate judgment approaches for long-term memory persistence.
|
||||
|
||||
Args:
|
||||
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||
langchain_messages: List of messages to store
|
||||
memory_config: Memory configuration identifier
|
||||
end_user_id: User group identifier
|
||||
scope: Scope parameter for chunk-based storage (default: 6)
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||
write_store.save_session_write(end_user_id, langchain_messages)
|
||||
# 获取数据库会话
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
@@ -53,26 +68,39 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
|
||||
config_id=memory_config, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type=='chunk':
|
||||
'''方案一:对话窗口6轮对话'''
|
||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||
if long_term_type=='time':
|
||||
"""时间"""
|
||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||
if long_term_type=='aggregate':
|
||||
"""方案三:聚合判断"""
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||
"""Time-based strategy"""
|
||||
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
"""Strategy 3: Aggregate judgment"""
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
|
||||
"""
|
||||
Write long-term memory with different storage types
|
||||
|
||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||
Handles both RAG-based storage and traditional memory storage approaches.
|
||||
For traditional storage, uses chunk-based strategy with paired user-AI messages.
|
||||
|
||||
Args:
|
||||
storage_type: Type of storage (RAG or traditional)
|
||||
end_user_id: User group identifier
|
||||
message_chat: User message content
|
||||
aimessages: AI response messages
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_config_id: Actual configuration ID
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
else:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
@@ -101,4 +129,4 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
# asyncio.run(main())
|
||||
|
||||
@@ -8,10 +8,11 @@ from langgraph.graph import add_messages
|
||||
|
||||
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
||||
|
||||
|
||||
class WriteState(TypedDict):
|
||||
'''
|
||||
"""
|
||||
Langgrapg Writing TypedDict
|
||||
'''
|
||||
"""
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
end_user_id: str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
@@ -20,6 +21,7 @@ class WriteState(TypedDict):
|
||||
data: str
|
||||
language: str # 语言类型 ("zh" 中文, "en" 英文)
|
||||
|
||||
|
||||
class ReadState(TypedDict):
|
||||
"""
|
||||
LangGraph 工作流状态定义
|
||||
@@ -43,18 +45,20 @@ class ReadState(TypedDict):
|
||||
config_id: str
|
||||
data: str # 新增字段用于传递内容
|
||||
spit_data: dict # 新增字段用于传递问题分解结果
|
||||
problem_extension:dict
|
||||
problem_extension: dict
|
||||
storage_type: str
|
||||
user_rag_memory_id: str
|
||||
llm_id: str
|
||||
embedding_id: str
|
||||
memory_config: object # 新增字段用于传递内存配置对象
|
||||
retrieve:dict
|
||||
retrieve: dict
|
||||
RetrieveSummary: dict
|
||||
InputSummary: dict
|
||||
verify: dict
|
||||
SummaryFails: dict
|
||||
summary: dict
|
||||
|
||||
|
||||
class COUNTState:
|
||||
"""
|
||||
工作流对话检索内容计数器
|
||||
@@ -99,6 +103,7 @@ class COUNTState:
|
||||
self.total = 0
|
||||
print("[COUNTState] 已重置为 0")
|
||||
|
||||
|
||||
def deduplicate_entries(entries):
|
||||
seen = set()
|
||||
deduped = []
|
||||
@@ -109,6 +114,7 @@ def deduplicate_entries(entries):
|
||||
deduped.append(entry)
|
||||
return deduped
|
||||
|
||||
|
||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||
grouped = defaultdict(list)
|
||||
for item in data:
|
||||
@@ -142,4 +148,4 @@ def convert_extended_question_to_question(data):
|
||||
return [convert_extended_question_to_question(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
return data
|
||||
return data
|
||||
|
||||
@@ -165,7 +165,9 @@ async def write(
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
connector=neo4j_connector,
|
||||
config_id=config_id,
|
||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
||||
|
||||
__all__ = ["LabelPropagationEngine"]
|
||||
@@ -0,0 +1,484 @@
|
||||
"""标签传播聚类引擎
|
||||
|
||||
基于 ZEP 论文的动态标签传播算法,对 Neo4j 中的 ExtractedEntity 节点进行社区聚类。
|
||||
|
||||
支持两种模式:
|
||||
- 全量初始化(full_clustering):首次运行,对所有实体做完整 LPA 迭代
|
||||
- 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from math import sqrt
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.repositories.neo4j.community_repository import CommunityRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全量迭代最大轮数,防止不收敛
|
||||
MAX_ITERATIONS = 10
|
||||
# 社区摘要核心实体数量
|
||||
CORE_ENTITY_LIMIT = 5
|
||||
|
||||
|
||||
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
|
||||
"""计算两个向量的余弦相似度,任一为空则返回 0。"""
|
||||
if not v1 or not v2 or len(v1) != len(v2):
|
||||
return 0.0
|
||||
dot = sum(a * b for a, b in zip(v1, v2))
|
||||
norm1 = sqrt(sum(a * a for a in v1))
|
||||
norm2 = sqrt(sum(b * b for b in v2))
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
return dot / (norm1 * norm2)
|
||||
|
||||
|
||||
def _weighted_vote(
|
||||
neighbors: List[Dict],
|
||||
self_embedding: Optional[List[float]],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
加权多数投票,选出得票最高的社区。
|
||||
|
||||
权重 = 语义相似度(name_embedding 余弦)* activation_value 加成
|
||||
没有 community_id 的邻居不参与投票。
|
||||
"""
|
||||
votes: Dict[str, float] = {}
|
||||
for nb in neighbors:
|
||||
cid = nb.get("community_id")
|
||||
if not cid:
|
||||
continue
|
||||
sem = _cosine_similarity(self_embedding, nb.get("name_embedding"))
|
||||
act = nb.get("activation_value") or 0.5
|
||||
# 语义相似度权重 0.6,激活值权重 0.4
|
||||
weight = 0.6 * sem + 0.4 * act
|
||||
votes[cid] = votes.get(cid, 0.0) + weight
|
||||
|
||||
if not votes:
|
||||
return None
|
||||
return max(votes, key=votes.__getitem__)
|
||||
|
||||
|
||||
class LabelPropagationEngine:
|
||||
"""标签传播聚类引擎"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
config_id: Optional[str] = None,
|
||||
llm_model_id: Optional[str] = None,
|
||||
):
|
||||
self.connector = connector
|
||||
self.repo = CommunityRepository(connector)
|
||||
self.config_id = config_id
|
||||
self.llm_model_id = llm_model_id
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 公开接口
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def run(
|
||||
self,
|
||||
end_user_id: str,
|
||||
new_entity_ids: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
统一入口:自动判断全量还是增量。
|
||||
|
||||
- 若该用户尚无 Community 节点 → 全量初始化
|
||||
- 否则 → 增量更新(仅处理 new_entity_ids)
|
||||
"""
|
||||
has_communities = await self.repo.has_communities(end_user_id)
|
||||
if not has_communities:
|
||||
logger.info(f"[Clustering] 用户 {end_user_id} 首次聚类,执行全量初始化")
|
||||
await self.full_clustering(end_user_id)
|
||||
else:
|
||||
if new_entity_ids:
|
||||
logger.info(
|
||||
f"[Clustering] 增量更新,新实体数: {len(new_entity_ids)}"
|
||||
)
|
||||
await self.incremental_update(new_entity_ids, end_user_id)
|
||||
|
||||
async def full_clustering(self, end_user_id: str) -> None:
|
||||
"""
|
||||
全量标签传播初始化。
|
||||
|
||||
1. 拉取所有实体,初始化每个实体为独立社区
|
||||
2. 迭代:每轮对所有实体做邻居投票,更新社区标签
|
||||
3. 直到标签不再变化或达到 MAX_ITERATIONS
|
||||
4. 将最终标签写入 Neo4j
|
||||
"""
|
||||
entities = await self.repo.get_all_entities(end_user_id)
|
||||
if not entities:
|
||||
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
|
||||
return
|
||||
|
||||
# 初始化:每个实体持有自己 id 作为社区标签
|
||||
labels: Dict[str, str] = {e["id"]: e["id"] for e in entities}
|
||||
embeddings: Dict[str, Optional[List[float]]] = {
|
||||
e["id"]: e.get("name_embedding") for e in entities
|
||||
}
|
||||
|
||||
# 预加载所有实体的邻居,避免迭代内 O(iterations * |E|) 次 Neo4j 往返
|
||||
logger.info(f"[Clustering] 预加载 {len(entities)} 个实体的邻居图...")
|
||||
neighbors_cache: Dict[str, List[Dict]] = await self.repo.get_all_entity_neighbors_batch(end_user_id)
|
||||
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
|
||||
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
changed = 0
|
||||
# 随机顺序(Python dict 在 3.7+ 保持插入顺序,这里直接遍历)
|
||||
for entity in entities:
|
||||
eid = entity["id"]
|
||||
# 直接从缓存取邻居,不再发起 Neo4j 查询
|
||||
neighbors = neighbors_cache.get(eid, [])
|
||||
|
||||
# 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值)
|
||||
enriched = []
|
||||
for nb in neighbors:
|
||||
nb_copy = dict(nb)
|
||||
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
|
||||
enriched.append(nb_copy)
|
||||
|
||||
new_label = _weighted_vote(enriched, embeddings.get(eid))
|
||||
if new_label and new_label != labels[eid]:
|
||||
labels[eid] = new_label
|
||||
changed += 1
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS},"
|
||||
f"标签变化数: {changed}"
|
||||
)
|
||||
if changed == 0:
|
||||
logger.info("[Clustering] 标签已收敛,提前结束迭代")
|
||||
break
|
||||
|
||||
# 将最终标签写入 Neo4j
|
||||
await self._flush_labels(labels, end_user_id)
|
||||
pre_merge_count = len(set(labels.values()))
|
||||
logger.info(
|
||||
f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区,"
|
||||
f"{len(labels)} 个实体,开始后处理合并"
|
||||
)
|
||||
|
||||
# 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度)
|
||||
all_community_ids = list(set(labels.values()))
|
||||
await self._evaluate_merge(all_community_ids, end_user_id)
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
||||
f"{len(labels)} 个实体"
|
||||
)
|
||||
# 为所有社区生成元数据
|
||||
# 注意:_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活的社区
|
||||
# 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID
|
||||
surviving_communities = await self.repo.get_all_entities(end_user_id)
|
||||
surviving_community_ids = list({
|
||||
e.get("community_id") for e in surviving_communities
|
||||
if e.get("community_id")
|
||||
})
|
||||
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
|
||||
for cid in surviving_community_ids:
|
||||
await self._generate_community_metadata(cid, end_user_id)
|
||||
|
||||
async def incremental_update(
|
||||
self, new_entity_ids: List[str], end_user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
增量更新:只处理新实体及其邻居,不重跑全图。
|
||||
|
||||
1. 对每个新实体查询邻居
|
||||
2. 加权多数投票决定社区归属
|
||||
3. 若邻居无社区 → 创建新社区
|
||||
4. 若邻居分属多个社区 → 评估是否合并
|
||||
"""
|
||||
for entity_id in new_entity_ids:
|
||||
await self._process_single_entity(entity_id, end_user_id)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 内部方法
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def _process_single_entity(
|
||||
self, entity_id: str, end_user_id: str
|
||||
) -> None:
|
||||
"""处理单个新实体的社区分配。"""
|
||||
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
|
||||
|
||||
# 查询自身 embedding(从邻居查询结果中无法获取,需单独查)
|
||||
self_embedding = await self._get_entity_embedding(entity_id, end_user_id)
|
||||
|
||||
if not neighbors:
|
||||
# 孤立实体:创建单成员社区
|
||||
new_cid = self._new_community_id()
|
||||
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
||||
return
|
||||
|
||||
# 统计邻居社区分布
|
||||
community_ids_in_neighbors = set(
|
||||
nb["community_id"] for nb in neighbors if nb.get("community_id")
|
||||
)
|
||||
|
||||
target_cid = _weighted_vote(neighbors, self_embedding)
|
||||
|
||||
if target_cid is None:
|
||||
# 邻居都没有社区,连同新实体一起创建新社区
|
||||
new_cid = self._new_community_id()
|
||||
await self.repo.upsert_community(new_cid, end_user_id)
|
||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||
for nb in neighbors:
|
||||
await self.repo.assign_entity_to_community(
|
||||
nb["id"], new_cid, end_user_id
|
||||
)
|
||||
await self.repo.refresh_member_count(new_cid, end_user_id)
|
||||
logger.debug(
|
||||
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||
)
|
||||
await self._generate_community_metadata(new_cid, end_user_id)
|
||||
else:
|
||||
# 加入得票最多的社区
|
||||
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||
await self.repo.refresh_member_count(target_cid, end_user_id)
|
||||
logger.debug(f"[Clustering] 新实体 {entity_id} → 社区 {target_cid}")
|
||||
|
||||
# 若邻居分属多个社区,评估合并
|
||||
if len(community_ids_in_neighbors) > 1:
|
||||
await self._evaluate_merge(
|
||||
list(community_ids_in_neighbors), end_user_id
|
||||
)
|
||||
await self._generate_community_metadata(target_cid, end_user_id)
|
||||
|
||||
async def _evaluate_merge(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
评估多个社区是否应合并。
|
||||
|
||||
策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。
|
||||
合并时保留成员数最多的社区,其余成员迁移过来。
|
||||
|
||||
全量场景(社区数 > 20)使用批量查询,避免 N 次数据库往返。
|
||||
"""
|
||||
MERGE_THRESHOLD = 0.85
|
||||
BATCH_THRESHOLD = 20 # 超过此数量走批量查询
|
||||
|
||||
community_embeddings: Dict[str, Optional[List[float]]] = {}
|
||||
community_sizes: Dict[str, int] = {}
|
||||
|
||||
if len(community_ids) > BATCH_THRESHOLD:
|
||||
# 批量查询:一次拉取所有社区成员
|
||||
all_members = await self.repo.get_all_community_members_batch(
|
||||
community_ids, end_user_id
|
||||
)
|
||||
for cid in community_ids:
|
||||
members = all_members.get(cid, [])
|
||||
community_sizes[cid] = len(members)
|
||||
valid_embeddings = [
|
||||
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||
]
|
||||
if valid_embeddings:
|
||||
dim = len(valid_embeddings[0])
|
||||
community_embeddings[cid] = [
|
||||
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||
for i in range(dim)
|
||||
]
|
||||
else:
|
||||
community_embeddings[cid] = None
|
||||
else:
|
||||
# 增量场景:逐个查询
|
||||
for cid in community_ids:
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
community_sizes[cid] = len(members)
|
||||
valid_embeddings = [
|
||||
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||
]
|
||||
if valid_embeddings:
|
||||
dim = len(valid_embeddings[0])
|
||||
community_embeddings[cid] = [
|
||||
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||
for i in range(dim)
|
||||
]
|
||||
else:
|
||||
community_embeddings[cid] = None
|
||||
|
||||
# 找出应合并的社区对
|
||||
to_merge: List[tuple] = []
|
||||
cids = list(community_ids)
|
||||
for i in range(len(cids)):
|
||||
for j in range(i + 1, len(cids)):
|
||||
sim = _cosine_similarity(
|
||||
community_embeddings[cids[i]],
|
||||
community_embeddings[cids[j]],
|
||||
)
|
||||
if sim > MERGE_THRESHOLD:
|
||||
to_merge.append((cids[i], cids[j]))
|
||||
|
||||
logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区")
|
||||
|
||||
# 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量
|
||||
# 避免 union-find 链式传递导致语义不相关的社区被间接合并
|
||||
# (A≈B、B≈C 不代表 A≈C,不能因传递性把 A/B/C 全部合并)
|
||||
merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射
|
||||
|
||||
def get_root(x: str) -> str:
|
||||
"""路径压缩,找到 x 当前所属的根社区。"""
|
||||
while x in merged_into:
|
||||
merged_into[x] = merged_into.get(merged_into[x], merged_into[x])
|
||||
x = merged_into[x]
|
||||
return x
|
||||
|
||||
for c1, c2 in to_merge:
|
||||
root1, root2 = get_root(c1), get_root(c2)
|
||||
if root1 == root2:
|
||||
continue
|
||||
|
||||
# 用合并后的最新平均向量重新验证相似度
|
||||
# 防止链式传递:A≈B 合并后 B 的向量已更新,C 必须和新 B 相似才能合并
|
||||
current_sim = _cosine_similarity(
|
||||
community_embeddings.get(root1),
|
||||
community_embeddings.get(root2),
|
||||
)
|
||||
if current_sim <= MERGE_THRESHOLD:
|
||||
# 合并后向量已漂移,不再满足阈值,跳过
|
||||
logger.debug(
|
||||
f"[Clustering] 跳过合并 {root1} ↔ {root2},"
|
||||
f"当前相似度 {current_sim:.3f} ≤ {MERGE_THRESHOLD}"
|
||||
)
|
||||
continue
|
||||
|
||||
keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2
|
||||
dissolve = root2 if keep == root1 else root1
|
||||
merged_into[dissolve] = keep
|
||||
|
||||
members = await self.repo.get_community_members(dissolve, end_user_id)
|
||||
for m in members:
|
||||
await self.repo.assign_entity_to_community(m["id"], keep, end_user_id)
|
||||
|
||||
# 合并后重新计算 keep 的平均向量(加权平均)
|
||||
keep_emb = community_embeddings.get(keep)
|
||||
dissolve_emb = community_embeddings.get(dissolve)
|
||||
keep_size = community_sizes.get(keep, 0)
|
||||
dissolve_size = community_sizes.get(dissolve, 0)
|
||||
total_size = keep_size + dissolve_size
|
||||
if keep_emb and dissolve_emb and total_size > 0:
|
||||
dim = len(keep_emb)
|
||||
community_embeddings[keep] = [
|
||||
(keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size
|
||||
for i in range(dim)
|
||||
]
|
||||
community_embeddings[dissolve] = None
|
||||
|
||||
community_sizes[keep] = total_size
|
||||
community_sizes[dissolve] = 0
|
||||
await self.repo.refresh_member_count(keep, end_user_id)
|
||||
logger.info(
|
||||
f"[Clustering] 社区合并: {dissolve} → {keep},"
|
||||
f"相似度={current_sim:.3f},迁移 {len(members)} 个成员"
|
||||
)
|
||||
|
||||
async def _flush_labels(
|
||||
self, labels: Dict[str, str], end_user_id: str
|
||||
) -> None:
|
||||
"""将内存中的标签批量写入 Neo4j。"""
|
||||
# 先创建所有唯一社区节点
|
||||
unique_communities = set(labels.values())
|
||||
for cid in unique_communities:
|
||||
await self.repo.upsert_community(cid, end_user_id)
|
||||
|
||||
# 再批量分配实体
|
||||
for entity_id, community_id in labels.items():
|
||||
await self.repo.assign_entity_to_community(
|
||||
entity_id, community_id, end_user_id
|
||||
)
|
||||
|
||||
# 刷新成员数
|
||||
for cid in unique_communities:
|
||||
await self.repo.refresh_member_count(cid, end_user_id)
|
||||
|
||||
async def _get_entity_embedding(
|
||||
self, entity_id: str, end_user_id: str
|
||||
) -> Optional[List[float]]:
|
||||
"""查询单个实体的 name_embedding。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
"MATCH (e:ExtractedEntity {id: $eid, end_user_id: $uid}) "
|
||||
"RETURN e.name_embedding AS name_embedding",
|
||||
eid=entity_id,
|
||||
uid=end_user_id,
|
||||
)
|
||||
return result[0]["name_embedding"] if result else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _generate_community_metadata(
|
||||
self, community_id: str, end_user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
为社区生成并写入元数据:名称、摘要、核心实体。
|
||||
|
||||
- core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM)
|
||||
- name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
|
||||
"""
|
||||
try:
|
||||
members = await self.repo.get_community_members(community_id, end_user_id)
|
||||
if not members:
|
||||
return
|
||||
|
||||
# 核心实体:按 activation_value 降序取 top-N
|
||||
sorted_members = sorted(
|
||||
members,
|
||||
key=lambda m: m.get("activation_value") or 0,
|
||||
reverse=True,
|
||||
)
|
||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||
all_names = [m["name"] for m in members if m.get("name")]
|
||||
|
||||
name = "、".join(core_entities[:3]) if core_entities else community_id[:8]
|
||||
summary = f"包含实体:{', '.join(all_names)}"
|
||||
|
||||
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
|
||||
if self.llm_model_id:
|
||||
try:
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
entity_list_str = "、".join(all_names)
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过50个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(self.llm_model_id)
|
||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
name = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
summary = line[3:].strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
|
||||
|
||||
await self.repo.update_community_metadata(
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
name=name,
|
||||
summary=summary,
|
||||
core_entities=core_entities,
|
||||
)
|
||||
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _new_community_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
@@ -5,7 +5,7 @@ from typing import List, Dict, Optional
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
@@ -14,15 +14,15 @@ from app.core.memory.utils.log.logging_utils import prompt_logger
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
|
||||
class TripletExtractor:
|
||||
"""Extracts knowledge triplets and entities from statements using LLM"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
language: str = "zh"):
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
language: str = "zh"
|
||||
):
|
||||
"""Initialize the TripletExtractor with an LLM client
|
||||
|
||||
Args:
|
||||
@@ -65,7 +65,8 @@ class TripletExtractor:
|
||||
|
||||
# Create messages for LLM
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
||||
{"role": "system",
|
||||
"content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
||||
{"role": "user", "content": prompt_content}
|
||||
]
|
||||
|
||||
@@ -116,7 +117,8 @@ class TripletExtractor:
|
||||
logger.error(f"Error processing statement: {e}", exc_info=True)
|
||||
return TripletExtractionResponse(triplets=[], entities=[])
|
||||
|
||||
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]:
|
||||
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[
|
||||
str, TripletExtractionResponse]:
|
||||
"""Extract triplets and entities from statements
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
自我反思引擎实现
|
||||
Self-Reflection Engine Implementation
|
||||
|
||||
该模块实现了记忆系统的自我反思功能,包括:
|
||||
1. 基于时间的反思 - 根据时间周期触发反思
|
||||
2. 基于事实的反思 - 检测记忆冲突并解决
|
||||
3. 综合反思 - 整合多种反思策略
|
||||
4. 反思结果应用 - 更新记忆库
|
||||
This module implements the self-reflection functionality of the memory system, including:
|
||||
1. Time-based reflection - Triggers reflection based on time cycles
|
||||
2. Fact-based reflection - Detects and resolves memory conflicts
|
||||
3. Comprehensive reflection - Integrates multiple reflection strategies
|
||||
4. Reflection result application - Updates memory database
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -38,7 +38,7 @@ from app.schemas.memory_storage_schema import (
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 配置日志
|
||||
# Configure logging
|
||||
_root_logger = logging.getLogger()
|
||||
if not _root_logger.handlers:
|
||||
logging.basicConfig(
|
||||
@@ -49,35 +49,62 @@ else:
|
||||
_root_logger.setLevel(logging.INFO)
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
"""翻译响应模型"""
|
||||
"""Translation response model for language conversion"""
|
||||
data: str
|
||||
|
||||
class ReflectionRange(str, Enum):
|
||||
"""反思范围枚举"""
|
||||
PARTIAL = "partial" # 从检索结果中反思
|
||||
ALL = "all" # 从整个数据库中反思
|
||||
"""
|
||||
Reflection range enumeration
|
||||
|
||||
Defines the scope of data to be included in reflection operations.
|
||||
"""
|
||||
PARTIAL = "partial" # Reflect from retrieval results
|
||||
ALL = "all" # Reflect from entire database
|
||||
|
||||
|
||||
class ReflectionBaseline(str, Enum):
|
||||
"""反思基线枚举"""
|
||||
TIME = "TIME" # 基于时间的反思
|
||||
FACT = "FACT" # 基于事实的反思
|
||||
HYBRID = "HYBRID" # 混合反思
|
||||
"""
|
||||
Reflection baseline enumeration
|
||||
|
||||
Defines the strategy or approach used for reflection operations.
|
||||
"""
|
||||
TIME = "TIME" # Time-based reflection
|
||||
FACT = "FACT" # Fact-based reflection
|
||||
HYBRID = "HYBRID" # Hybrid reflection combining multiple strategies
|
||||
|
||||
|
||||
class ReflectionConfig(BaseModel):
|
||||
"""反思引擎配置"""
|
||||
"""
|
||||
Reflection engine configuration
|
||||
|
||||
Defines all configuration parameters for the reflection engine including
|
||||
operation modes, model settings, and evaluation criteria.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether reflection engine is enabled
|
||||
iteration_period: Reflection cycle period (e.g., "3" hours)
|
||||
reflexion_range: Scope of reflection (PARTIAL or ALL)
|
||||
baseline: Reflection strategy (TIME, FACT, or HYBRID)
|
||||
model_id: LLM model identifier for reflection operations
|
||||
end_user_id: User identifier for scoped operations
|
||||
output_example: Example output format for guidance
|
||||
memory_verify: Enable memory verification checks
|
||||
quality_assessment: Enable quality assessment evaluation
|
||||
violation_handling_strategy: Strategy for handling violations
|
||||
language_type: Language type for output ("zh" or "en")
|
||||
"""
|
||||
enabled: bool = False
|
||||
iteration_period: str = "3" # 反思周期
|
||||
iteration_period: str = "3" # Reflection cycle period
|
||||
reflexion_range: ReflectionRange = ReflectionRange.PARTIAL
|
||||
baseline: ReflectionBaseline = ReflectionBaseline.TIME
|
||||
model_id: Optional[str] = None # 模型ID
|
||||
model_id: Optional[str] = None # Model ID
|
||||
end_user_id: Optional[str] = None
|
||||
output_example: Optional[str] = None # 输出示例
|
||||
output_example: Optional[str] = None # Output example
|
||||
|
||||
# 评估相关字段
|
||||
memory_verify: bool = True # 记忆验证
|
||||
quality_assessment: bool = True # 质量评估
|
||||
violation_handling_strategy: str = "warn" # 违规处理策略
|
||||
# Evaluation related fields
|
||||
memory_verify: bool = True # Memory verification
|
||||
quality_assessment: bool = True # Quality assessment
|
||||
violation_handling_strategy: str = "warn" # Violation handling strategy
|
||||
language_type: str = "zh"
|
||||
|
||||
class Config:
|
||||
@@ -85,7 +112,21 @@ class ReflectionConfig(BaseModel):
|
||||
|
||||
|
||||
class ReflectionResult(BaseModel):
|
||||
"""反思结果"""
|
||||
"""
|
||||
Reflection operation result
|
||||
|
||||
Contains comprehensive information about the outcome of a reflection operation
|
||||
including success status, metrics, and execution details.
|
||||
|
||||
Attributes:
|
||||
success: Whether the reflection operation succeeded
|
||||
message: Descriptive message about the operation result
|
||||
conflicts_found: Number of conflicts detected during reflection
|
||||
conflicts_resolved: Number of conflicts successfully resolved
|
||||
memories_updated: Number of memory entries updated in database
|
||||
execution_time: Total time taken for the reflection operation
|
||||
details: Additional details about the operation (optional)
|
||||
"""
|
||||
success: bool
|
||||
message: str
|
||||
conflicts_found: int = 0
|
||||
@@ -97,9 +138,22 @@ class ReflectionResult(BaseModel):
|
||||
|
||||
class ReflectionEngine:
|
||||
"""
|
||||
自我反思引擎
|
||||
|
||||
负责执行记忆系统的自我反思,包括冲突检测、冲突解决和记忆更新。
|
||||
Self-Reflection Engine
|
||||
|
||||
Responsible for executing memory system self-reflection operations including
|
||||
conflict detection, conflict resolution, and memory updates. Supports multiple
|
||||
reflection strategies and provides comprehensive result tracking.
|
||||
|
||||
The engine can operate in different modes:
|
||||
- Time-based: Reflects on memories within specific time periods
|
||||
- Fact-based: Detects and resolves factual conflicts in memories
|
||||
- Hybrid: Combines multiple reflection strategies
|
||||
|
||||
Attributes:
|
||||
config: Reflection engine configuration
|
||||
neo4j_connector: Neo4j database connector
|
||||
llm_client: Language model client for analysis
|
||||
Various function handlers for data processing and prompt rendering
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -115,18 +169,21 @@ class ReflectionEngine:
|
||||
update_query: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化反思引擎
|
||||
Initialize reflection engine
|
||||
|
||||
Sets up the reflection engine with configuration and optional dependencies.
|
||||
Uses lazy initialization to avoid circular imports and optimize startup time.
|
||||
|
||||
Args:
|
||||
config: 反思引擎配置
|
||||
neo4j_connector: Neo4j 连接器(可选)
|
||||
llm_client: LLM 客户端(可选)
|
||||
get_data_func: 获取数据的函数(可选)
|
||||
render_evaluate_prompt_func: 渲染评估提示词的函数(可选)
|
||||
render_reflexion_prompt_func: 渲染反思提示词的函数(可选)
|
||||
conflict_schema: 冲突结果 Schema(可选)
|
||||
reflexion_schema: 反思结果 Schema(可选)
|
||||
update_query: 更新查询语句(可选)
|
||||
config: Reflection engine configuration object
|
||||
neo4j_connector: Neo4j connector instance (optional, will be created if not provided)
|
||||
llm_client: LLM client instance (optional, will be created if not provided)
|
||||
get_data_func: Function for retrieving data (optional, uses default if not provided)
|
||||
render_evaluate_prompt_func: Function for rendering evaluation prompts (optional)
|
||||
render_reflexion_prompt_func: Function for rendering reflection prompts (optional)
|
||||
conflict_schema: Schema for conflict result validation (optional)
|
||||
reflexion_schema: Schema for reflection result validation (optional)
|
||||
update_query: Query string for database updates (optional)
|
||||
"""
|
||||
self.config = config
|
||||
self.neo4j_connector = neo4j_connector
|
||||
@@ -137,14 +194,20 @@ class ReflectionEngine:
|
||||
self.conflict_schema = conflict_schema
|
||||
self.reflexion_schema = reflexion_schema
|
||||
self.update_query = update_query
|
||||
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5
|
||||
self._semaphore = asyncio.Semaphore(5) # Default concurrency limit of 5
|
||||
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
# Lazy import to avoid circular dependencies
|
||||
self._lazy_init_done = False
|
||||
|
||||
def _lazy_init(self):
|
||||
"""延迟初始化,避免循环导入"""
|
||||
"""
|
||||
Lazy initialization to avoid circular imports
|
||||
|
||||
Initializes dependencies only when needed, preventing circular import issues
|
||||
and optimizing startup performance. Sets up default implementations for
|
||||
any components not provided during construction.
|
||||
"""
|
||||
if self._lazy_init_done:
|
||||
return
|
||||
|
||||
@@ -158,7 +221,7 @@ class ReflectionEngine:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(self.config.model_id)
|
||||
elif isinstance(self.llm_client, str):
|
||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||
# If llm_client is a string (model_id), use it to initialize the client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
@@ -172,10 +235,10 @@ class ReflectionEngine:
|
||||
model_config = config_service.get_model_config(model_id)
|
||||
|
||||
extra_params={
|
||||
"temperature": 0.2, # 降低温度提高响应速度和一致性
|
||||
"max_tokens": 600, # 限制最大token数
|
||||
"top_p": 0.8, # 优化采样参数
|
||||
"stream": False, # 确保非流式输出以获得最快响应
|
||||
"temperature": 0.2, # Lower temperature for faster response and consistency
|
||||
"max_tokens": 600, # Limit maximum token count
|
||||
"top_p": 0.8, # Optimize sampling parameters
|
||||
"stream": False, # Ensure non-streaming output for fastest response
|
||||
}
|
||||
|
||||
self.llm_client = OpenAIClient(RedBearModelConfig(
|
||||
@@ -191,7 +254,7 @@ class ReflectionEngine:
|
||||
if self.get_data_func is None:
|
||||
self.get_data_func = get_data
|
||||
|
||||
# 导入get_data_statement函数
|
||||
# Import get_data_statement function
|
||||
if not hasattr(self, 'get_data_statement'):
|
||||
self.get_data_statement = get_data_statement
|
||||
|
||||
@@ -223,13 +286,20 @@ class ReflectionEngine:
|
||||
|
||||
async def execute_reflection(self, host_id) -> ReflectionResult:
|
||||
"""
|
||||
执行完整的反思流程
|
||||
Execute complete reflection workflow
|
||||
|
||||
Performs the full reflection process including data retrieval, conflict detection,
|
||||
conflict resolution, and memory updates. This is the main entry point for
|
||||
reflection operations.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host identifier for scoping reflection operations
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive result of the reflection operation including
|
||||
success status, conflict metrics, and execution time
|
||||
"""
|
||||
# 延迟初始化
|
||||
# Lazy initialization
|
||||
self._lazy_init()
|
||||
|
||||
if not self.config.enabled:
|
||||
@@ -243,7 +313,7 @@ class ReflectionEngine:
|
||||
|
||||
print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment)
|
||||
try:
|
||||
# 1. 获取反思数据
|
||||
# 1. Get reflection data
|
||||
reflexion_data, statement_databasets = await self._get_reflexion_data(host_id)
|
||||
if not reflexion_data:
|
||||
return ReflectionResult(
|
||||
@@ -252,7 +322,7 @@ class ReflectionEngine:
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
# 2. Detect conflicts (fact-based reflection)
|
||||
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
|
||||
conflict_list=[]
|
||||
for i in conflict_data:
|
||||
@@ -261,7 +331,7 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
conflicts_found=0
|
||||
# 3. 解决冲突
|
||||
# 3. Resolve conflicts
|
||||
solved_data = await self._resolve_conflicts(conflict_list, statement_databasets)
|
||||
|
||||
if not solved_data:
|
||||
@@ -276,7 +346,7 @@ class ReflectionEngine:
|
||||
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
||||
|
||||
|
||||
# 4. 应用反思结果(更新记忆库)
|
||||
# 4. Apply reflection results (update memory database)
|
||||
memories_updated=await self._apply_reflection_results(solved_data)
|
||||
|
||||
execution_time = asyncio.get_event_loop().time() - start_time
|
||||
@@ -302,7 +372,19 @@ class ReflectionEngine:
|
||||
)
|
||||
|
||||
async def Translate(self, text):
|
||||
# 翻译中文为英文
|
||||
"""
|
||||
Translate Chinese text to English
|
||||
|
||||
Uses the configured LLM to translate Chinese text to English with structured output.
|
||||
Provides consistent translation format for reflection results.
|
||||
|
||||
Args:
|
||||
text: Chinese text to be translated
|
||||
|
||||
Returns:
|
||||
str: Translated English text
|
||||
"""
|
||||
# Translate Chinese to English
|
||||
translation_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -316,6 +398,19 @@ class ReflectionEngine:
|
||||
)
|
||||
return response.data
|
||||
async def extract_translation(self,data):
|
||||
"""
|
||||
Extract and translate reflection data to English
|
||||
|
||||
Processes reflection data structure and translates all Chinese content to English.
|
||||
Handles nested data structures including memory verifications, quality assessments,
|
||||
and reflection data while preserving the original structure.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing reflection data with Chinese content
|
||||
|
||||
Returns:
|
||||
dict: Translated data structure with English content
|
||||
"""
|
||||
end_datas={}
|
||||
end_datas['source_data']=await self.Translate(data['source_data'])
|
||||
quality_assessments = []
|
||||
@@ -350,6 +445,18 @@ class ReflectionEngine:
|
||||
return end_datas
|
||||
|
||||
async def reflection_run(self):
|
||||
"""
|
||||
Execute reflection workflow with comprehensive data processing
|
||||
|
||||
Performs a complete reflection operation including conflict detection, resolution,
|
||||
and result formatting. Supports both Chinese and English output based on
|
||||
configuration settings.
|
||||
|
||||
Returns:
|
||||
dict: Comprehensive reflection results including source data, memory verifications,
|
||||
quality assessments, and reflection data. Results are translated to English
|
||||
if language_type is set to 'en'.
|
||||
"""
|
||||
self._lazy_init()
|
||||
start_time = time.time()
|
||||
memory_verifies_flag = self.config.memory_verify
|
||||
@@ -367,7 +474,7 @@ class ReflectionEngine:
|
||||
result_data['source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(databasets, source_data)
|
||||
# 遍历数据提取字段
|
||||
# Traverse data to extract fields
|
||||
quality_assessments = []
|
||||
memory_verifies = []
|
||||
for item in conflict_data:
|
||||
@@ -375,9 +482,9 @@ class ReflectionEngine:
|
||||
memory_verifies.append(item['memory_verify'])
|
||||
result_data['memory_verifies'] = memory_verifies
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
conflicts_found = 0 # 初始化为整数0而不是空字符串
|
||||
conflicts_found = 0 # Initialize as integer 0 instead of empty string
|
||||
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
||||
# Clearn conflict_data,And memory_verify和quality_assessment
|
||||
# Clean conflict_data, and memory_verify and quality_assessment
|
||||
cleaned_conflict_data = []
|
||||
for item in conflict_data:
|
||||
cleaned_item = {
|
||||
@@ -389,7 +496,7 @@ class ReflectionEngine:
|
||||
for item in conflict_data:
|
||||
cleaned_data = []
|
||||
for row in item.get("data", []):
|
||||
# 删除 created_at / expired_at
|
||||
# Remove created_at / expired_at
|
||||
cleaned_row = {
|
||||
k: v
|
||||
for k, v in row.items()
|
||||
@@ -402,7 +509,7 @@ class ReflectionEngine:
|
||||
}
|
||||
cleaned_conflict_data_.append(cleaned_item)
|
||||
print(cleaned_conflict_data_)
|
||||
# 3. 解决冲突
|
||||
# 3. Resolve conflicts
|
||||
solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data)
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
@@ -413,7 +520,7 @@ class ReflectionEngine:
|
||||
)
|
||||
reflexion_data = []
|
||||
|
||||
# 遍历数据提取reflexion字段
|
||||
# Traverse data to extract reflexion fields
|
||||
for item in solved_data:
|
||||
if 'results' in item:
|
||||
for result in item['results']:
|
||||
@@ -431,15 +538,24 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
async def extract_fields_from_json(self):
|
||||
"""从example.json中提取source_data和databasets字段"""
|
||||
"""
|
||||
Extract source_data and databasets fields from example.json
|
||||
|
||||
Reads reflection example data from the example.json file and extracts
|
||||
the source data and database statements for testing and demonstration purposes.
|
||||
|
||||
Returns:
|
||||
tuple: (source_data, databasets) extracted from the example file
|
||||
Returns empty lists if file reading fails
|
||||
"""
|
||||
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "example")
|
||||
try:
|
||||
# 读取JSON文件
|
||||
# Read JSON file
|
||||
with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f:
|
||||
data = json.loads(f.read())
|
||||
|
||||
# 提取memory_verify下的字段
|
||||
# Extract fields under memory_verify
|
||||
memory_verify = data.get("memory_verify", {})
|
||||
source_data = memory_verify.get("source_data", [])
|
||||
databasets = memory_verify.get("databasets", [])
|
||||
@@ -451,15 +567,17 @@ class ReflectionEngine:
|
||||
|
||||
async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]:
|
||||
"""
|
||||
获取反思数据
|
||||
|
||||
根据配置的反思范围获取需要反思的记忆数据。
|
||||
Get reflection data from database
|
||||
|
||||
Retrieves memory data for reflection based on the configured reflection range.
|
||||
Supports both partial (from retrieval results) and full (entire database) modes.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping data retrieval
|
||||
|
||||
Returns:
|
||||
List[Any]: 反思数据列表
|
||||
tuple: (reflexion_data, statement_data) containing memory data for reflection
|
||||
Returns empty lists if query fails
|
||||
"""
|
||||
|
||||
print("=== 获取反思数据 ===")
|
||||
@@ -484,26 +602,29 @@ class ReflectionEngine:
|
||||
|
||||
async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
检测冲突(基于事实的反思)
|
||||
|
||||
使用 LLM 分析记忆数据,检测其中的冲突。
|
||||
Detect conflicts (fact-based reflection)
|
||||
|
||||
Uses LLM to analyze memory data and detect conflicts within the memories.
|
||||
Performs comprehensive conflict detection including memory verification and
|
||||
quality assessment based on configuration settings.
|
||||
|
||||
Args:
|
||||
data: 待检测的记忆数据
|
||||
data: Memory data to be analyzed for conflicts
|
||||
statement_databasets: Statement database records for context
|
||||
|
||||
Returns:
|
||||
List[Any]: 冲突记忆列表
|
||||
List[Any]: List of detected conflicts with detailed analysis
|
||||
"""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
# 数据预处理:如果数据量太少,直接返回无冲突
|
||||
# Data preprocessing: if data is too small, return no conflicts directly
|
||||
if len(data) < 2:
|
||||
logging.info("数据量不足,无需检测冲突")
|
||||
return []
|
||||
|
||||
# 使用转换后的数据
|
||||
# print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
|
||||
# Use converted data
|
||||
# print("Converted data:", data[:2] if len(data) > 2 else data) # Only print first 2 to avoid long logs
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
logging.info("====== 冲突检测开始 ======")
|
||||
@@ -512,7 +633,7 @@ class ReflectionEngine:
|
||||
language_type=self.config.language_type
|
||||
|
||||
try:
|
||||
# 渲染冲突检测提示词
|
||||
# Render conflict detection prompt
|
||||
rendered_prompt = await self.render_evaluate_prompt_func(
|
||||
data,
|
||||
self.conflict_schema,
|
||||
@@ -526,7 +647,7 @@ class ReflectionEngine:
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
logging.info(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
# 调用 LLM 进行冲突检测
|
||||
# Call LLM for conflict detection
|
||||
response = await self.llm_client.response_structured(
|
||||
messages,
|
||||
self.conflict_schema
|
||||
@@ -539,7 +660,7 @@ class ReflectionEngine:
|
||||
logging.error("LLM 冲突检测输出解析失败")
|
||||
return []
|
||||
|
||||
# 标准化返回格式
|
||||
# Standardize return format
|
||||
if isinstance(response, BaseModel):
|
||||
return [response.model_dump()]
|
||||
elif hasattr(response, 'dict'):
|
||||
@@ -553,15 +674,17 @@ class ReflectionEngine:
|
||||
|
||||
async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
解决冲突
|
||||
|
||||
使用 LLM 对检测到的冲突进行反思和解决。
|
||||
Resolve detected conflicts
|
||||
|
||||
Uses LLM to perform reflection and resolution on detected conflicts.
|
||||
Processes conflicts in parallel for efficiency while respecting concurrency limits.
|
||||
|
||||
Args:
|
||||
conflicts: 冲突列表
|
||||
conflicts: List of conflicts to be resolved
|
||||
statement_databasets: Statement database records for context
|
||||
|
||||
Returns:
|
||||
List[Any]: 解决方案列表
|
||||
List[Any]: List of resolution solutions with reflection results
|
||||
"""
|
||||
if not conflicts:
|
||||
return []
|
||||
@@ -570,12 +693,12 @@ class ReflectionEngine:
|
||||
baseline = self.config.baseline
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
# 并行处理每个冲突
|
||||
# Process each conflict in parallel
|
||||
async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]:
|
||||
"""解决单个冲突"""
|
||||
"""Resolve a single conflict"""
|
||||
async with self._semaphore:
|
||||
try:
|
||||
# 渲染反思提示词
|
||||
# Render reflection prompt
|
||||
rendered_prompt = await self.render_reflexion_prompt_func(
|
||||
[conflict],
|
||||
self.reflexion_schema,
|
||||
@@ -587,7 +710,7 @@ class ReflectionEngine:
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
|
||||
# 调用 LLM 进行反思
|
||||
# Call LLM for reflection
|
||||
response = await self.llm_client.response_structured(
|
||||
messages,
|
||||
self.reflexion_schema
|
||||
@@ -596,7 +719,7 @@ class ReflectionEngine:
|
||||
if not response:
|
||||
return None
|
||||
|
||||
# 标准化返回格式
|
||||
# Standardize return format
|
||||
if isinstance(response, BaseModel):
|
||||
return response.model_dump()
|
||||
elif hasattr(response, 'dict'):
|
||||
@@ -610,11 +733,11 @@ class ReflectionEngine:
|
||||
logging.warning(f"解决单个冲突失败: {e}")
|
||||
return None
|
||||
|
||||
# 并发执行所有冲突解决任务
|
||||
# Execute all conflict resolution tasks concurrently
|
||||
tasks = [_resolve_one(conflict) for conflict in conflicts]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
# 过滤掉失败的结果
|
||||
# Filter out failed results
|
||||
solved = [r for r in results if r is not None]
|
||||
|
||||
logging.info(f"成功解决 {len(solved)}/{len(conflicts)} 个冲突")
|
||||
@@ -626,15 +749,16 @@ class ReflectionEngine:
|
||||
solved_data: List[Dict[str, Any]]
|
||||
) -> int:
|
||||
"""
|
||||
应用反思结果(更新记忆库)
|
||||
|
||||
将解决冲突后的记忆更新到 Neo4j 数据库中。
|
||||
Apply reflection results (update memory database)
|
||||
|
||||
Updates the Neo4j database with resolved conflicts and reflection results.
|
||||
Processes the solved data and applies changes to the memory storage system.
|
||||
|
||||
Args:
|
||||
solved_data: 解决方案列表
|
||||
solved_data: List of resolved conflict solutions with reflection data
|
||||
|
||||
Returns:
|
||||
int: 成功更新的记忆数量
|
||||
int: Number of successfully updated memory entries
|
||||
"""
|
||||
changes = extract_and_process_changes(solved_data)
|
||||
success_count = await neo4j_data(changes)
|
||||
@@ -642,80 +766,86 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
|
||||
# 基于时间的反思方法
|
||||
# Time-based reflection methods
|
||||
async def time_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID,
|
||||
time_period: Optional[str] = None
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于时间的反思
|
||||
|
||||
根据时间周期触发反思,检查在指定时间段内的记忆。
|
||||
Time-based reflection
|
||||
|
||||
Triggers reflection based on time cycles, checking memories within
|
||||
specified time periods. Uses the configured iteration period if
|
||||
no specific time period is provided.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
time_period: 时间周期(如"三小时"),如果不提供则使用配置中的值
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
time_period: Time period (e.g., "three hours"), uses config value if not provided
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result
|
||||
"""
|
||||
period = time_period or self.config.iteration_period
|
||||
logging.info(f"执行基于时间的反思,周期: {period}")
|
||||
|
||||
# 使用标准反思流程
|
||||
# Use standard reflection workflow
|
||||
return await self.execute_reflection(host_id)
|
||||
|
||||
# 基于事实的反思方法
|
||||
# Fact-based reflection methods
|
||||
async def fact_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于事实的反思
|
||||
|
||||
检测记忆中的事实冲突并解决。
|
||||
Fact-based reflection
|
||||
|
||||
Detects and resolves factual conflicts within memories. Analyzes
|
||||
memory data for inconsistencies and contradictions that need resolution.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result
|
||||
"""
|
||||
logging.info("执行基于事实的反思")
|
||||
|
||||
# 使用标准反思流程
|
||||
# Use standard reflection workflow
|
||||
return await self.execute_reflection(host_id)
|
||||
|
||||
# 综合反思方法
|
||||
# Comprehensive reflection methods
|
||||
async def comprehensive_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
综合反思
|
||||
|
||||
整合基于时间和基于事实的反思策略。
|
||||
Comprehensive reflection
|
||||
|
||||
Integrates time-based and fact-based reflection strategies based on
|
||||
the configured baseline. Supports hybrid approaches that combine
|
||||
multiple reflection methodologies.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result combining
|
||||
multiple strategies if using hybrid baseline
|
||||
"""
|
||||
logging.info("执行综合反思")
|
||||
|
||||
# 根据配置的基线选择反思策略
|
||||
# Choose reflection strategy based on configured baseline
|
||||
if self.config.baseline == ReflectionBaseline.TIME:
|
||||
return await self.time_based_reflection(host_id)
|
||||
elif self.config.baseline == ReflectionBaseline.FACT:
|
||||
return await self.fact_based_reflection(host_id)
|
||||
elif self.config.baseline == ReflectionBaseline.HYBRID:
|
||||
# 混合策略:先执行基于时间的反思,再执行基于事实的反思
|
||||
# Hybrid strategy: execute time-based reflection first, then fact-based reflection
|
||||
time_result = await self.time_based_reflection(host_id)
|
||||
fact_result = await self.fact_based_reflection(host_id)
|
||||
|
||||
# 合并结果
|
||||
# Merge results
|
||||
return ReflectionResult(
|
||||
success=time_result.success and fact_result.success,
|
||||
message=f"时间反思: {time_result.message}; 事实反思: {fact_result.message}",
|
||||
|
||||
@@ -2,9 +2,17 @@ import json
|
||||
|
||||
|
||||
def escape_lucene_query(query: str) -> str:
|
||||
"""Escape Lucene special characters in a free-text query.
|
||||
|
||||
This prevents ParseException when using Neo4j full-text procedures.
|
||||
"""
|
||||
Escape special characters in Lucene queries
|
||||
|
||||
Prevents ParseException when using Neo4j full-text search procedures.
|
||||
Escapes all Lucene reserved special characters and operators.
|
||||
|
||||
Args:
|
||||
query: Original query string
|
||||
|
||||
Returns:
|
||||
str: Escaped query string safe for Lucene search
|
||||
"""
|
||||
if query is None:
|
||||
return ""
|
||||
@@ -22,11 +30,21 @@ def escape_lucene_query(query: str) -> str:
|
||||
return s
|
||||
|
||||
def extract_plain_query(query_input: str) -> str:
|
||||
"""Extract clean, plain-text query from various input forms.
|
||||
|
||||
"""
|
||||
Extract clean plain-text query from various input forms
|
||||
|
||||
Handles the following cases:
|
||||
- Strips surrounding quotes and whitespace
|
||||
- If input looks like JSON, prefers the 'original' field
|
||||
- Fallbacks to the raw string when parsing fails
|
||||
- Falls back to raw string when parsing fails
|
||||
- Handles dictionary-type input
|
||||
- Best-effort unescape common escape characters
|
||||
|
||||
Args:
|
||||
query_input: Query input in various forms (string, dict, etc.)
|
||||
|
||||
Returns:
|
||||
str: Extracted plain-text query string
|
||||
"""
|
||||
if query_input is None:
|
||||
return ""
|
||||
|
||||
@@ -4,7 +4,13 @@ from datetime import datetime
|
||||
|
||||
def validate_date_format(date_str: str) -> bool:
|
||||
"""
|
||||
Validate if the date string is in the format YYYY-MM-DD.
|
||||
Validate if date string conforms to YYYY-MM-DD format
|
||||
|
||||
Args:
|
||||
date_str: Date string to validate
|
||||
|
||||
Returns:
|
||||
bool: True if format is correct, False otherwise
|
||||
"""
|
||||
pattern = r"^\d{4}-\d{1,2}-\d{1,2}$"
|
||||
return bool(re.match(pattern, date_str))
|
||||
@@ -41,7 +47,20 @@ def normalize_date(date_str: str) -> str:
|
||||
|
||||
|
||||
def preprocess_date_string(date_str: str) -> str:
|
||||
"""预处理日期字符串,处理特殊格式"""
|
||||
"""
|
||||
预处理日期字符串,处理特殊格式
|
||||
|
||||
处理以下特殊格式:
|
||||
- 年份后直接跟月份没有分隔符的格式(如 "20259/28")
|
||||
- 无分隔符的纯数字格式(如 "20251028", "251028")
|
||||
- 混合分隔符,统一为 "-"
|
||||
|
||||
Args:
|
||||
date_str: 原始日期字符串
|
||||
|
||||
Returns:
|
||||
str: 预处理后的日期字符串,格式为 "YYYY-MM-DD" 或 "YYYY-MM"
|
||||
"""
|
||||
|
||||
# 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔)
|
||||
match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str)
|
||||
@@ -78,7 +97,23 @@ def preprocess_date_string(date_str: str) -> str:
|
||||
|
||||
|
||||
def fallback_parse(date_str: str) -> str:
|
||||
"""备选解析方案"""
|
||||
"""
|
||||
备选日期解析方案
|
||||
|
||||
当智能解析失败时,尝试使用预定义的日期格式进行解析。
|
||||
支持多种常见的日期格式,包括:
|
||||
- YYYY-MM-DD, YYYY/MM/DD, YYYY.MM.DD
|
||||
- YYYYMMDD, YYMMDD
|
||||
- MM-DD-YYYY, MM/DD/YYYY, MM.DD.YYYY
|
||||
- DD-MM-YYYY, DD/MM/YYYY, DD.MM.YYYY
|
||||
- YYYY-MM, YYYY/MM, YYYY.MM
|
||||
|
||||
Args:
|
||||
date_str: 待解析的日期字符串
|
||||
|
||||
Returns:
|
||||
str: 标准化后的日期字符串(YYYY-MM-DD格式),解析失败时返回原字符串
|
||||
"""
|
||||
|
||||
# 尝试常见的日期格式[citation:4][citation:5]
|
||||
formats_to_try = [
|
||||
|
||||
@@ -2,15 +2,15 @@ import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
# Setup Jinja2 environment
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
baseline: str = "TIME",
|
||||
memory_verify: bool = False,quality_assessment:bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
memory_verify: bool = False, quality_assessment: bool = False,
|
||||
statement_databasets=None, language_type: str = "zh") -> str:
|
||||
"""
|
||||
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
|
||||
|
||||
@@ -23,6 +23,8 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
if statement_databasets is None:
|
||||
statement_databasets = []
|
||||
template = prompt_env.get_template("evaluate.jinja2")
|
||||
|
||||
# Convert Pydantic model to JSON schema if needed
|
||||
@@ -46,7 +48,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
|
||||
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
statement_databasets=None, language_type: str = "zh") -> str:
|
||||
"""
|
||||
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
||||
|
||||
@@ -58,6 +60,8 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
|
||||
Returns:
|
||||
Rendered prompt content as a string.
|
||||
"""
|
||||
if statement_databasets is None:
|
||||
statement_databasets = []
|
||||
template = prompt_env.get_template("reflexion.jinja2")
|
||||
|
||||
# Convert Pydantic model to JSON schema if needed
|
||||
@@ -69,7 +73,7 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
|
||||
json_schema = schema
|
||||
|
||||
rendered_prompt = template.render(data=data, json_schema=json_schema,
|
||||
baseline=baseline,memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets,language_type=language_type)
|
||||
baseline=baseline, memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets, language_type=language_type)
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -1,23 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
|
||||
from langchain_aws import ChatBedrock
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_openai import ChatOpenAI, OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import httpx
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLanguageModel, BaseLLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -163,25 +159,17 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
if type == ModelType.LLM:
|
||||
from langchain_openai import OpenAI
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaLLM
|
||||
return OllamaLLM
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
from langchain_aws import ChatBedrock, ChatBedrockConverse
|
||||
|
||||
return ChatBedrock
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
|
||||
from app.schemas import FileInput
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -620,11 +621,12 @@ class BaseNode(ABC):
|
||||
|
||||
@staticmethod
|
||||
async def process_message(
|
||||
provider: str,
|
||||
is_omni: bool,
|
||||
api_config: ModelInfo,
|
||||
content: str | dict | FileObject,
|
||||
end_user_id: str,
|
||||
enable_file=False
|
||||
) -> list | str | None:
|
||||
provider = api_config.provider
|
||||
if isinstance(content, dict):
|
||||
content = FileObject(
|
||||
type=content.get("type"),
|
||||
@@ -643,7 +645,7 @@ class BaseNode(ABC):
|
||||
if content.content_cache.get(provider):
|
||||
return content.content_cache[provider]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
|
||||
multimodel_service = MultimodalService(db, api_config=api_config)
|
||||
file_obj = FileInput(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
@@ -653,7 +655,8 @@ class BaseNode(ABC):
|
||||
)
|
||||
file_obj.set_content(content.get_content())
|
||||
message = await multimodel_service.process_files(
|
||||
[file_obj]
|
||||
end_user_id,
|
||||
[file_obj],
|
||||
)
|
||||
content.set_content(file_obj.get_content())
|
||||
if message:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
@@ -23,6 +23,26 @@ class IfElseNode(BaseNode):
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
result = []
|
||||
for case in self.typed_config.cases:
|
||||
expressions = []
|
||||
for expression in case.expressions:
|
||||
expressions.append({
|
||||
"left": self.get_variable(expression.left, variable_pool, strict=False),
|
||||
"right": expression.right
|
||||
if expression.input_type == ValueInputType.CONSTANT
|
||||
else self.get_variable(expression.right, variable_pool, strict=False),
|
||||
"operator": expression.operator,
|
||||
})
|
||||
result.append({
|
||||
"expressions": expressions,
|
||||
"logical_operator": case.logical_operator,
|
||||
})
|
||||
return {
|
||||
"cases": result
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||
match operator:
|
||||
|
||||
@@ -30,6 +30,12 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
"output": VariableType.ARRAY_STRING
|
||||
}
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
return {
|
||||
"query": self._render_template(self.typed_config.query, variable_pool),
|
||||
"knowledge_bases": [kb_config.model_dump(mode="json") for kb_config in self.typed_config.knowledge_bases],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
|
||||
"""
|
||||
|
||||
@@ -20,6 +20,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -113,12 +114,15 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 在 Session 关闭前提取所有需要的数据
|
||||
api_config = self.model_balance(config)
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
model_info = ModelInfo(
|
||||
model_name=api_config.model_name,
|
||||
model_type=ModelType(config.type),
|
||||
api_key=api_config.api_key,
|
||||
api_base=api_config.api_base,
|
||||
provider=api_config.provider,
|
||||
is_omni=api_config.is_omni,
|
||||
capability=api_config.capability
|
||||
)
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||
@@ -126,17 +130,18 @@ class LLMNode(BaseNode):
|
||||
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
model_name=model_info.model_name,
|
||||
provider=model_info.provider,
|
||||
api_key=model_info.api_key,
|
||||
base_url=model_info.api_base,
|
||||
extra_params=extra_params,
|
||||
is_omni=is_omni
|
||||
is_omni=model_info.is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
type=model_info.model_type
|
||||
)
|
||||
|
||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||
logger.debug(
|
||||
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
||||
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
@@ -148,35 +153,40 @@ class LLMNode(BaseNode):
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, variable_pool)
|
||||
content = self._render_template(content_template, variable_pool)
|
||||
|
||||
user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(
|
||||
model_info,
|
||||
content,
|
||||
user_id,
|
||||
self.typed_config.vision,
|
||||
)
|
||||
})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
})
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||
for file in files.value:
|
||||
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
@@ -190,14 +200,19 @@ class LLMNode(BaseNode):
|
||||
if isinstance(message["content"], list):
|
||||
file_content = []
|
||||
for file in message["content"]:
|
||||
content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
history_message.append(
|
||||
{"role": message["role"], "content": file_content}
|
||||
)
|
||||
else:
|
||||
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
|
||||
message["content"] = await self.process_message(
|
||||
model_info,
|
||||
message["content"],
|
||||
user_id,
|
||||
self.typed_config.vision
|
||||
)
|
||||
history_message.append(message)
|
||||
messages = messages[:-1] + history_message + messages[-1:]
|
||||
self.messages = messages
|
||||
@@ -293,7 +308,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
last_meta_data = {}
|
||||
async for chunk in llm.astream(self.messages, stream_usage=True):
|
||||
async for chunk in llm.astream(self.messages):
|
||||
# 提取内容
|
||||
if hasattr(chunk, 'content'):
|
||||
content = self.process_model_output(chunk.content)
|
||||
|
||||
@@ -37,6 +37,14 @@ class ParameterExtractorNode(BaseNode):
|
||||
}
|
||||
return None
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
return {
|
||||
"text": self._render_template(self.typed_config.text, variable_pool),
|
||||
"prompt": self._render_template(self.typed_config.prompt, variable_pool),
|
||||
"params": [param.model_dump(mode="json") for param in self.typed_config.params],
|
||||
"model_id": str(self.typed_config.model_id),
|
||||
}
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
outputs = {}
|
||||
for param in self.typed_config.params:
|
||||
|
||||
61
api/app/i18n/README.md
Normal file
61
api/app/i18n/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Internationalization (i18n) Module
|
||||
|
||||
This module provides internationalization support for the MemoryBear API.
|
||||
|
||||
## Components
|
||||
|
||||
- `service.py` - Translation service and core translation logic
|
||||
- `middleware.py` - Language detection middleware
|
||||
- `dependencies.py` - FastAPI dependency injection functions
|
||||
- `exceptions.py` - Internationalized exception classes
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Translation
|
||||
|
||||
```python
|
||||
from app.i18n import t
|
||||
|
||||
# Simple translation
|
||||
message = t("common.success.created")
|
||||
|
||||
# Parameterized translation
|
||||
message = t("common.validation.required", field="Name")
|
||||
```
|
||||
|
||||
### Enum Translation
|
||||
|
||||
```python
|
||||
from app.i18n import t_enum
|
||||
|
||||
# Translate enum value
|
||||
role_display = t_enum("workspace_role", "manager")
|
||||
```
|
||||
|
||||
### In FastAPI Endpoints
|
||||
|
||||
```python
|
||||
from fastapi import Depends
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
@router.post("/workspaces")
|
||||
async def create_workspace(
|
||||
data: WorkspaceCreate,
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
workspace = await workspace_service.create(data)
|
||||
return {
|
||||
"success": True,
|
||||
"message": t("workspace.created_successfully"),
|
||||
"data": workspace
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
See `app/core/config.py` for i18n configuration options:
|
||||
|
||||
- `I18N_DEFAULT_LANGUAGE` - Default language (default: "zh")
|
||||
- `I18N_SUPPORTED_LANGUAGES` - Supported languages (default: "zh,en")
|
||||
- `I18N_ENABLE_TRANSLATION_CACHE` - Enable caching (default: true)
|
||||
- `I18N_LOG_MISSING_TRANSLATIONS` - Log missing translations (default: true)
|
||||
124
api/app/i18n/__init__.py
Normal file
124
api/app/i18n/__init__.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Internationalization (i18n) module for MemoryBear Enterprise.
|
||||
|
||||
This module provides complete i18n support for the backend API including:
|
||||
- Translation loading from multiple directories (community + enterprise)
|
||||
- Translation service with caching and fallback
|
||||
- Language detection middleware
|
||||
- Dependency injection for FastAPI
|
||||
- Convenience functions for easy usage
|
||||
|
||||
Usage:
|
||||
from app.i18n import t, t_enum
|
||||
|
||||
# Simple translation
|
||||
message = t("common.success.created")
|
||||
|
||||
# Parameterized translation
|
||||
error = t("common.validation.required", field="名称")
|
||||
|
||||
# Enum translation
|
||||
role_display = t_enum("workspace_role", "manager")
|
||||
"""
|
||||
|
||||
from app.i18n.dependencies import (
|
||||
get_current_language,
|
||||
get_enum_translator,
|
||||
get_translator,
|
||||
)
|
||||
from app.i18n.exceptions import (
|
||||
BadRequestError,
|
||||
ConflictError,
|
||||
FileNotFoundError,
|
||||
FileTooLargeError,
|
||||
ForbiddenError,
|
||||
I18nException,
|
||||
InternalServerError,
|
||||
InvalidCredentialsError,
|
||||
InvalidFileTypeError,
|
||||
NotFoundError,
|
||||
QuotaExceededError,
|
||||
RateLimitExceededError,
|
||||
ServiceUnavailableError,
|
||||
TenantNotFoundError,
|
||||
TenantSuspendedError,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
UnauthorizedError,
|
||||
UserAlreadyExistsError,
|
||||
UserNotFoundError,
|
||||
ValidationError,
|
||||
WorkspaceNotFoundError,
|
||||
WorkspacePermissionDeniedError,
|
||||
get_current_locale,
|
||||
set_current_locale,
|
||||
)
|
||||
from app.i18n.loader import TranslationLoader
|
||||
from app.i18n.logger import (
|
||||
TranslationLogger,
|
||||
get_translation_logger,
|
||||
log_missing_translation,
|
||||
log_translation_error,
|
||||
)
|
||||
from app.i18n.middleware import LanguageMiddleware
|
||||
from app.i18n.serializers import (
|
||||
I18nResponseMixin,
|
||||
WorkspaceSerializer,
|
||||
WorkspaceMemberSerializer,
|
||||
WorkspaceInviteSerializer,
|
||||
)
|
||||
from app.i18n.service import (
|
||||
TranslationService,
|
||||
get_translation_service,
|
||||
t,
|
||||
t_enum,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TranslationLoader",
|
||||
"LanguageMiddleware",
|
||||
"TranslationService",
|
||||
"get_translation_service",
|
||||
"t",
|
||||
"t_enum",
|
||||
"get_current_language",
|
||||
"get_translator",
|
||||
"get_enum_translator",
|
||||
# Context management
|
||||
"get_current_locale",
|
||||
"set_current_locale",
|
||||
# Logging
|
||||
"TranslationLogger",
|
||||
"get_translation_logger",
|
||||
"log_missing_translation",
|
||||
"log_translation_error",
|
||||
# Serializers
|
||||
"I18nResponseMixin",
|
||||
"WorkspaceSerializer",
|
||||
"WorkspaceMemberSerializer",
|
||||
"WorkspaceInviteSerializer",
|
||||
# Exception classes
|
||||
"I18nException",
|
||||
"BadRequestError",
|
||||
"UnauthorizedError",
|
||||
"ForbiddenError",
|
||||
"NotFoundError",
|
||||
"ConflictError",
|
||||
"ValidationError",
|
||||
"InternalServerError",
|
||||
"ServiceUnavailableError",
|
||||
"WorkspaceNotFoundError",
|
||||
"WorkspacePermissionDeniedError",
|
||||
"UserNotFoundError",
|
||||
"UserAlreadyExistsError",
|
||||
"TenantNotFoundError",
|
||||
"TenantSuspendedError",
|
||||
"InvalidCredentialsError",
|
||||
"TokenExpiredError",
|
||||
"TokenInvalidError",
|
||||
"FileNotFoundError",
|
||||
"FileTooLargeError",
|
||||
"InvalidFileTypeError",
|
||||
"RateLimitExceededError",
|
||||
"QuotaExceededError",
|
||||
]
|
||||
291
api/app/i18n/cache.py
Normal file
291
api/app/i18n/cache.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
Advanced caching system for i18n translations.
|
||||
|
||||
This module provides:
|
||||
- LRU cache for hot translations
|
||||
- Lazy loading mechanism
|
||||
- Memory optimization
|
||||
- Cache statistics
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, Optional
|
||||
from collections import OrderedDict
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranslationCache:
|
||||
"""
|
||||
Advanced translation cache with LRU eviction and lazy loading.
|
||||
|
||||
Features:
|
||||
- LRU cache for frequently accessed translations
|
||||
- Lazy loading to reduce startup time
|
||||
- Memory-efficient storage
|
||||
- Cache hit/miss statistics
|
||||
"""
|
||||
|
||||
def __init__(self, max_lru_size: int = 1000, enable_lazy_load: bool = True):
|
||||
"""
|
||||
Initialize the translation cache.
|
||||
|
||||
Args:
|
||||
max_lru_size: Maximum size of LRU cache for hot translations
|
||||
enable_lazy_load: Enable lazy loading of locales
|
||||
"""
|
||||
self.max_lru_size = max_lru_size
|
||||
self.enable_lazy_load = enable_lazy_load
|
||||
|
||||
# Main cache: {locale: {namespace: {key: value}}}
|
||||
self._main_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# LRU cache for hot translations
|
||||
self._lru_cache: OrderedDict = OrderedDict()
|
||||
|
||||
# Loaded locales tracker
|
||||
self._loaded_locales: set = set()
|
||||
|
||||
# Statistics
|
||||
self._stats = {
|
||||
"hits": 0,
|
||||
"misses": 0,
|
||||
"lru_hits": 0,
|
||||
"lru_misses": 0,
|
||||
"lazy_loads": 0
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"TranslationCache initialized with LRU size: {max_lru_size}, "
|
||||
f"lazy loading: {enable_lazy_load}"
|
||||
)
|
||||
|
||||
def set_locale_data(self, locale: str, data: Dict[str, Any]):
|
||||
"""
|
||||
Set translation data for a locale.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
data: Translation data dictionary
|
||||
"""
|
||||
self._main_cache[locale] = data
|
||||
self._loaded_locales.add(locale)
|
||||
logger.debug(f"Loaded locale '{locale}' into cache")
|
||||
|
||||
def get_translation(
|
||||
self,
|
||||
locale: str,
|
||||
namespace: str,
|
||||
key_path: list
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get translation from cache with LRU optimization.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
namespace: Translation namespace
|
||||
key_path: List of nested keys
|
||||
|
||||
Returns:
|
||||
Translation string or None if not found
|
||||
"""
|
||||
# Build cache key for LRU
|
||||
cache_key = f"{locale}:{namespace}:{'.'.join(key_path)}"
|
||||
|
||||
# Check LRU cache first (hot translations)
|
||||
if cache_key in self._lru_cache:
|
||||
self._stats["lru_hits"] += 1
|
||||
self._stats["hits"] += 1
|
||||
# Move to end (most recently used)
|
||||
self._lru_cache.move_to_end(cache_key)
|
||||
return self._lru_cache[cache_key]
|
||||
|
||||
self._stats["lru_misses"] += 1
|
||||
|
||||
# Check main cache
|
||||
if locale not in self._main_cache:
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
if namespace not in self._main_cache[locale]:
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
# Navigate through nested keys
|
||||
current = self._main_cache[locale][namespace]
|
||||
for key in key_path:
|
||||
if isinstance(current, dict) and key in current:
|
||||
current = current[key]
|
||||
else:
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
# Return only if it's a string value
|
||||
if not isinstance(current, str):
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
self._stats["hits"] += 1
|
||||
|
||||
# Add to LRU cache
|
||||
self._add_to_lru(cache_key, current)
|
||||
|
||||
return current
|
||||
|
||||
def _add_to_lru(self, key: str, value: str):
|
||||
"""
|
||||
Add translation to LRU cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Translation value
|
||||
"""
|
||||
# Remove oldest if cache is full
|
||||
if len(self._lru_cache) >= self.max_lru_size:
|
||||
self._lru_cache.popitem(last=False)
|
||||
|
||||
self._lru_cache[key] = value
|
||||
|
||||
def is_locale_loaded(self, locale: str) -> bool:
|
||||
"""
|
||||
Check if a locale is loaded.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
|
||||
Returns:
|
||||
True if locale is loaded
|
||||
"""
|
||||
return locale in self._loaded_locales
|
||||
|
||||
def get_loaded_locales(self) -> list:
|
||||
"""
|
||||
Get list of loaded locales.
|
||||
|
||||
Returns:
|
||||
List of locale codes
|
||||
"""
|
||||
return list(self._loaded_locales)
|
||||
|
||||
def clear_lru(self):
|
||||
"""Clear the LRU cache."""
|
||||
self._lru_cache.clear()
|
||||
logger.info("LRU cache cleared")
|
||||
|
||||
def clear_locale(self, locale: str):
|
||||
"""
|
||||
Clear cache for a specific locale.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
"""
|
||||
if locale in self._main_cache:
|
||||
del self._main_cache[locale]
|
||||
self._loaded_locales.discard(locale)
|
||||
|
||||
# Clear related LRU entries
|
||||
keys_to_remove = [k for k in self._lru_cache if k.startswith(f"{locale}:")]
|
||||
for key in keys_to_remove:
|
||||
del self._lru_cache[key]
|
||||
|
||||
logger.info(f"Cleared cache for locale '{locale}'")
|
||||
|
||||
def clear_all(self):
|
||||
"""Clear all caches."""
|
||||
self._main_cache.clear()
|
||||
self._lru_cache.clear()
|
||||
self._loaded_locales.clear()
|
||||
logger.info("All caches cleared")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
total_requests = self._stats["hits"] + self._stats["misses"]
|
||||
hit_rate = (
|
||||
self._stats["hits"] / total_requests * 100
|
||||
if total_requests > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
lru_total = self._stats["lru_hits"] + self._stats["lru_misses"]
|
||||
lru_hit_rate = (
|
||||
self._stats["lru_hits"] / lru_total * 100
|
||||
if lru_total > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_requests": total_requests,
|
||||
"hits": self._stats["hits"],
|
||||
"misses": self._stats["misses"],
|
||||
"hit_rate": round(hit_rate, 2),
|
||||
"lru_hits": self._stats["lru_hits"],
|
||||
"lru_misses": self._stats["lru_misses"],
|
||||
"lru_hit_rate": round(lru_hit_rate, 2),
|
||||
"lru_size": len(self._lru_cache),
|
||||
"lru_max_size": self.max_lru_size,
|
||||
"loaded_locales": len(self._loaded_locales),
|
||||
"lazy_loads": self._stats["lazy_loads"]
|
||||
}
|
||||
|
||||
def reset_stats(self):
|
||||
"""Reset cache statistics."""
|
||||
self._stats = {
|
||||
"hits": 0,
|
||||
"misses": 0,
|
||||
"lru_hits": 0,
|
||||
"lru_misses": 0,
|
||||
"lazy_loads": 0
|
||||
}
|
||||
logger.info("Cache statistics reset")
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Estimate memory usage of the cache.
|
||||
|
||||
Returns:
|
||||
Dictionary with memory usage information
|
||||
"""
|
||||
import sys
|
||||
|
||||
main_cache_size = sys.getsizeof(self._main_cache)
|
||||
lru_cache_size = sys.getsizeof(self._lru_cache)
|
||||
|
||||
# Rough estimate of nested data
|
||||
for locale_data in self._main_cache.values():
|
||||
main_cache_size += sys.getsizeof(locale_data)
|
||||
for namespace_data in locale_data.values():
|
||||
main_cache_size += sys.getsizeof(namespace_data)
|
||||
|
||||
return {
|
||||
"main_cache_bytes": main_cache_size,
|
||||
"lru_cache_bytes": lru_cache_size,
|
||||
"total_bytes": main_cache_size + lru_cache_size,
|
||||
"main_cache_mb": round(main_cache_size / 1024 / 1024, 2),
|
||||
"lru_cache_mb": round(lru_cache_size / 1024 / 1024, 2),
|
||||
"total_mb": round((main_cache_size + lru_cache_size) / 1024 / 1024, 2)
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def get_cached_translation_key(locale: str, namespace: str, key: str) -> str:
|
||||
"""
|
||||
LRU cached function for building translation cache keys.
|
||||
|
||||
This reduces string concatenation overhead for frequently accessed keys.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
namespace: Translation namespace
|
||||
key: Translation key
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
"""
|
||||
return f"{locale}:{namespace}:{key}"
|
||||
158
api/app/i18n/dependencies.py
Normal file
158
api/app/i18n/dependencies.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
FastAPI dependency injection functions for i18n.
|
||||
|
||||
This module provides dependency injection functions that can be used
|
||||
in FastAPI route handlers to access the current language and translator.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from app.i18n.service import get_translation_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_current_language(request: Request) -> str:
|
||||
"""
|
||||
Get the current language from the request context.
|
||||
|
||||
This dependency extracts the language that was determined by the
|
||||
LanguageMiddleware and stored in request.state.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Language code (e.g., "zh", "en")
|
||||
|
||||
Usage:
|
||||
@router.get("/example")
|
||||
async def example(language: str = Depends(get_current_language)):
|
||||
return {"language": language}
|
||||
"""
|
||||
# Get language from request state (set by LanguageMiddleware)
|
||||
language = getattr(request.state, "language", None)
|
||||
|
||||
if language is None:
|
||||
# Fallback to default language if not set
|
||||
from app.core.config import settings
|
||||
language = settings.I18N_DEFAULT_LANGUAGE
|
||||
logger.warning(
|
||||
"Language not found in request.state, using default: "
|
||||
f"{language}"
|
||||
)
|
||||
|
||||
return language
|
||||
|
||||
|
||||
async def get_translator(request: Request) -> Callable:
|
||||
"""
|
||||
Get a translator function bound to the current request's language.
|
||||
|
||||
This dependency returns a translation function that automatically
|
||||
uses the current request's language, making it easy to translate
|
||||
strings in route handlers.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Translation function with signature: t(key: str, **params) -> str
|
||||
|
||||
Usage:
|
||||
@router.post("/workspaces")
|
||||
async def create_workspace(
|
||||
data: WorkspaceCreate,
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
workspace = await workspace_service.create(data)
|
||||
return {
|
||||
"success": True,
|
||||
"message": t("workspace.created_successfully"),
|
||||
"data": workspace
|
||||
}
|
||||
|
||||
# With parameters
|
||||
@router.get("/items")
|
||||
async def get_items(t: Callable = Depends(get_translator)):
|
||||
count = 5
|
||||
return {
|
||||
"message": t("items.found", count=count)
|
||||
}
|
||||
"""
|
||||
# Get current language
|
||||
language = await get_current_language(request)
|
||||
|
||||
# Get translation service
|
||||
service = get_translation_service()
|
||||
|
||||
# Return a bound translation function
|
||||
def translate(key: str, **params) -> str:
|
||||
"""
|
||||
Translate a key using the current request's language.
|
||||
|
||||
Args:
|
||||
key: Translation key (e.g., "common.success.created")
|
||||
**params: Parameters for parameterized messages
|
||||
|
||||
Returns:
|
||||
Translated string
|
||||
"""
|
||||
return service.translate(key, language, **params)
|
||||
|
||||
return translate
|
||||
|
||||
|
||||
async def get_enum_translator(request: Request) -> Callable:
|
||||
"""
|
||||
Get an enum translator function bound to the current request's language.
|
||||
|
||||
This dependency returns a function for translating enum values
|
||||
that automatically uses the current request's language.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Enum translation function with signature:
|
||||
t_enum(enum_type: str, value: str) -> str
|
||||
|
||||
Usage:
|
||||
@router.get("/workspace/{id}")
|
||||
async def get_workspace(
|
||||
id: str,
|
||||
t_enum: Callable = Depends(get_enum_translator)
|
||||
):
|
||||
workspace = await workspace_service.get(id)
|
||||
return {
|
||||
"id": workspace.id,
|
||||
"role": workspace.role,
|
||||
"role_display": t_enum("workspace_role", workspace.role),
|
||||
"status": workspace.status,
|
||||
"status_display": t_enum("workspace_status", workspace.status)
|
||||
}
|
||||
"""
|
||||
# Get current language
|
||||
language = await get_current_language(request)
|
||||
|
||||
# Get translation service
|
||||
service = get_translation_service()
|
||||
|
||||
# Return a bound enum translation function
|
||||
def translate_enum(enum_type: str, value: str) -> str:
|
||||
"""
|
||||
Translate an enum value using the current request's language.
|
||||
|
||||
Args:
|
||||
enum_type: Enum type name (e.g., "workspace_role")
|
||||
value: Enum value (e.g., "manager")
|
||||
|
||||
Returns:
|
||||
Translated enum display name
|
||||
"""
|
||||
return service.translate_enum(enum_type, value, language)
|
||||
|
||||
return translate_enum
|
||||
495
api/app/i18n/exceptions.py
Normal file
495
api/app/i18n/exceptions.py
Normal file
@@ -0,0 +1,495 @@
|
||||
"""
|
||||
Internationalized exception classes for i18n system.
|
||||
|
||||
This module provides exception classes that automatically translate
|
||||
error messages based on the current request's language.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from app.i18n.service import get_translation_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variable to store current locale
|
||||
_current_locale: ContextVar[Optional[str]] = ContextVar("current_locale", default=None)
|
||||
|
||||
|
||||
def set_current_locale(locale: str) -> None:
|
||||
"""
|
||||
Set the current locale in the context variable.
|
||||
|
||||
This should be called by the LanguageMiddleware.
|
||||
|
||||
Args:
|
||||
locale: Locale code (e.g., "zh", "en")
|
||||
"""
|
||||
_current_locale.set(locale)
|
||||
|
||||
|
||||
def get_current_locale() -> Optional[str]:
|
||||
"""
|
||||
Get the current locale from the context variable.
|
||||
|
||||
Returns:
|
||||
Locale code or None if not set
|
||||
"""
|
||||
return _current_locale.get()
|
||||
|
||||
|
||||
class I18nException(HTTPException):
|
||||
"""
|
||||
Base exception class with automatic i18n support.
|
||||
|
||||
This exception automatically translates error messages based on:
|
||||
1. The current request's language (from request.state.language)
|
||||
2. The fallback language if request language is not available
|
||||
3. The error key itself if no translation is found
|
||||
|
||||
Features:
|
||||
- Automatic error message translation
|
||||
- Parameterized error messages support
|
||||
- Consistent error response format
|
||||
- Language-aware error handling
|
||||
|
||||
Usage:
|
||||
# Simple error
|
||||
raise I18nException(
|
||||
error_key="errors.workspace.not_found",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
# Error with parameters
|
||||
raise I18nException(
|
||||
error_key="errors.validation.missing_field",
|
||||
status_code=400,
|
||||
field="name"
|
||||
)
|
||||
|
||||
# Custom error code
|
||||
raise I18nException(
|
||||
error_key="errors.workspace.not_found",
|
||||
error_code="WORKSPACE_NOT_FOUND",
|
||||
status_code=404,
|
||||
workspace_id="123"
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str,
|
||||
status_code: int = 400,
|
||||
error_code: Optional[str] = None,
|
||||
locale: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
**params
|
||||
):
|
||||
"""
|
||||
Initialize the i18n exception.
|
||||
|
||||
Args:
|
||||
error_key: Translation key for the error message
|
||||
(e.g., "errors.workspace.not_found")
|
||||
status_code: HTTP status code (default: 400)
|
||||
error_code: Custom error code for API clients
|
||||
(default: derived from error_key)
|
||||
locale: Target locale for translation (optional)
|
||||
If not provided, uses current request's language
|
||||
headers: Additional HTTP headers
|
||||
**params: Parameters for parameterized error messages
|
||||
"""
|
||||
self.error_key = error_key
|
||||
self.error_code = error_code or self._generate_error_code(error_key)
|
||||
self.params = params
|
||||
|
||||
# Get locale from request context if not provided
|
||||
if locale is None:
|
||||
locale = self._get_current_locale()
|
||||
|
||||
# Translate error message
|
||||
translation_service = get_translation_service()
|
||||
message = translation_service.translate(
|
||||
error_key,
|
||||
locale,
|
||||
**params
|
||||
)
|
||||
|
||||
# Build error detail
|
||||
detail = {
|
||||
"error_code": self.error_code,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
# Add parameters to detail if provided
|
||||
if params:
|
||||
detail["params"] = params
|
||||
|
||||
# Initialize HTTPException
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
detail=detail,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"I18nException raised: {self.error_code} "
|
||||
f"(key: {error_key}, locale: {locale})"
|
||||
)
|
||||
|
||||
def _get_current_locale(self) -> str:
|
||||
"""
|
||||
Get the current locale from request context.
|
||||
|
||||
Returns:
|
||||
Locale code (e.g., "zh", "en")
|
||||
"""
|
||||
try:
|
||||
# Try to get locale from context variable
|
||||
locale = _current_locale.get()
|
||||
if locale:
|
||||
return locale
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get locale from context: {e}")
|
||||
|
||||
# Fallback to default locale
|
||||
from app.core.config import settings
|
||||
return settings.I18N_DEFAULT_LANGUAGE
|
||||
|
||||
def _generate_error_code(self, error_key: str) -> str:
|
||||
"""
|
||||
Generate error code from error key.
|
||||
|
||||
Converts "errors.workspace.not_found" to "WORKSPACE_NOT_FOUND"
|
||||
|
||||
Args:
|
||||
error_key: Translation key
|
||||
|
||||
Returns:
|
||||
Error code in UPPER_SNAKE_CASE
|
||||
"""
|
||||
# Remove "errors." prefix if present
|
||||
if error_key.startswith("errors."):
|
||||
error_key = error_key[7:]
|
||||
|
||||
# Convert to UPPER_SNAKE_CASE
|
||||
parts = error_key.split(".")
|
||||
return "_".join(parts).upper()
|
||||
|
||||
|
||||
# Specific exception classes for common errors
|
||||
|
||||
class BadRequestError(I18nException):
|
||||
"""Bad request error (400)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.common.bad_request",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=400,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class UnauthorizedError(I18nException):
|
||||
"""Unauthorized error (401)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.auth.unauthorized",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=401,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class ForbiddenError(I18nException):
|
||||
"""Forbidden error (403)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.auth.forbidden",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=403,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class NotFoundError(I18nException):
|
||||
"""Not found error (404)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.common.not_found",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=404,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class ConflictError(I18nException):
|
||||
"""Conflict error (409)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.common.conflict",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=409,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class ValidationError(I18nException):
|
||||
"""Validation error (422)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.common.validation_failed",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=422,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class InternalServerError(I18nException):
|
||||
"""Internal server error (500)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.common.internal_error",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=500,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class ServiceUnavailableError(I18nException):
|
||||
"""Service unavailable error (503)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_key: str = "errors.common.service_unavailable",
|
||||
error_code: Optional[str] = None,
|
||||
**params
|
||||
):
|
||||
super().__init__(
|
||||
error_key=error_key,
|
||||
status_code=503,
|
||||
error_code=error_code,
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
# Domain-specific exception classes
|
||||
|
||||
class WorkspaceNotFoundError(NotFoundError):
|
||||
"""Workspace not found error."""
|
||||
|
||||
def __init__(self, workspace_id: Optional[str] = None, **params):
|
||||
if workspace_id:
|
||||
params["workspace_id"] = workspace_id
|
||||
super().__init__(
|
||||
error_key="errors.workspace.not_found",
|
||||
error_code="WORKSPACE_NOT_FOUND",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class WorkspacePermissionDeniedError(ForbiddenError):
|
||||
"""Workspace permission denied error."""
|
||||
|
||||
def __init__(self, workspace_id: Optional[str] = None, **params):
|
||||
if workspace_id:
|
||||
params["workspace_id"] = workspace_id
|
||||
super().__init__(
|
||||
error_key="errors.workspace.permission_denied",
|
||||
error_code="WORKSPACE_PERMISSION_DENIED",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class UserNotFoundError(NotFoundError):
|
||||
"""User not found error."""
|
||||
|
||||
def __init__(self, user_id: Optional[str] = None, **params):
|
||||
if user_id:
|
||||
params["user_id"] = user_id
|
||||
super().__init__(
|
||||
error_key="errors.user.not_found",
|
||||
error_code="USER_NOT_FOUND",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class UserAlreadyExistsError(ConflictError):
|
||||
"""User already exists error."""
|
||||
|
||||
def __init__(self, identifier: Optional[str] = None, **params):
|
||||
if identifier:
|
||||
params["identifier"] = identifier
|
||||
super().__init__(
|
||||
error_key="errors.user.already_exists",
|
||||
error_code="USER_ALREADY_EXISTS",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class TenantNotFoundError(NotFoundError):
|
||||
"""Tenant not found error."""
|
||||
|
||||
def __init__(self, tenant_id: Optional[str] = None, **params):
|
||||
if tenant_id:
|
||||
params["tenant_id"] = tenant_id
|
||||
super().__init__(
|
||||
error_key="errors.tenant.not_found",
|
||||
error_code="TENANT_NOT_FOUND",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class TenantSuspendedError(ForbiddenError):
|
||||
"""Tenant suspended error."""
|
||||
|
||||
def __init__(self, tenant_id: Optional[str] = None, **params):
|
||||
if tenant_id:
|
||||
params["tenant_id"] = tenant_id
|
||||
super().__init__(
|
||||
error_key="errors.tenant.suspended",
|
||||
error_code="TENANT_SUSPENDED",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class InvalidCredentialsError(UnauthorizedError):
|
||||
"""Invalid credentials error."""
|
||||
|
||||
def __init__(self, **params):
|
||||
super().__init__(
|
||||
error_key="errors.auth.invalid_credentials",
|
||||
error_code="INVALID_CREDENTIALS",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class TokenExpiredError(UnauthorizedError):
|
||||
"""Token expired error."""
|
||||
|
||||
def __init__(self, **params):
|
||||
super().__init__(
|
||||
error_key="errors.auth.token_expired",
|
||||
error_code="TOKEN_EXPIRED",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class TokenInvalidError(UnauthorizedError):
|
||||
"""Token invalid error."""
|
||||
|
||||
def __init__(self, **params):
|
||||
super().__init__(
|
||||
error_key="errors.auth.token_invalid",
|
||||
error_code="TOKEN_INVALID",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class FileNotFoundError(NotFoundError):
|
||||
"""File not found error."""
|
||||
|
||||
def __init__(self, file_id: Optional[str] = None, **params):
|
||||
if file_id:
|
||||
params["file_id"] = file_id
|
||||
super().__init__(
|
||||
error_key="errors.file.not_found",
|
||||
error_code="FILE_NOT_FOUND",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class FileTooLargeError(BadRequestError):
|
||||
"""File too large error."""
|
||||
|
||||
def __init__(self, max_size: Optional[str] = None, **params):
|
||||
if max_size:
|
||||
params["max_size"] = max_size
|
||||
super().__init__(
|
||||
error_key="errors.file.too_large",
|
||||
error_code="FILE_TOO_LARGE",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class InvalidFileTypeError(BadRequestError):
|
||||
"""Invalid file type error."""
|
||||
|
||||
def __init__(self, file_type: Optional[str] = None, **params):
|
||||
if file_type:
|
||||
params["file_type"] = file_type
|
||||
super().__init__(
|
||||
error_key="errors.file.invalid_type",
|
||||
error_code="INVALID_FILE_TYPE",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class RateLimitExceededError(I18nException):
|
||||
"""Rate limit exceeded error (429)."""
|
||||
|
||||
def __init__(self, **params):
|
||||
super().__init__(
|
||||
error_key="errors.api.rate_limit_exceeded",
|
||||
status_code=429,
|
||||
error_code="RATE_LIMIT_EXCEEDED",
|
||||
**params
|
||||
)
|
||||
|
||||
|
||||
class QuotaExceededError(ForbiddenError):
|
||||
"""Quota exceeded error."""
|
||||
|
||||
def __init__(self, resource: Optional[str] = None, **params):
|
||||
if resource:
|
||||
params["resource"] = resource
|
||||
super().__init__(
|
||||
error_key="errors.api.quota_exceeded",
|
||||
error_code="QUOTA_EXCEEDED",
|
||||
**params
|
||||
)
|
||||
199
api/app/i18n/loader.py
Normal file
199
api/app/i18n/loader.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
Translation file loader for i18n system.
|
||||
|
||||
This module handles loading translation files from multiple directories
|
||||
(community edition + enterprise edition) and provides hot reload support.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranslationLoader:
|
||||
"""
|
||||
Translation file loader that supports:
|
||||
- Loading from multiple directories (community + enterprise)
|
||||
- Hot reload of translation files
|
||||
- Automatic locale detection
|
||||
"""
|
||||
|
||||
def __init__(self, locales_dirs: Optional[List[str]] = None):
|
||||
"""
|
||||
Initialize the translation loader.
|
||||
|
||||
Args:
|
||||
locales_dirs: List of directories containing translation files.
|
||||
If None, will auto-detect from settings.
|
||||
"""
|
||||
if locales_dirs is None:
|
||||
locales_dirs = self._detect_locales_dirs()
|
||||
|
||||
self.locales_dirs = [Path(d) for d in locales_dirs]
|
||||
logger.info(f"TranslationLoader initialized with directories: {self.locales_dirs}")
|
||||
|
||||
def _detect_locales_dirs(self) -> List[str]:
|
||||
"""
|
||||
Auto-detect translation directories from settings.
|
||||
|
||||
Returns:
|
||||
List of translation directory paths
|
||||
"""
|
||||
from app.core.config import settings
|
||||
|
||||
dirs = []
|
||||
|
||||
# 1. Core locales directory (community edition, required)
|
||||
core_dir = Path(settings.I18N_CORE_LOCALES_DIR)
|
||||
if core_dir.exists():
|
||||
dirs.append(str(core_dir))
|
||||
logger.debug(f"Found core locales directory: {core_dir}")
|
||||
else:
|
||||
logger.warning(f"Core locales directory not found: {core_dir}")
|
||||
|
||||
# 2. Premium locales directory (enterprise edition, optional)
|
||||
if settings.I18N_PREMIUM_LOCALES_DIR:
|
||||
premium_dir = Path(settings.I18N_PREMIUM_LOCALES_DIR)
|
||||
if premium_dir.exists():
|
||||
dirs.append(str(premium_dir))
|
||||
logger.debug(f"Found premium locales directory: {premium_dir}")
|
||||
else:
|
||||
# Auto-detect premium directory
|
||||
premium_dir = Path("premium/locales")
|
||||
if premium_dir.exists():
|
||||
dirs.append(str(premium_dir))
|
||||
logger.debug(f"Auto-detected premium locales directory: {premium_dir}")
|
||||
|
||||
if not dirs:
|
||||
logger.error("No translation directories found!")
|
||||
|
||||
return dirs
|
||||
|
||||
def get_available_locales(self) -> List[str]:
|
||||
"""
|
||||
Get list of all available locales across all directories.
|
||||
|
||||
Returns:
|
||||
List of locale codes (e.g., ['zh', 'en'])
|
||||
"""
|
||||
locales = set()
|
||||
|
||||
for locales_dir in self.locales_dirs:
|
||||
if not locales_dir.exists():
|
||||
continue
|
||||
|
||||
for locale_dir in locales_dir.iterdir():
|
||||
if locale_dir.is_dir() and not locale_dir.name.startswith('.'):
|
||||
locales.add(locale_dir.name)
|
||||
|
||||
return sorted(list(locales))
|
||||
|
||||
def load_locale(self, locale: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Load all translation files for a specific locale from all directories.
|
||||
|
||||
Translation files are merged with priority:
|
||||
- Later directories override earlier directories
|
||||
- Enterprise translations override community translations
|
||||
|
||||
Args:
|
||||
locale: Locale code (e.g., 'zh', 'en')
|
||||
|
||||
Returns:
|
||||
Dictionary of translations organized by namespace
|
||||
Format: {namespace: {key: value, ...}, ...}
|
||||
"""
|
||||
translations = {}
|
||||
|
||||
# Load from each directory in order (later directories override earlier)
|
||||
for locales_dir in self.locales_dirs:
|
||||
locale_dir = locales_dir / locale
|
||||
if not locale_dir.exists():
|
||||
logger.debug(f"Locale directory not found: {locale_dir}")
|
||||
continue
|
||||
|
||||
# Load all JSON files in this locale directory
|
||||
for json_file in locale_dir.glob("*.json"):
|
||||
namespace = json_file.stem
|
||||
|
||||
try:
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
new_translations = json.load(f)
|
||||
|
||||
# Merge translations (deep merge)
|
||||
if namespace in translations:
|
||||
translations[namespace] = self._deep_merge(
|
||||
translations[namespace],
|
||||
new_translations
|
||||
)
|
||||
logger.debug(
|
||||
f"Merged translations: {locale}/{namespace} from {json_file}"
|
||||
)
|
||||
else:
|
||||
translations[namespace] = new_translations
|
||||
logger.debug(
|
||||
f"Loaded translations: {locale}/{namespace} from {json_file}"
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
f"Failed to parse JSON file {json_file}: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load translation file {json_file}: {e}"
|
||||
)
|
||||
|
||||
if not translations:
|
||||
logger.warning(f"No translations found for locale: {locale}")
|
||||
|
||||
return translations
|
||||
|
||||
def reload(self, locale: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Reload translation files.
|
||||
|
||||
Args:
|
||||
locale: Specific locale to reload. If None, reloads all locales.
|
||||
|
||||
Returns:
|
||||
Dictionary of reloaded translations
|
||||
Format: {locale: {namespace: {key: value}}}
|
||||
"""
|
||||
if locale:
|
||||
logger.info(f"Reloading translations for locale: {locale}")
|
||||
return {locale: self.load_locale(locale)}
|
||||
else:
|
||||
logger.info("Reloading all translations")
|
||||
all_translations = {}
|
||||
for loc in self.get_available_locales():
|
||||
all_translations[loc] = self.load_locale(loc)
|
||||
return all_translations
|
||||
|
||||
def _deep_merge(self, base: Dict, override: Dict) -> Dict:
|
||||
"""
|
||||
Deep merge two dictionaries.
|
||||
|
||||
Args:
|
||||
base: Base dictionary
|
||||
override: Dictionary with values to override
|
||||
|
||||
Returns:
|
||||
Merged dictionary
|
||||
"""
|
||||
result = base.copy()
|
||||
|
||||
for key, value in override.items():
|
||||
if (
|
||||
key in result
|
||||
and isinstance(result[key], dict)
|
||||
and isinstance(value, dict)
|
||||
):
|
||||
result[key] = self._deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
382
api/app/i18n/logger.py
Normal file
382
api/app/i18n/logger.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""
|
||||
Translation logging for i18n system.
|
||||
|
||||
This module provides:
|
||||
- TranslationLogger for recording missing translations
|
||||
- Missing translation report generation
|
||||
- Integration with existing logging system
|
||||
- Structured logging for translation events
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Set
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TranslationLogger:
|
||||
"""
|
||||
Logger for translation events and missing translations.
|
||||
|
||||
Features:
|
||||
- Records missing translations with context
|
||||
- Generates missing translation reports
|
||||
- Integrates with existing logging system
|
||||
- Provides structured logging for analysis
|
||||
"""
|
||||
|
||||
def __init__(self, log_file: Optional[str] = None):
|
||||
"""
|
||||
Initialize translation logger.
|
||||
|
||||
Args:
|
||||
log_file: Optional custom log file path for missing translations
|
||||
"""
|
||||
self.log_file = log_file or "logs/i18n/missing_translations.log"
|
||||
self._missing_translations: Dict[str, Set[str]] = defaultdict(set)
|
||||
self._missing_with_context: List[Dict] = []
|
||||
self._max_context_entries = 10000 # Keep last 10k entries
|
||||
|
||||
# Ensure log directory exists
|
||||
log_path = Path(self.log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create dedicated file handler for missing translations
|
||||
self._file_handler = logging.FileHandler(
|
||||
self.log_file,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self._file_handler.setLevel(logging.WARNING)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(
|
||||
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
self._file_handler.setFormatter(formatter)
|
||||
|
||||
# Create dedicated logger for missing translations
|
||||
self._logger = logging.getLogger("i18n.missing_translations")
|
||||
self._logger.setLevel(logging.WARNING)
|
||||
self._logger.addHandler(self._file_handler)
|
||||
self._logger.propagate = False # Don't propagate to root logger
|
||||
|
||||
logger.info(f"TranslationLogger initialized with log file: {self.log_file}")
|
||||
|
||||
def log_missing_translation(
|
||||
self,
|
||||
key: str,
|
||||
locale: str,
|
||||
context: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Log a missing translation.
|
||||
|
||||
Args:
|
||||
key: Translation key that was not found
|
||||
locale: Locale code
|
||||
context: Optional context information (e.g., request path, user info)
|
||||
"""
|
||||
# Add to missing set
|
||||
self._missing_translations[locale].add(key)
|
||||
|
||||
# Create context entry
|
||||
entry = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"key": key,
|
||||
"locale": locale,
|
||||
"context": context or {}
|
||||
}
|
||||
|
||||
# Keep only recent entries to avoid memory bloat
|
||||
if len(self._missing_with_context) >= self._max_context_entries:
|
||||
self._missing_with_context.pop(0)
|
||||
|
||||
self._missing_with_context.append(entry)
|
||||
|
||||
# Log to file
|
||||
context_str = f" (context: {context})" if context else ""
|
||||
self._logger.warning(
|
||||
f"Missing translation: key='{key}', locale='{locale}'{context_str}"
|
||||
)
|
||||
|
||||
def log_translation_error(
|
||||
self,
|
||||
error_type: str,
|
||||
message: str,
|
||||
key: Optional[str] = None,
|
||||
locale: Optional[str] = None,
|
||||
context: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Log a translation error.
|
||||
|
||||
Args:
|
||||
error_type: Type of error (e.g., "format_error", "parameter_missing")
|
||||
message: Error message
|
||||
key: Translation key (optional)
|
||||
locale: Locale code (optional)
|
||||
context: Optional context information
|
||||
"""
|
||||
error_data = {
|
||||
"error_type": error_type,
|
||||
"message": message,
|
||||
"key": key,
|
||||
"locale": locale,
|
||||
"context": context or {},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self._logger.error(
|
||||
f"Translation error: {error_type} - {message} "
|
||||
f"(key: {key}, locale: {locale})"
|
||||
)
|
||||
|
||||
def log_translation_success(
|
||||
self,
|
||||
key: str,
|
||||
locale: str,
|
||||
duration_ms: Optional[float] = None
|
||||
):
|
||||
"""
|
||||
Log a successful translation (debug level).
|
||||
|
||||
Args:
|
||||
key: Translation key
|
||||
locale: Locale code
|
||||
duration_ms: Optional duration in milliseconds
|
||||
"""
|
||||
duration_str = f" ({duration_ms:.3f}ms)" if duration_ms else ""
|
||||
logger.debug(
|
||||
f"Translation success: key='{key}', locale='{locale}'{duration_str}"
|
||||
)
|
||||
|
||||
def get_missing_translations(
|
||||
self,
|
||||
locale: Optional[str] = None
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Get missing translations.
|
||||
|
||||
Args:
|
||||
locale: Specific locale (optional, returns all if None)
|
||||
|
||||
Returns:
|
||||
Dictionary of missing translations by locale
|
||||
"""
|
||||
if locale:
|
||||
return {locale: sorted(list(self._missing_translations.get(locale, set())))}
|
||||
|
||||
return {
|
||||
loc: sorted(list(keys))
|
||||
for loc, keys in self._missing_translations.items()
|
||||
}
|
||||
|
||||
def get_missing_with_context(
|
||||
self,
|
||||
locale: Optional[str] = None,
|
||||
limit: Optional[int] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Get missing translations with context.
|
||||
|
||||
Args:
|
||||
locale: Filter by locale (optional)
|
||||
limit: Maximum number of entries to return (optional)
|
||||
|
||||
Returns:
|
||||
List of missing translation entries with context
|
||||
"""
|
||||
entries = self._missing_with_context
|
||||
|
||||
# Filter by locale if specified
|
||||
if locale:
|
||||
entries = [e for e in entries if e["locale"] == locale]
|
||||
|
||||
# Apply limit if specified
|
||||
if limit:
|
||||
entries = entries[-limit:]
|
||||
|
||||
return entries
|
||||
|
||||
def generate_report(
|
||||
self,
|
||||
locale: Optional[str] = None,
|
||||
output_file: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Generate a missing translation report.
|
||||
|
||||
Args:
|
||||
locale: Specific locale (optional, generates for all if None)
|
||||
output_file: Optional file path to save report as JSON
|
||||
|
||||
Returns:
|
||||
Report dictionary
|
||||
"""
|
||||
missing = self.get_missing_translations(locale)
|
||||
|
||||
report = {
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"total_missing": sum(len(keys) for keys in missing.values()),
|
||||
"missing_by_locale": {
|
||||
loc: {
|
||||
"count": len(keys),
|
||||
"keys": keys
|
||||
}
|
||||
for loc, keys in missing.items()
|
||||
},
|
||||
"recent_context": self.get_missing_with_context(locale, limit=100)
|
||||
}
|
||||
|
||||
# Save to file if specified
|
||||
if output_file:
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"Missing translation report saved to: {output_file}")
|
||||
|
||||
return report
|
||||
|
||||
def get_statistics(self) -> Dict:
|
||||
"""
|
||||
Get statistics about missing translations.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics
|
||||
"""
|
||||
total_missing = sum(len(keys) for keys in self._missing_translations.values())
|
||||
|
||||
# Count by namespace
|
||||
namespace_counts = defaultdict(int)
|
||||
for locale, keys in self._missing_translations.items():
|
||||
for key in keys:
|
||||
namespace = key.split('.')[0] if '.' in key else 'unknown'
|
||||
namespace_counts[namespace] += 1
|
||||
|
||||
return {
|
||||
"total_missing": total_missing,
|
||||
"locales_affected": len(self._missing_translations),
|
||||
"missing_by_locale": {
|
||||
loc: len(keys)
|
||||
for loc, keys in self._missing_translations.items()
|
||||
},
|
||||
"missing_by_namespace": dict(namespace_counts),
|
||||
"total_context_entries": len(self._missing_with_context)
|
||||
}
|
||||
|
||||
def clear(self, locale: Optional[str] = None):
|
||||
"""
|
||||
Clear missing translation records.
|
||||
|
||||
Args:
|
||||
locale: Specific locale to clear (optional, clears all if None)
|
||||
"""
|
||||
if locale:
|
||||
self._missing_translations.pop(locale, None)
|
||||
self._missing_with_context = [
|
||||
e for e in self._missing_with_context
|
||||
if e["locale"] != locale
|
||||
]
|
||||
logger.info(f"Cleared missing translations for locale: {locale}")
|
||||
else:
|
||||
self._missing_translations.clear()
|
||||
self._missing_with_context.clear()
|
||||
logger.info("Cleared all missing translations")
|
||||
|
||||
def export_to_json(self, output_file: str):
|
||||
"""
|
||||
Export all missing translations to JSON file.
|
||||
|
||||
Args:
|
||||
output_file: Output file path
|
||||
"""
|
||||
data = {
|
||||
"exported_at": datetime.now().isoformat(),
|
||||
"missing_translations": self.get_missing_translations(),
|
||||
"statistics": self.get_statistics(),
|
||||
"recent_context": self.get_missing_with_context(limit=1000)
|
||||
}
|
||||
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"Missing translations exported to: {output_file}")
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup file handler on deletion."""
|
||||
try:
|
||||
if hasattr(self, '_file_handler'):
|
||||
self._file_handler.close()
|
||||
self._logger.removeHandler(self._file_handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Global translation logger instance
|
||||
_translation_logger: Optional[TranslationLogger] = None
|
||||
|
||||
|
||||
def get_translation_logger() -> TranslationLogger:
|
||||
"""
|
||||
Get the global translation logger instance.
|
||||
|
||||
Returns:
|
||||
TranslationLogger singleton
|
||||
"""
|
||||
global _translation_logger
|
||||
if _translation_logger is None:
|
||||
_translation_logger = TranslationLogger()
|
||||
return _translation_logger
|
||||
|
||||
|
||||
def log_missing_translation(
|
||||
key: str,
|
||||
locale: str,
|
||||
context: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Log a missing translation (convenience function).
|
||||
|
||||
Args:
|
||||
key: Translation key
|
||||
locale: Locale code
|
||||
context: Optional context information
|
||||
"""
|
||||
translation_logger = get_translation_logger()
|
||||
translation_logger.log_missing_translation(key, locale, context)
|
||||
|
||||
|
||||
def log_translation_error(
|
||||
error_type: str,
|
||||
message: str,
|
||||
key: Optional[str] = None,
|
||||
locale: Optional[str] = None,
|
||||
context: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Log a translation error (convenience function).
|
||||
|
||||
Args:
|
||||
error_type: Type of error
|
||||
message: Error message
|
||||
key: Translation key (optional)
|
||||
locale: Locale code (optional)
|
||||
context: Optional context information
|
||||
"""
|
||||
translation_logger = get_translation_logger()
|
||||
translation_logger.log_translation_error(
|
||||
error_type, message, key, locale, context
|
||||
)
|
||||
337
api/app/i18n/metrics.py
Normal file
337
api/app/i18n/metrics.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
Performance monitoring and metrics for i18n system.
|
||||
|
||||
This module provides:
|
||||
- Translation request counters
|
||||
- Translation timing metrics
|
||||
- Missing translation tracking
|
||||
- Performance monitoring decorators
|
||||
- Prometheus-compatible metrics
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranslationMetrics:
|
||||
"""
|
||||
Metrics collector for translation operations.
|
||||
|
||||
Tracks:
|
||||
- Translation request counts
|
||||
- Translation timing (latency)
|
||||
- Missing translations
|
||||
- Cache performance
|
||||
- Locale usage
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize metrics collector."""
|
||||
# Request counters by locale
|
||||
self._request_counts: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# Missing translation tracker
|
||||
self._missing_translations: Dict[str, set] = defaultdict(set)
|
||||
|
||||
# Timing metrics (in milliseconds)
|
||||
self._timing_data: list = []
|
||||
self._max_timing_samples = 10000 # Keep last 10k samples
|
||||
|
||||
# Locale usage
|
||||
self._locale_usage: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# Namespace usage
|
||||
self._namespace_usage: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# Error counts
|
||||
self._error_counts: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# Start time
|
||||
self._start_time = datetime.now()
|
||||
|
||||
logger.info("TranslationMetrics initialized")
|
||||
|
||||
def record_request(self, locale: str, namespace: str = None):
|
||||
"""
|
||||
Record a translation request.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
namespace: Translation namespace (optional)
|
||||
"""
|
||||
self._request_counts[locale] += 1
|
||||
self._locale_usage[locale] += 1
|
||||
|
||||
if namespace:
|
||||
self._namespace_usage[namespace] += 1
|
||||
|
||||
def record_missing(self, key: str, locale: str):
|
||||
"""
|
||||
Record a missing translation.
|
||||
|
||||
Args:
|
||||
key: Translation key
|
||||
locale: Locale code
|
||||
"""
|
||||
self._missing_translations[locale].add(key)
|
||||
logger.debug(f"Missing translation recorded: {key} (locale: {locale})")
|
||||
|
||||
def record_timing(self, duration_ms: float, locale: str, operation: str = "translate"):
|
||||
"""
|
||||
Record translation operation timing.
|
||||
|
||||
Args:
|
||||
duration_ms: Duration in milliseconds
|
||||
locale: Locale code
|
||||
operation: Operation type
|
||||
"""
|
||||
# Keep only recent samples to avoid memory bloat
|
||||
if len(self._timing_data) >= self._max_timing_samples:
|
||||
self._timing_data.pop(0)
|
||||
|
||||
self._timing_data.append({
|
||||
"duration_ms": duration_ms,
|
||||
"locale": locale,
|
||||
"operation": operation,
|
||||
"timestamp": time.time()
|
||||
})
|
||||
|
||||
def record_error(self, error_type: str):
|
||||
"""
|
||||
Record an error.
|
||||
|
||||
Args:
|
||||
error_type: Type of error
|
||||
"""
|
||||
self._error_counts[error_type] += 1
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get metrics summary.
|
||||
|
||||
Returns:
|
||||
Dictionary with metrics summary
|
||||
"""
|
||||
total_requests = sum(self._request_counts.values())
|
||||
total_missing = sum(len(keys) for keys in self._missing_translations.values())
|
||||
|
||||
# Calculate timing statistics
|
||||
timing_stats = self._calculate_timing_stats()
|
||||
|
||||
# Calculate uptime
|
||||
uptime_seconds = (datetime.now() - self._start_time).total_seconds()
|
||||
|
||||
return {
|
||||
"uptime_seconds": round(uptime_seconds, 2),
|
||||
"total_requests": total_requests,
|
||||
"requests_per_locale": dict(self._request_counts),
|
||||
"total_missing_translations": total_missing,
|
||||
"missing_by_locale": {
|
||||
locale: len(keys)
|
||||
for locale, keys in self._missing_translations.items()
|
||||
},
|
||||
"timing": timing_stats,
|
||||
"locale_usage": dict(self._locale_usage),
|
||||
"namespace_usage": dict(self._namespace_usage),
|
||||
"error_counts": dict(self._error_counts)
|
||||
}
|
||||
|
||||
def _calculate_timing_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate timing statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with timing statistics
|
||||
"""
|
||||
if not self._timing_data:
|
||||
return {
|
||||
"count": 0,
|
||||
"avg_ms": 0,
|
||||
"min_ms": 0,
|
||||
"max_ms": 0,
|
||||
"p50_ms": 0,
|
||||
"p95_ms": 0,
|
||||
"p99_ms": 0
|
||||
}
|
||||
|
||||
durations = [d["duration_ms"] for d in self._timing_data]
|
||||
durations.sort()
|
||||
|
||||
count = len(durations)
|
||||
avg = sum(durations) / count
|
||||
|
||||
# Calculate percentiles
|
||||
p50_idx = int(count * 0.50)
|
||||
p95_idx = int(count * 0.95)
|
||||
p99_idx = int(count * 0.99)
|
||||
|
||||
return {
|
||||
"count": count,
|
||||
"avg_ms": round(avg, 3),
|
||||
"min_ms": round(durations[0], 3),
|
||||
"max_ms": round(durations[-1], 3),
|
||||
"p50_ms": round(durations[p50_idx], 3),
|
||||
"p95_ms": round(durations[p95_idx], 3),
|
||||
"p99_ms": round(durations[p99_idx], 3)
|
||||
}
|
||||
|
||||
def get_missing_translations(self, locale: Optional[str] = None) -> Dict[str, list]:
|
||||
"""
|
||||
Get missing translations.
|
||||
|
||||
Args:
|
||||
locale: Specific locale (optional, returns all if None)
|
||||
|
||||
Returns:
|
||||
Dictionary of missing translations by locale
|
||||
"""
|
||||
if locale:
|
||||
return {locale: list(self._missing_translations.get(locale, set()))}
|
||||
|
||||
return {
|
||||
locale: list(keys)
|
||||
for locale, keys in self._missing_translations.items()
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Reset all metrics."""
|
||||
self._request_counts.clear()
|
||||
self._missing_translations.clear()
|
||||
self._timing_data.clear()
|
||||
self._locale_usage.clear()
|
||||
self._namespace_usage.clear()
|
||||
self._error_counts.clear()
|
||||
self._start_time = datetime.now()
|
||||
logger.info("Metrics reset")
|
||||
|
||||
def export_prometheus(self) -> str:
|
||||
"""
|
||||
Export metrics in Prometheus format.
|
||||
|
||||
Returns:
|
||||
Prometheus-formatted metrics string
|
||||
"""
|
||||
lines = []
|
||||
|
||||
# Translation requests counter
|
||||
lines.append("# HELP i18n_translation_requests_total Total number of translation requests")
|
||||
lines.append("# TYPE i18n_translation_requests_total counter")
|
||||
for locale, count in self._request_counts.items():
|
||||
lines.append(f'i18n_translation_requests_total{{locale="{locale}"}} {count}')
|
||||
|
||||
# Missing translations counter
|
||||
lines.append("# HELP i18n_missing_translations_total Total number of missing translations")
|
||||
lines.append("# TYPE i18n_missing_translations_total counter")
|
||||
for locale, keys in self._missing_translations.items():
|
||||
lines.append(f'i18n_missing_translations_total{{locale="{locale}"}} {len(keys)}')
|
||||
|
||||
# Timing metrics
|
||||
timing_stats = self._calculate_timing_stats()
|
||||
lines.append("# HELP i18n_translation_duration_ms Translation operation duration in milliseconds")
|
||||
lines.append("# TYPE i18n_translation_duration_ms summary")
|
||||
lines.append(f'i18n_translation_duration_ms{{quantile="0.5"}} {timing_stats["p50_ms"]}')
|
||||
lines.append(f'i18n_translation_duration_ms{{quantile="0.95"}} {timing_stats["p95_ms"]}')
|
||||
lines.append(f'i18n_translation_duration_ms{{quantile="0.99"}} {timing_stats["p99_ms"]}')
|
||||
lines.append(f'i18n_translation_duration_ms_sum {sum(d["duration_ms"] for d in self._timing_data)}')
|
||||
lines.append(f'i18n_translation_duration_ms_count {timing_stats["count"]}')
|
||||
|
||||
# Error counter
|
||||
lines.append("# HELP i18n_errors_total Total number of i18n errors")
|
||||
lines.append("# TYPE i18n_errors_total counter")
|
||||
for error_type, count in self._error_counts.items():
|
||||
lines.append(f'i18n_errors_total{{type="{error_type}"}} {count}')
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# Global metrics instance
|
||||
_metrics: Optional[TranslationMetrics] = None
|
||||
|
||||
|
||||
def get_metrics() -> TranslationMetrics:
|
||||
"""
|
||||
Get the global metrics instance.
|
||||
|
||||
Returns:
|
||||
TranslationMetrics singleton
|
||||
"""
|
||||
global _metrics
|
||||
if _metrics is None:
|
||||
_metrics = TranslationMetrics()
|
||||
return _metrics
|
||||
|
||||
|
||||
def monitor_performance(operation: str = "translate"):
|
||||
"""
|
||||
Decorator to monitor translation operation performance.
|
||||
|
||||
Args:
|
||||
operation: Operation name for metrics
|
||||
|
||||
Returns:
|
||||
Decorated function
|
||||
|
||||
Example:
|
||||
@monitor_performance("translate")
|
||||
def translate(key: str, locale: str) -> str:
|
||||
...
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Record timing
|
||||
duration_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
# Try to extract locale from args/kwargs
|
||||
locale = kwargs.get("locale", "unknown")
|
||||
if not locale and len(args) > 1:
|
||||
locale = args[1] if isinstance(args[1], str) else "unknown"
|
||||
|
||||
metrics = get_metrics()
|
||||
metrics.record_timing(duration_ms, locale, operation)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Record error
|
||||
metrics = get_metrics()
|
||||
metrics.record_error(type(e).__name__)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def track_missing_translation(key: str, locale: str):
|
||||
"""
|
||||
Track a missing translation.
|
||||
|
||||
Args:
|
||||
key: Translation key
|
||||
locale: Locale code
|
||||
"""
|
||||
metrics = get_metrics()
|
||||
metrics.record_missing(key, locale)
|
||||
|
||||
|
||||
def track_translation_request(locale: str, namespace: str = None):
|
||||
"""
|
||||
Track a translation request.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
namespace: Translation namespace (optional)
|
||||
"""
|
||||
metrics = get_metrics()
|
||||
metrics.record_request(locale, namespace)
|
||||
202
api/app/i18n/middleware.py
Normal file
202
api/app/i18n/middleware.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
Language detection middleware for i18n system.
|
||||
|
||||
This middleware determines the language to use for each request based on:
|
||||
1. Query parameter (?lang=en)
|
||||
2. Accept-Language HTTP header
|
||||
3. User language preference (from database)
|
||||
4. Tenant default language
|
||||
5. System default language
|
||||
|
||||
The detected language is injected into request.state.language and
|
||||
added to the response Content-Language header.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LanguageMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Language detection middleware.
|
||||
|
||||
Determines the language for each request based on multiple sources
|
||||
with a clear priority order, validates the language is supported,
|
||||
and injects it into the request context.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
Process the request and determine the language.
|
||||
|
||||
Args:
|
||||
request: The incoming request
|
||||
call_next: The next middleware/handler in the chain
|
||||
|
||||
Returns:
|
||||
Response with Content-Language header added
|
||||
"""
|
||||
# Determine the language for this request
|
||||
language = await self._determine_language(request)
|
||||
|
||||
# Validate language is supported
|
||||
from app.core.config import settings
|
||||
if language not in settings.I18N_SUPPORTED_LANGUAGES:
|
||||
logger.warning(
|
||||
f"Unsupported language '{language}' requested, "
|
||||
f"falling back to default: {settings.I18N_DEFAULT_LANGUAGE}"
|
||||
)
|
||||
language = settings.I18N_DEFAULT_LANGUAGE
|
||||
|
||||
# Inject language into request state
|
||||
request.state.language = language
|
||||
|
||||
# Also set in context variable for exception handling
|
||||
from app.i18n.exceptions import set_current_locale
|
||||
set_current_locale(language)
|
||||
|
||||
logger.debug(f"Request language set to: {language}")
|
||||
|
||||
# Process the request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add Content-Language header to response
|
||||
response.headers["Content-Language"] = language
|
||||
|
||||
return response
|
||||
|
||||
async def _determine_language(self, request: Request) -> str:
|
||||
"""
|
||||
Determine the language to use based on priority order.
|
||||
|
||||
Priority:
|
||||
1. Query parameter (?lang=en)
|
||||
2. Accept-Language HTTP header
|
||||
3. User language preference (from database)
|
||||
4. Tenant default language
|
||||
5. System default language
|
||||
|
||||
Args:
|
||||
request: The incoming request
|
||||
|
||||
Returns:
|
||||
Language code (e.g., "zh", "en")
|
||||
"""
|
||||
from app.core.config import settings
|
||||
|
||||
# 1. Check query parameter (?lang=en)
|
||||
if "lang" in request.query_params:
|
||||
lang = request.query_params["lang"].strip().lower()
|
||||
if lang:
|
||||
logger.debug(f"Language from query parameter: {lang}")
|
||||
return lang
|
||||
|
||||
# 2. Check Accept-Language HTTP header
|
||||
if "Accept-Language" in request.headers:
|
||||
lang = self._parse_accept_language(
|
||||
request.headers["Accept-Language"]
|
||||
)
|
||||
if lang:
|
||||
logger.debug(f"Language from Accept-Language header: {lang}")
|
||||
return lang
|
||||
|
||||
# 3. Check user language preference (requires authentication)
|
||||
# Note: This assumes user is already loaded into request.state by auth middleware
|
||||
if hasattr(request.state, "user") and request.state.user:
|
||||
user = request.state.user
|
||||
if hasattr(user, "preferred_language") and user.preferred_language:
|
||||
logger.debug(
|
||||
f"Language from user preference: {user.preferred_language}"
|
||||
)
|
||||
return user.preferred_language
|
||||
|
||||
# 4. Check tenant default language
|
||||
# Note: This assumes tenant is already loaded into request.state
|
||||
if hasattr(request.state, "tenant") and request.state.tenant:
|
||||
tenant = request.state.tenant
|
||||
if hasattr(tenant, "default_language") and tenant.default_language:
|
||||
logger.debug(
|
||||
f"Language from tenant default: {tenant.default_language}"
|
||||
)
|
||||
return tenant.default_language
|
||||
|
||||
# 5. Fall back to system default language
|
||||
logger.debug(
|
||||
f"Using system default language: {settings.I18N_DEFAULT_LANGUAGE}"
|
||||
)
|
||||
return settings.I18N_DEFAULT_LANGUAGE
|
||||
|
||||
def _parse_accept_language(self, header: str) -> Optional[str]:
|
||||
"""
|
||||
Parse the Accept-Language HTTP header.
|
||||
|
||||
The Accept-Language header format:
|
||||
Accept-Language: zh-CN,zh;q=0.9,en;q=0.8,en-US;q=0.7
|
||||
|
||||
This method:
|
||||
1. Parses all language codes and their quality values
|
||||
2. Extracts the base language code (zh-CN -> zh)
|
||||
3. Sorts by quality value (higher first)
|
||||
4. Returns the first supported language
|
||||
|
||||
Args:
|
||||
header: Accept-Language header value
|
||||
|
||||
Returns:
|
||||
Language code if found and supported, None otherwise
|
||||
|
||||
Examples:
|
||||
_parse_accept_language("zh-CN,zh;q=0.9,en;q=0.8")
|
||||
# => "zh" (if zh is supported)
|
||||
|
||||
_parse_accept_language("en-US,en;q=0.9")
|
||||
# => "en" (if en is supported)
|
||||
"""
|
||||
from app.core.config import settings
|
||||
|
||||
if not header:
|
||||
return None
|
||||
|
||||
# Parse language preferences with quality values
|
||||
languages = []
|
||||
|
||||
for item in header.split(","):
|
||||
item = item.strip()
|
||||
if not item:
|
||||
continue
|
||||
|
||||
# Split language code and quality value
|
||||
parts = item.split(";")
|
||||
lang_code = parts[0].strip()
|
||||
|
||||
# Extract base language code (zh-CN -> zh, en-US -> en)
|
||||
base_lang = lang_code.split("-")[0].lower()
|
||||
|
||||
# Extract quality value (default: 1.0)
|
||||
quality = 1.0
|
||||
if len(parts) > 1:
|
||||
# Look for q=0.9 pattern
|
||||
q_match = re.search(r"q=([\d.]+)", parts[1])
|
||||
if q_match:
|
||||
try:
|
||||
quality = float(q_match.group(1))
|
||||
except ValueError:
|
||||
quality = 1.0
|
||||
|
||||
languages.append((base_lang, quality))
|
||||
|
||||
# Sort by quality value (descending)
|
||||
languages.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Return the first supported language
|
||||
for lang_code, _ in languages:
|
||||
if lang_code in settings.I18N_SUPPORTED_LANGUAGES:
|
||||
return lang_code
|
||||
|
||||
return None
|
||||
221
api/app/i18n/serializers.py
Normal file
221
api/app/i18n/serializers.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
国际化响应序列化器
|
||||
|
||||
提供基础的 I18nResponseMixin 类,用于为 API 响应添加国际化字段。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class I18nResponseMixin:
|
||||
"""国际化响应混入类
|
||||
|
||||
为响应数据添加国际化字段,特别是为枚举值添加 _display 后缀的翻译字段。
|
||||
|
||||
使用方法:
|
||||
1. 继承此类
|
||||
2. 实现 _get_enum_fields() 方法定义需要翻译的枚举字段
|
||||
3. 调用 serialize_with_i18n() 方法序列化数据
|
||||
|
||||
示例:
|
||||
class WorkspaceSerializer(I18nResponseMixin):
|
||||
def _get_enum_fields(self) -> Dict[str, str]:
|
||||
return {
|
||||
"role": "workspace_role",
|
||||
"status": "workspace_status"
|
||||
}
|
||||
|
||||
def serialize(self, workspace: Workspace, locale: str = "zh") -> Dict:
|
||||
data = {
|
||||
"id": str(workspace.id),
|
||||
"name": workspace.name,
|
||||
"role": workspace.role,
|
||||
"status": workspace.status
|
||||
}
|
||||
return self.serialize_with_i18n(data, locale)
|
||||
"""
|
||||
|
||||
def serialize_with_i18n(
|
||||
self,
|
||||
data: Any,
|
||||
locale: str = "zh"
|
||||
) -> Union[Dict, List[Dict], Any]:
|
||||
"""序列化数据并添加国际化字段
|
||||
|
||||
Args:
|
||||
data: 要序列化的数据(字典、列表或 Pydantic 模型)
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
序列化后的数据,包含国际化字段
|
||||
"""
|
||||
# 如果是 Pydantic 模型,转换为字典
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump()
|
||||
|
||||
# 处理不同类型的数据
|
||||
if isinstance(data, dict):
|
||||
return self._serialize_dict(data, locale)
|
||||
elif isinstance(data, list):
|
||||
return [self._serialize_dict(item, locale) if isinstance(item, dict) else item for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
def _serialize_dict(self, data: Dict, locale: str) -> Dict:
|
||||
"""序列化字典并添加 _display 字段
|
||||
|
||||
Args:
|
||||
data: 字典数据
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
添加了 _display 字段的字典
|
||||
"""
|
||||
from app.i18n.service import get_translation_service
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
result = data.copy()
|
||||
|
||||
# 获取需要翻译的枚举字段
|
||||
enum_fields = self._get_enum_fields()
|
||||
|
||||
# 为每个枚举字段添加 _display 字段
|
||||
for field, enum_type in enum_fields.items():
|
||||
if field in result and result[field] is not None:
|
||||
value = result[field]
|
||||
# 翻译枚举值
|
||||
display_value = translation_service.translate_enum(
|
||||
enum_type=enum_type,
|
||||
value=str(value),
|
||||
locale=locale
|
||||
)
|
||||
# 添加 _display 字段
|
||||
result[f"{field}_display"] = display_value
|
||||
|
||||
return result
|
||||
|
||||
def _get_enum_fields(self) -> Dict[str, str]:
|
||||
"""获取需要翻译的枚举字段
|
||||
|
||||
子类必须实现此方法,返回字段名到枚举类型的映射。
|
||||
|
||||
Returns:
|
||||
字段名到枚举类型的映射
|
||||
例如: {"role": "workspace_role", "status": "workspace_status"}
|
||||
"""
|
||||
return {}
|
||||
|
||||
|
||||
class WorkspaceSerializer(I18nResponseMixin):
|
||||
"""工作空间序列化器
|
||||
|
||||
为工作空间响应添加国际化字段。
|
||||
"""
|
||||
|
||||
def _get_enum_fields(self) -> Dict[str, str]:
|
||||
"""定义工作空间的枚举字段"""
|
||||
return {
|
||||
"role": "workspace_role",
|
||||
"status": "workspace_status"
|
||||
}
|
||||
|
||||
def serialize(self, workspace_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
|
||||
"""序列化工作空间数据
|
||||
|
||||
Args:
|
||||
workspace_data: 工作空间数据(字典或 Pydantic 模型)
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
序列化后的工作空间数据,包含国际化字段
|
||||
"""
|
||||
return self.serialize_with_i18n(workspace_data, locale)
|
||||
|
||||
def serialize_list(self, workspaces: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
|
||||
"""序列化工作空间列表
|
||||
|
||||
Args:
|
||||
workspaces: 工作空间列表
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
序列化后的工作空间列表
|
||||
"""
|
||||
return [self.serialize(ws, locale) for ws in workspaces]
|
||||
|
||||
|
||||
class WorkspaceMemberSerializer(I18nResponseMixin):
|
||||
"""工作空间成员序列化器
|
||||
|
||||
为工作空间成员响应添加国际化字段。
|
||||
"""
|
||||
|
||||
def _get_enum_fields(self) -> Dict[str, str]:
|
||||
"""定义工作空间成员的枚举字段"""
|
||||
return {
|
||||
"role": "workspace_role"
|
||||
}
|
||||
|
||||
def serialize(self, member_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
|
||||
"""序列化工作空间成员数据
|
||||
|
||||
Args:
|
||||
member_data: 成员数据(字典或 Pydantic 模型)
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
序列化后的成员数据,包含国际化字段
|
||||
"""
|
||||
return self.serialize_with_i18n(member_data, locale)
|
||||
|
||||
def serialize_list(self, members: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
|
||||
"""序列化工作空间成员列表
|
||||
|
||||
Args:
|
||||
members: 成员列表
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
序列化后的成员列表
|
||||
"""
|
||||
return [self.serialize(member, locale) for member in members]
|
||||
|
||||
|
||||
class WorkspaceInviteSerializer(I18nResponseMixin):
|
||||
"""工作空间邀请序列化器
|
||||
|
||||
为工作空间邀请响应添加国际化字段。
|
||||
"""
|
||||
|
||||
def _get_enum_fields(self) -> Dict[str, str]:
|
||||
"""定义工作空间邀请的枚举字段"""
|
||||
return {
|
||||
"status": "invite_status",
|
||||
"role": "workspace_role"
|
||||
}
|
||||
|
||||
def serialize(self, invite_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
|
||||
"""序列化工作空间邀请数据
|
||||
|
||||
Args:
|
||||
invite_data: 邀请数据(字典或 Pydantic 模型)
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
序列化后的邀请数据,包含国际化字段
|
||||
"""
|
||||
return self.serialize_with_i18n(invite_data, locale)
|
||||
|
||||
def serialize_list(self, invites: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
|
||||
"""序列化工作空间邀请列表
|
||||
|
||||
Args:
|
||||
invites: 邀请列表
|
||||
locale: 语言代码
|
||||
|
||||
Returns:
|
||||
序列化后的邀请列表
|
||||
"""
|
||||
return [self.serialize(invite, locale) for invite in invites]
|
||||
370
api/app/i18n/service.py
Normal file
370
api/app/i18n/service.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
Translation service for i18n system.
|
||||
|
||||
This module provides the core translation functionality including:
|
||||
- Translation lookup with fallback mechanism
|
||||
- Parameterized message support
|
||||
- Enum value translation
|
||||
- Memory caching for performance
|
||||
- Performance monitoring and metrics
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from app.i18n.loader import TranslationLoader
|
||||
from app.i18n.cache import TranslationCache
|
||||
from app.i18n.metrics import get_metrics, monitor_performance, track_missing_translation, track_translation_request
|
||||
from app.i18n.logger import get_translation_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranslationService:
|
||||
"""
|
||||
Translation service that provides:
|
||||
- Fast translation lookup with memory cache
|
||||
- Parameterized message support ({param} syntax)
|
||||
- Fallback mechanism (current locale → default locale → key)
|
||||
- Enum value translation
|
||||
- Deep merge of multi-directory translations
|
||||
"""
|
||||
|
||||
def __init__(self, locales_dirs: Optional[list] = None):
|
||||
"""
|
||||
Initialize the translation service.
|
||||
|
||||
Args:
|
||||
locales_dirs: List of directories containing translation files.
|
||||
If None, will auto-detect from settings.
|
||||
"""
|
||||
from app.core.config import settings
|
||||
|
||||
self.loader = TranslationLoader(locales_dirs)
|
||||
self.default_locale = settings.I18N_DEFAULT_LANGUAGE
|
||||
self.fallback_locale = settings.I18N_FALLBACK_LANGUAGE
|
||||
self.log_missing = settings.I18N_LOG_MISSING_TRANSLATIONS
|
||||
self.enable_cache = settings.I18N_ENABLE_TRANSLATION_CACHE
|
||||
|
||||
# Initialize advanced cache with LRU
|
||||
lru_cache_size = getattr(settings, 'I18N_LRU_CACHE_SIZE', 1000)
|
||||
self.cache = TranslationCache(
|
||||
max_lru_size=lru_cache_size,
|
||||
enable_lazy_load=False # Load all at startup for now
|
||||
)
|
||||
|
||||
# Load all translations into cache
|
||||
self._load_all_locales()
|
||||
|
||||
# Initialize metrics
|
||||
self.metrics = get_metrics()
|
||||
|
||||
# Initialize translation logger
|
||||
self.translation_logger = get_translation_logger()
|
||||
|
||||
logger.info(
|
||||
f"TranslationService initialized with default locale: {self.default_locale}, "
|
||||
f"LRU cache size: {lru_cache_size}"
|
||||
)
|
||||
|
||||
def _load_all_locales(self):
|
||||
"""Load all available locales into memory cache."""
|
||||
available_locales = self.loader.get_available_locales()
|
||||
logger.info(f"Loading translations for locales: {available_locales}")
|
||||
|
||||
for locale in available_locales:
|
||||
locale_data = self.loader.load_locale(locale)
|
||||
self.cache.set_locale_data(locale, locale_data)
|
||||
|
||||
logger.info(f"Loaded {len(available_locales)} locales into cache")
|
||||
|
||||
@monitor_performance("translate")
|
||||
def translate(
|
||||
self,
|
||||
key: str,
|
||||
locale: Optional[str] = None,
|
||||
**params
|
||||
) -> str:
|
||||
"""
|
||||
Translate a key to the target locale.
|
||||
|
||||
Supports:
|
||||
- Dot-separated keys (e.g., "common.success.created")
|
||||
- Parameterized messages (e.g., "Hello {name}")
|
||||
- Fallback mechanism
|
||||
|
||||
Args:
|
||||
key: Translation key (format: "namespace.key.subkey")
|
||||
locale: Target locale (defaults to default locale)
|
||||
**params: Parameters for parameterized messages
|
||||
|
||||
Returns:
|
||||
Translated string, or the key itself if translation not found
|
||||
|
||||
Examples:
|
||||
translate("common.success.created", "zh")
|
||||
# => "创建成功"
|
||||
|
||||
translate("common.validation.required", "zh", field="名称")
|
||||
# => "名称不能为空"
|
||||
"""
|
||||
if locale is None:
|
||||
locale = self.default_locale
|
||||
|
||||
# Parse key (namespace.key.subkey)
|
||||
parts = key.split(".", 1)
|
||||
if len(parts) < 2:
|
||||
if self.log_missing:
|
||||
logger.warning(f"Invalid translation key format: {key}")
|
||||
return key
|
||||
|
||||
namespace = parts[0]
|
||||
key_path = parts[1].split(".")
|
||||
|
||||
# Track request
|
||||
track_translation_request(locale, namespace)
|
||||
|
||||
# Get translation from cache
|
||||
translation = self.cache.get_translation(locale, namespace, key_path)
|
||||
|
||||
# Fallback to default locale if not found
|
||||
if translation is None and locale != self.fallback_locale:
|
||||
translation = self.cache.get_translation(
|
||||
self.fallback_locale, namespace, key_path
|
||||
)
|
||||
|
||||
# If still not found, return the key itself
|
||||
if translation is None:
|
||||
if self.log_missing:
|
||||
logger.warning(
|
||||
f"Missing translation: {key} (locale: {locale})"
|
||||
)
|
||||
track_missing_translation(key, locale)
|
||||
|
||||
# Log to translation logger with context
|
||||
self.translation_logger.log_missing_translation(
|
||||
key=key,
|
||||
locale=locale,
|
||||
context={"namespace": namespace}
|
||||
)
|
||||
return key
|
||||
|
||||
# Apply parameters if provided
|
||||
if params:
|
||||
try:
|
||||
translation = translation.format(**params)
|
||||
except KeyError as e:
|
||||
error_msg = f"Missing parameter in translation '{key}': {e}"
|
||||
logger.error(error_msg)
|
||||
self.translation_logger.log_translation_error(
|
||||
error_type="parameter_missing",
|
||||
message=error_msg,
|
||||
key=key,
|
||||
locale=locale,
|
||||
context={"params": list(params.keys())}
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error formatting translation '{key}': {e}"
|
||||
logger.error(error_msg)
|
||||
self.translation_logger.log_translation_error(
|
||||
error_type="format_error",
|
||||
message=error_msg,
|
||||
key=key,
|
||||
locale=locale
|
||||
)
|
||||
|
||||
return translation
|
||||
|
||||
def _get_translation(
|
||||
self,
|
||||
locale: str,
|
||||
namespace: str,
|
||||
key_path: list
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get translation from cache (deprecated, use cache.get_translation).
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
namespace: Translation namespace
|
||||
key_path: List of nested keys
|
||||
|
||||
Returns:
|
||||
Translation string or None if not found
|
||||
"""
|
||||
return self.cache.get_translation(locale, namespace, key_path)
|
||||
|
||||
@monitor_performance("translate_enum")
|
||||
def translate_enum(
|
||||
self,
|
||||
enum_type: str,
|
||||
value: str,
|
||||
locale: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Translate an enum value.
|
||||
|
||||
Args:
|
||||
enum_type: Enum type name (e.g., "workspace_role")
|
||||
value: Enum value (e.g., "manager")
|
||||
locale: Target locale
|
||||
|
||||
Returns:
|
||||
Translated enum display name
|
||||
|
||||
Examples:
|
||||
translate_enum("workspace_role", "manager", "zh")
|
||||
# => "管理员"
|
||||
|
||||
translate_enum("invite_status", "pending", "en")
|
||||
# => "Pending"
|
||||
"""
|
||||
key = f"enums.{enum_type}.{value}"
|
||||
return self.translate(key, locale)
|
||||
|
||||
def has_translation(self, key: str, locale: str) -> bool:
|
||||
"""
|
||||
Check if a translation exists for the given key and locale.
|
||||
|
||||
Args:
|
||||
key: Translation key
|
||||
locale: Locale code
|
||||
|
||||
Returns:
|
||||
True if translation exists, False otherwise
|
||||
"""
|
||||
parts = key.split(".", 1)
|
||||
if len(parts) < 2:
|
||||
return False
|
||||
|
||||
namespace = parts[0]
|
||||
key_path = parts[1].split(".")
|
||||
|
||||
translation = self.cache.get_translation(locale, namespace, key_path)
|
||||
return translation is not None
|
||||
|
||||
def reload(self, locale: Optional[str] = None):
|
||||
"""
|
||||
Reload translation files.
|
||||
|
||||
Args:
|
||||
locale: Specific locale to reload. If None, reloads all locales.
|
||||
"""
|
||||
logger.info(f"Reloading translations for locale: {locale or 'all'}")
|
||||
|
||||
if locale:
|
||||
locale_data = self.loader.load_locale(locale)
|
||||
self.cache.set_locale_data(locale, locale_data)
|
||||
# Clear LRU cache for this locale
|
||||
self.cache.clear_locale(locale)
|
||||
else:
|
||||
self._load_all_locales()
|
||||
# Clear all LRU cache
|
||||
self.cache.clear_lru()
|
||||
|
||||
logger.info("Translation reload completed")
|
||||
|
||||
def get_available_locales(self) -> list:
|
||||
"""
|
||||
Get list of all available locales.
|
||||
|
||||
Returns:
|
||||
List of locale codes
|
||||
"""
|
||||
return self.cache.get_loaded_locales()
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
return self.cache.get_stats()
|
||||
|
||||
def get_metrics_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get metrics summary.
|
||||
|
||||
Returns:
|
||||
Dictionary with metrics summary
|
||||
"""
|
||||
return self.metrics.get_summary()
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get memory usage information.
|
||||
|
||||
Returns:
|
||||
Dictionary with memory usage information
|
||||
"""
|
||||
return self.cache.get_memory_usage()
|
||||
|
||||
def get_loaded_dirs(self) -> list:
|
||||
"""
|
||||
Get list of loaded translation directories.
|
||||
|
||||
Returns:
|
||||
List of directory paths
|
||||
"""
|
||||
return self.loader.locales_dirs
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_translation_service: Optional[TranslationService] = None
|
||||
|
||||
|
||||
def get_translation_service() -> TranslationService:
|
||||
"""
|
||||
Get the global translation service instance.
|
||||
|
||||
Returns:
|
||||
TranslationService singleton
|
||||
"""
|
||||
global _translation_service
|
||||
if _translation_service is None:
|
||||
_translation_service = TranslationService()
|
||||
return _translation_service
|
||||
|
||||
|
||||
# Convenience functions for easy access
|
||||
def t(key: str, locale: Optional[str] = None, **params) -> str:
|
||||
"""
|
||||
Translate a key (convenience function).
|
||||
|
||||
Args:
|
||||
key: Translation key
|
||||
locale: Target locale (optional, uses default if not provided)
|
||||
**params: Parameters for parameterized messages
|
||||
|
||||
Returns:
|
||||
Translated string
|
||||
|
||||
Examples:
|
||||
t("common.success.created")
|
||||
t("common.validation.required", field="名称")
|
||||
t("workspace.member_count", count=5)
|
||||
"""
|
||||
service = get_translation_service()
|
||||
return service.translate(key, locale, **params)
|
||||
|
||||
|
||||
def t_enum(enum_type: str, value: str, locale: Optional[str] = None) -> str:
|
||||
"""
|
||||
Translate an enum value (convenience function).
|
||||
|
||||
Args:
|
||||
enum_type: Enum type name
|
||||
value: Enum value
|
||||
locale: Target locale
|
||||
|
||||
Returns:
|
||||
Translated enum display name
|
||||
|
||||
Examples:
|
||||
t_enum("workspace_role", "manager")
|
||||
t_enum("invite_status", "pending", "en")
|
||||
"""
|
||||
service = get_translation_service()
|
||||
return service.translate_enum(enum_type, value, locale)
|
||||
26
api/app/locales/en/README.md
Normal file
26
api/app/locales/en/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# English Translation Files
|
||||
|
||||
This directory contains English translation files.
|
||||
|
||||
## File Structure
|
||||
|
||||
- `common.json` - Common translations (success messages, actions, validation)
|
||||
- `auth.json` - Authentication module translations
|
||||
- `workspace.json` - Workspace module translations
|
||||
- `tenant.json` - Tenant module translations
|
||||
- `errors.json` - Error message translations
|
||||
- `enums.json` - Enum value translations
|
||||
|
||||
## Translation File Format
|
||||
|
||||
All translation files use JSON format and support nested structures.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"success": {
|
||||
"created": "Created successfully",
|
||||
"updated": "Updated successfully"
|
||||
}
|
||||
}
|
||||
```
|
||||
55
api/app/locales/en/auth.json
Normal file
55
api/app/locales/en/auth.json
Normal file
@@ -0,0 +1,55 @@
|
||||
{
|
||||
"login": {
|
||||
"success": "Login successful",
|
||||
"failed": "Login failed",
|
||||
"invalid_credentials": "Invalid username or password",
|
||||
"account_locked": "Account has been locked",
|
||||
"account_disabled": "Account has been disabled"
|
||||
},
|
||||
"logout": {
|
||||
"success": "Logout successful",
|
||||
"failed": "Logout failed"
|
||||
},
|
||||
"token": {
|
||||
"refresh_success": "Token refreshed successfully",
|
||||
"invalid": "Invalid token",
|
||||
"expired": "Token has expired",
|
||||
"blacklisted": "Token has been invalidated",
|
||||
"invalid_refresh_token": "Invalid refresh token",
|
||||
"refresh_token_blacklisted": "Refresh token has been invalidated"
|
||||
},
|
||||
"registration": {
|
||||
"success": "Registration successful",
|
||||
"failed": "Registration failed",
|
||||
"email_exists": "Email already in use",
|
||||
"username_exists": "Username already taken"
|
||||
},
|
||||
"password": {
|
||||
"reset_success": "Password reset successful",
|
||||
"reset_failed": "Password reset failed",
|
||||
"change_success": "Password changed successfully",
|
||||
"change_failed": "Password change failed",
|
||||
"incorrect": "Incorrect password",
|
||||
"too_weak": "Password is too weak",
|
||||
"mismatch": "Passwords do not match"
|
||||
},
|
||||
"invite": {
|
||||
"invalid": "Invalid or expired invite code",
|
||||
"email_mismatch": "Invite email does not match login email",
|
||||
"accept_success": "Invite accepted successfully",
|
||||
"accept_failed": "Failed to accept invite",
|
||||
"password_verification_failed": "Failed to accept invite, password verification error",
|
||||
"bind_workspace_success": "Workspace bound successfully",
|
||||
"bind_workspace_failed": "Failed to bind workspace"
|
||||
},
|
||||
"user": {
|
||||
"not_found": "User not found",
|
||||
"already_exists": "User already exists",
|
||||
"created_with_invite": "User created successfully and joined workspace"
|
||||
},
|
||||
"session": {
|
||||
"expired": "Session expired, please login again",
|
||||
"invalid": "Invalid session",
|
||||
"single_session_enabled": "Single sign-on enabled, other device sessions will be logged out"
|
||||
}
|
||||
}
|
||||
132
api/app/locales/en/common.json
Normal file
132
api/app/locales/en/common.json
Normal file
@@ -0,0 +1,132 @@
|
||||
{
|
||||
"success": {
|
||||
"created": "Created successfully",
|
||||
"updated": "Updated successfully",
|
||||
"deleted": "Deleted successfully",
|
||||
"retrieved": "Retrieved successfully",
|
||||
"saved": "Saved successfully",
|
||||
"uploaded": "Uploaded successfully",
|
||||
"downloaded": "Downloaded successfully",
|
||||
"sent": "Sent successfully",
|
||||
"completed": "Completed",
|
||||
"confirmed": "Confirmed",
|
||||
"cancelled": "Cancelled",
|
||||
"archived": "Archived",
|
||||
"restored": "Restored"
|
||||
},
|
||||
"actions": {
|
||||
"create": "Create",
|
||||
"update": "Update",
|
||||
"delete": "Delete",
|
||||
"view": "View",
|
||||
"edit": "Edit",
|
||||
"save": "Save",
|
||||
"cancel": "Cancel",
|
||||
"confirm": "Confirm",
|
||||
"submit": "Submit",
|
||||
"upload": "Upload",
|
||||
"download": "Download",
|
||||
"send": "Send",
|
||||
"search": "Search",
|
||||
"filter": "Filter",
|
||||
"sort": "Sort",
|
||||
"export": "Export",
|
||||
"import": "Import",
|
||||
"refresh": "Refresh",
|
||||
"reset": "Reset",
|
||||
"back": "Back",
|
||||
"next": "Next",
|
||||
"previous": "Previous",
|
||||
"finish": "Finish",
|
||||
"close": "Close",
|
||||
"open": "Open",
|
||||
"archive": "Archive",
|
||||
"restore": "Restore",
|
||||
"duplicate": "Duplicate",
|
||||
"share": "Share",
|
||||
"invite": "Invite",
|
||||
"remove": "Remove",
|
||||
"add": "Add",
|
||||
"select": "Select",
|
||||
"clear": "Clear"
|
||||
},
|
||||
"validation": {
|
||||
"required": "{field} is required",
|
||||
"invalid_format": "{field} format is invalid",
|
||||
"too_long": "{field} cannot exceed {max} characters",
|
||||
"too_short": "{field} must be at least {min} characters",
|
||||
"invalid_email": "Invalid email format",
|
||||
"invalid_url": "Invalid URL format",
|
||||
"invalid_phone": "Invalid phone number format",
|
||||
"invalid_date": "Invalid date format",
|
||||
"invalid_number": "Must be a valid number",
|
||||
"out_of_range": "{field} must be between {min} and {max}",
|
||||
"already_exists": "{field} already exists",
|
||||
"not_found": "{field} not found",
|
||||
"invalid_value": "Invalid value for {field}",
|
||||
"password_mismatch": "Passwords do not match",
|
||||
"weak_password": "Password is too weak, please use a stronger password",
|
||||
"invalid_credentials": "Invalid username or password",
|
||||
"unauthorized": "Unauthorized access",
|
||||
"forbidden": "Permission denied",
|
||||
"expired": "{field} has expired",
|
||||
"invalid_token": "Invalid token",
|
||||
"file_too_large": "File size cannot exceed {max}",
|
||||
"invalid_file_type": "Unsupported file type",
|
||||
"duplicate": "Duplicate {field}"
|
||||
},
|
||||
"status": {
|
||||
"active": "Active",
|
||||
"inactive": "Inactive",
|
||||
"pending": "Pending",
|
||||
"processing": "Processing",
|
||||
"completed": "Completed",
|
||||
"failed": "Failed",
|
||||
"cancelled": "Cancelled",
|
||||
"archived": "Archived",
|
||||
"deleted": "Deleted",
|
||||
"draft": "Draft",
|
||||
"published": "Published",
|
||||
"suspended": "Suspended",
|
||||
"expired": "Expired"
|
||||
},
|
||||
"messages": {
|
||||
"loading": "Loading...",
|
||||
"saving": "Saving...",
|
||||
"processing": "Processing...",
|
||||
"uploading": "Uploading...",
|
||||
"downloading": "Downloading...",
|
||||
"no_data": "No data available",
|
||||
"no_results": "No results found",
|
||||
"confirm_delete": "Are you sure you want to delete? This action cannot be undone.",
|
||||
"confirm_action": "Are you sure you want to perform this action?",
|
||||
"operation_success": "Operation successful",
|
||||
"operation_failed": "Operation failed",
|
||||
"please_wait": "Please wait...",
|
||||
"try_again": "Please try again",
|
||||
"contact_support": "If the problem persists, please contact support"
|
||||
},
|
||||
"pagination": {
|
||||
"page": "Page {page}",
|
||||
"of": "of {total}",
|
||||
"items": "{total} items",
|
||||
"per_page": "{count} per page",
|
||||
"showing": "Showing {from} to {to} of {total}",
|
||||
"first": "First",
|
||||
"last": "Last",
|
||||
"next": "Next",
|
||||
"previous": "Previous"
|
||||
},
|
||||
"time": {
|
||||
"just_now": "Just now",
|
||||
"minutes_ago": "{count} minutes ago",
|
||||
"hours_ago": "{count} hours ago",
|
||||
"days_ago": "{count} days ago",
|
||||
"weeks_ago": "{count} weeks ago",
|
||||
"months_ago": "{count} months ago",
|
||||
"years_ago": "{count} years ago",
|
||||
"today": "Today",
|
||||
"yesterday": "Yesterday",
|
||||
"tomorrow": "Tomorrow"
|
||||
}
|
||||
}
|
||||
132
api/app/locales/en/enums.json
Normal file
132
api/app/locales/en/enums.json
Normal file
@@ -0,0 +1,132 @@
|
||||
{
|
||||
"workspace_role": {
|
||||
"owner": "Owner",
|
||||
"manager": "Manager",
|
||||
"member": "Member",
|
||||
"guest": "Guest"
|
||||
},
|
||||
"workspace_status": {
|
||||
"active": "Active",
|
||||
"inactive": "Inactive",
|
||||
"archived": "Archived",
|
||||
"suspended": "Suspended",
|
||||
"deleted": "Deleted"
|
||||
},
|
||||
"invite_status": {
|
||||
"pending": "Pending",
|
||||
"accepted": "Accepted",
|
||||
"rejected": "Rejected",
|
||||
"revoked": "Revoked",
|
||||
"expired": "Expired"
|
||||
},
|
||||
"user_status": {
|
||||
"active": "Active",
|
||||
"inactive": "Inactive",
|
||||
"suspended": "Suspended",
|
||||
"deleted": "Deleted",
|
||||
"pending": "Pending"
|
||||
},
|
||||
"tenant_status": {
|
||||
"active": "Active",
|
||||
"inactive": "Inactive",
|
||||
"suspended": "Suspended",
|
||||
"expired": "Expired",
|
||||
"trial": "Trial"
|
||||
},
|
||||
"file_status": {
|
||||
"uploading": "Uploading",
|
||||
"processing": "Processing",
|
||||
"completed": "Completed",
|
||||
"failed": "Failed",
|
||||
"deleted": "Deleted"
|
||||
},
|
||||
"task_status": {
|
||||
"pending": "Pending",
|
||||
"running": "Running",
|
||||
"completed": "Completed",
|
||||
"failed": "Failed",
|
||||
"cancelled": "Cancelled",
|
||||
"paused": "Paused"
|
||||
},
|
||||
"priority": {
|
||||
"low": "Low",
|
||||
"medium": "Medium",
|
||||
"high": "High",
|
||||
"urgent": "Urgent"
|
||||
},
|
||||
"visibility": {
|
||||
"public": "Public",
|
||||
"private": "Private",
|
||||
"internal": "Internal",
|
||||
"shared": "Shared"
|
||||
},
|
||||
"permission": {
|
||||
"read": "Read",
|
||||
"write": "Write",
|
||||
"delete": "Delete",
|
||||
"admin": "Admin",
|
||||
"owner": "Owner"
|
||||
},
|
||||
"notification_type": {
|
||||
"info": "Info",
|
||||
"warning": "Warning",
|
||||
"error": "Error",
|
||||
"success": "Success"
|
||||
},
|
||||
"language": {
|
||||
"zh": "Chinese (Simplified)",
|
||||
"en": "English",
|
||||
"ja": "Japanese",
|
||||
"ko": "Korean",
|
||||
"fr": "French",
|
||||
"de": "German",
|
||||
"es": "Spanish"
|
||||
},
|
||||
"timezone": {
|
||||
"utc": "UTC",
|
||||
"asia_shanghai": "Asia/Shanghai",
|
||||
"asia_tokyo": "Asia/Tokyo",
|
||||
"america_new_york": "America/New_York",
|
||||
"europe_london": "Europe/London"
|
||||
},
|
||||
"date_format": {
|
||||
"short": "Short",
|
||||
"medium": "Medium",
|
||||
"long": "Long",
|
||||
"full": "Full"
|
||||
},
|
||||
"sort_order": {
|
||||
"asc": "Ascending",
|
||||
"desc": "Descending"
|
||||
},
|
||||
"filter_operator": {
|
||||
"equals": "Equals",
|
||||
"not_equals": "Not Equals",
|
||||
"contains": "Contains",
|
||||
"not_contains": "Not Contains",
|
||||
"starts_with": "Starts With",
|
||||
"ends_with": "Ends With",
|
||||
"greater_than": "Greater Than",
|
||||
"less_than": "Less Than",
|
||||
"greater_or_equal": "Greater or Equal",
|
||||
"less_or_equal": "Less or Equal",
|
||||
"in": "In",
|
||||
"not_in": "Not In",
|
||||
"is_null": "Is Null",
|
||||
"is_not_null": "Is Not Null"
|
||||
},
|
||||
"log_level": {
|
||||
"debug": "Debug",
|
||||
"info": "Info",
|
||||
"warning": "Warning",
|
||||
"error": "Error",
|
||||
"critical": "Critical"
|
||||
},
|
||||
"api_method": {
|
||||
"get": "GET",
|
||||
"post": "POST",
|
||||
"put": "PUT",
|
||||
"patch": "PATCH",
|
||||
"delete": "DELETE"
|
||||
}
|
||||
}
|
||||
138
api/app/locales/en/errors.json
Normal file
138
api/app/locales/en/errors.json
Normal file
@@ -0,0 +1,138 @@
|
||||
{
|
||||
"common": {
|
||||
"internal_error": "Internal server error",
|
||||
"network_error": "Network connection error",
|
||||
"timeout": "Request timeout",
|
||||
"service_unavailable": "Service temporarily unavailable",
|
||||
"bad_request": "Bad request parameters",
|
||||
"unauthorized": "Unauthorized access",
|
||||
"forbidden": "Access forbidden",
|
||||
"not_found": "Resource not found",
|
||||
"method_not_allowed": "Method not allowed",
|
||||
"conflict": "Resource conflict",
|
||||
"too_many_requests": "Too many requests, please try again later",
|
||||
"validation_failed": "Validation failed",
|
||||
"database_error": "Database operation failed",
|
||||
"file_operation_error": "File operation failed"
|
||||
},
|
||||
"auth": {
|
||||
"invalid_credentials": "Invalid username or password",
|
||||
"token_expired": "Session expired, please login again",
|
||||
"token_invalid": "Invalid authentication token",
|
||||
"token_missing": "Authentication token missing",
|
||||
"unauthorized": "Unauthorized access",
|
||||
"forbidden": "Permission denied",
|
||||
"account_locked": "Account has been locked",
|
||||
"account_disabled": "Account has been disabled",
|
||||
"account_not_verified": "Account not verified",
|
||||
"password_incorrect": "Incorrect password",
|
||||
"password_too_weak": "Password is too weak",
|
||||
"password_expired": "Password expired, please change it",
|
||||
"email_not_verified": "Email not verified",
|
||||
"phone_not_verified": "Phone number not verified",
|
||||
"verification_code_invalid": "Invalid verification code",
|
||||
"verification_code_expired": "Verification code expired",
|
||||
"login_failed": "Login failed",
|
||||
"logout_failed": "Logout failed",
|
||||
"session_expired": "Session expired",
|
||||
"already_logged_in": "Already logged in",
|
||||
"not_logged_in": "Not logged in"
|
||||
},
|
||||
"user": {
|
||||
"not_found": "User not found",
|
||||
"already_exists": "User already exists",
|
||||
"email_already_exists": "Email already in use",
|
||||
"phone_already_exists": "Phone number already in use",
|
||||
"username_already_exists": "Username already taken",
|
||||
"invalid_email": "Invalid email format",
|
||||
"invalid_phone": "Invalid phone number format",
|
||||
"invalid_username": "Invalid username format",
|
||||
"create_failed": "Failed to create user",
|
||||
"update_failed": "Failed to update user",
|
||||
"delete_failed": "Failed to delete user",
|
||||
"cannot_delete_self": "Cannot delete yourself",
|
||||
"cannot_update_self_role": "Cannot update your own role",
|
||||
"profile_update_failed": "Failed to update profile",
|
||||
"avatar_upload_failed": "Failed to upload avatar",
|
||||
"password_change_failed": "Failed to change password",
|
||||
"old_password_incorrect": "Old password is incorrect"
|
||||
},
|
||||
"workspace": {
|
||||
"not_found": "Workspace not found",
|
||||
"already_exists": "Workspace already exists",
|
||||
"name_required": "Workspace name is required",
|
||||
"name_too_long": "Workspace name is too long",
|
||||
"create_failed": "Failed to create workspace",
|
||||
"update_failed": "Failed to update workspace",
|
||||
"delete_failed": "Failed to delete workspace",
|
||||
"permission_denied": "Permission denied to access this workspace",
|
||||
"not_member": "Not a workspace member",
|
||||
"already_member": "Already a workspace member",
|
||||
"member_limit_reached": "Member limit reached",
|
||||
"cannot_leave_last_manager": "Cannot leave, you are the last manager",
|
||||
"cannot_remove_last_manager": "Cannot remove the last manager",
|
||||
"cannot_remove_self": "Cannot remove yourself",
|
||||
"invite_not_found": "Invite not found",
|
||||
"invite_expired": "Invite has expired",
|
||||
"invite_already_accepted": "Invite already accepted",
|
||||
"invite_already_revoked": "Invite already revoked",
|
||||
"invite_send_failed": "Failed to send invite",
|
||||
"archived": "Workspace is archived",
|
||||
"suspended": "Workspace is suspended"
|
||||
},
|
||||
"tenant": {
|
||||
"not_found": "Tenant not found",
|
||||
"already_exists": "Tenant already exists",
|
||||
"create_failed": "Failed to create tenant",
|
||||
"update_failed": "Failed to update tenant",
|
||||
"delete_failed": "Failed to delete tenant",
|
||||
"suspended": "Tenant is suspended",
|
||||
"expired": "Tenant has expired",
|
||||
"license_invalid": "Invalid license",
|
||||
"license_expired": "License has expired",
|
||||
"quota_exceeded": "Quota exceeded"
|
||||
},
|
||||
"file": {
|
||||
"not_found": "File not found",
|
||||
"upload_failed": "File upload failed",
|
||||
"download_failed": "File download failed",
|
||||
"delete_failed": "File deletion failed",
|
||||
"too_large": "File size exceeds limit",
|
||||
"invalid_type": "Unsupported file type",
|
||||
"invalid_format": "Invalid file format",
|
||||
"corrupted": "File is corrupted",
|
||||
"storage_full": "Storage is full",
|
||||
"access_denied": "Access denied to this file"
|
||||
},
|
||||
"api": {
|
||||
"rate_limit_exceeded": "API rate limit exceeded",
|
||||
"quota_exceeded": "API quota exceeded",
|
||||
"invalid_api_key": "Invalid API key",
|
||||
"api_key_expired": "API key has expired",
|
||||
"api_key_revoked": "API key has been revoked",
|
||||
"endpoint_not_found": "API endpoint not found",
|
||||
"method_not_allowed": "Method not allowed",
|
||||
"invalid_request": "Invalid request",
|
||||
"missing_parameter": "Missing required parameter: {param}",
|
||||
"invalid_parameter": "Invalid parameter: {param}"
|
||||
},
|
||||
"database": {
|
||||
"connection_failed": "Database connection failed",
|
||||
"query_failed": "Database query failed",
|
||||
"transaction_failed": "Database transaction failed",
|
||||
"constraint_violation": "Data constraint violation",
|
||||
"duplicate_key": "Duplicate data",
|
||||
"foreign_key_violation": "Foreign key constraint violation",
|
||||
"deadlock": "Database deadlock"
|
||||
},
|
||||
"validation": {
|
||||
"invalid_input": "Invalid input data",
|
||||
"missing_field": "Missing required field: {field}",
|
||||
"invalid_field": "Invalid field: {field}",
|
||||
"field_too_long": "Field too long: {field}",
|
||||
"field_too_short": "Field too short: {field}",
|
||||
"invalid_format": "Invalid format: {field}",
|
||||
"invalid_value": "Invalid value: {field}",
|
||||
"out_of_range": "Value out of range: {field}"
|
||||
}
|
||||
}
|
||||
27
api/app/locales/en/i18n.json
Normal file
27
api/app/locales/en/i18n.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"language": {
|
||||
"not_found": "Language {locale} not found",
|
||||
"already_exists": "Language {locale} already exists",
|
||||
"add_instructions": "Language {locale} validated successfully. Please create translation files in {dir} directory to complete the addition.",
|
||||
"update_instructions": "Language {locale} update validated successfully. Please update I18N_SUPPORTED_LANGUAGES environment variable to apply configuration changes."
|
||||
},
|
||||
"namespace": {
|
||||
"not_found": "Namespace {namespace} not found in language {locale}"
|
||||
},
|
||||
"translation": {
|
||||
"invalid_key_format": "Invalid translation key format: {key}. Should use format: namespace.key.subkey",
|
||||
"update_instructions": "Translation {locale}/{key} update validated successfully. Please modify the corresponding JSON translation file to apply changes."
|
||||
},
|
||||
"reload": {
|
||||
"disabled": "Translation hot reload is disabled. Please enable I18N_ENABLE_HOT_RELOAD in configuration.",
|
||||
"success": "Translations reloaded successfully",
|
||||
"failed": "Translation reload failed: {error}"
|
||||
},
|
||||
"metrics": {
|
||||
"reset_success": "Performance metrics reset successfully"
|
||||
},
|
||||
"logs": {
|
||||
"export_success": "Missing translations exported to: {file}",
|
||||
"clear_success": "Missing translation logs cleared successfully"
|
||||
}
|
||||
}
|
||||
63
api/app/locales/en/tenant.json
Normal file
63
api/app/locales/en/tenant.json
Normal file
@@ -0,0 +1,63 @@
|
||||
{
|
||||
"info": {
|
||||
"get_success": "Tenant information retrieved successfully",
|
||||
"get_failed": "Failed to retrieve tenant information",
|
||||
"update_success": "Tenant information updated successfully",
|
||||
"update_failed": "Failed to update tenant information"
|
||||
},
|
||||
"create": {
|
||||
"success": "Tenant created successfully",
|
||||
"failed": "Failed to create tenant"
|
||||
},
|
||||
"delete": {
|
||||
"success": "Tenant deleted successfully",
|
||||
"failed": "Failed to delete tenant"
|
||||
},
|
||||
"status": {
|
||||
"activate_success": "Tenant activated successfully",
|
||||
"activate_failed": "Failed to activate tenant",
|
||||
"deactivate_success": "Tenant deactivated successfully",
|
||||
"deactivate_failed": "Failed to deactivate tenant"
|
||||
},
|
||||
"language": {
|
||||
"get_success": "Tenant language configuration retrieved successfully",
|
||||
"get_failed": "Failed to retrieve tenant language configuration",
|
||||
"update_success": "Tenant language configuration updated successfully",
|
||||
"update_failed": "Failed to update tenant language configuration",
|
||||
"invalid_language": "Unsupported language code",
|
||||
"default_not_in_supported": "Default language must be in the supported languages list"
|
||||
},
|
||||
"list": {
|
||||
"get_success": "Tenant list retrieved successfully",
|
||||
"get_failed": "Failed to retrieve tenant list"
|
||||
},
|
||||
"users": {
|
||||
"list_success": "Tenant user list retrieved successfully",
|
||||
"list_failed": "Failed to retrieve tenant user list",
|
||||
"assign_success": "User assigned to tenant successfully",
|
||||
"assign_failed": "Failed to assign user to tenant",
|
||||
"remove_success": "User removed from tenant successfully",
|
||||
"remove_failed": "Failed to remove user from tenant"
|
||||
},
|
||||
"statistics": {
|
||||
"get_success": "Tenant statistics retrieved successfully",
|
||||
"get_failed": "Failed to retrieve tenant statistics"
|
||||
},
|
||||
"validation": {
|
||||
"name_required": "Tenant name is required",
|
||||
"name_invalid": "Invalid tenant name format",
|
||||
"name_too_long": "Tenant name cannot exceed {max} characters",
|
||||
"description_too_long": "Tenant description cannot exceed {max} characters",
|
||||
"language_code_invalid": "Invalid language code format",
|
||||
"supported_languages_empty": "Supported languages list cannot be empty"
|
||||
},
|
||||
"errors": {
|
||||
"not_found": "Tenant not found",
|
||||
"already_exists": "Tenant name already exists",
|
||||
"permission_denied": "Permission denied to access this tenant",
|
||||
"has_users": "Cannot delete tenant, associated users exist",
|
||||
"has_workspaces": "Cannot delete tenant, associated workspaces exist",
|
||||
"already_active": "Tenant is already active",
|
||||
"already_inactive": "Tenant is already inactive"
|
||||
}
|
||||
}
|
||||
72
api/app/locales/en/users.json
Normal file
72
api/app/locales/en/users.json
Normal file
@@ -0,0 +1,72 @@
|
||||
{
|
||||
"info": {
|
||||
"get_success": "User information retrieved successfully",
|
||||
"get_failed": "Failed to retrieve user information",
|
||||
"update_success": "User information updated successfully",
|
||||
"update_failed": "Failed to update user information"
|
||||
},
|
||||
"create": {
|
||||
"success": "User created successfully",
|
||||
"failed": "Failed to create user",
|
||||
"superuser_success": "Superuser created successfully",
|
||||
"superuser_failed": "Failed to create superuser"
|
||||
},
|
||||
"delete": {
|
||||
"success": "User deleted successfully",
|
||||
"failed": "Failed to delete user",
|
||||
"deactivate_success": "User deactivated successfully",
|
||||
"deactivate_failed": "Failed to deactivate user"
|
||||
},
|
||||
"activate": {
|
||||
"success": "User activated successfully",
|
||||
"failed": "Failed to activate user"
|
||||
},
|
||||
"language": {
|
||||
"get_success": "Language preference retrieved successfully",
|
||||
"get_failed": "Failed to retrieve language preference",
|
||||
"update_success": "Language preference updated successfully",
|
||||
"update_failed": "Failed to update language preference",
|
||||
"invalid_language": "Unsupported language code",
|
||||
"current": "Current language preference"
|
||||
},
|
||||
"email": {
|
||||
"change_success": "Email changed successfully",
|
||||
"change_failed": "Failed to change email",
|
||||
"code_sent": "Verification code has been sent to your email",
|
||||
"code_send_failed": "Failed to send verification code",
|
||||
"code_invalid": "Invalid or expired verification code",
|
||||
"already_exists": "Email already in use"
|
||||
},
|
||||
"list": {
|
||||
"get_success": "User list retrieved successfully",
|
||||
"get_failed": "Failed to retrieve user list",
|
||||
"superusers_success": "Tenant superuser list retrieved successfully",
|
||||
"superusers_failed": "Failed to retrieve tenant superuser list"
|
||||
},
|
||||
"validation": {
|
||||
"username_required": "Username is required",
|
||||
"username_invalid": "Invalid username format",
|
||||
"username_too_long": "Username cannot exceed {max} characters",
|
||||
"email_required": "Email is required",
|
||||
"email_invalid": "Invalid email format",
|
||||
"password_required": "Password is required",
|
||||
"password_too_short": "Password must be at least {min} characters",
|
||||
"password_too_long": "Password cannot exceed {max} characters",
|
||||
"old_password_required": "Old password is required",
|
||||
"new_password_required": "New password is required",
|
||||
"verification_code_required": "Verification code is required",
|
||||
"verification_code_invalid": "Invalid verification code format"
|
||||
},
|
||||
"errors": {
|
||||
"not_found": "User not found",
|
||||
"already_exists": "User already exists",
|
||||
"permission_denied": "Permission denied to access this user",
|
||||
"cannot_delete_self": "Cannot delete yourself",
|
||||
"cannot_deactivate_self": "Cannot deactivate yourself",
|
||||
"already_deactivated": "User is already deactivated",
|
||||
"already_activated": "User is already activated",
|
||||
"password_verification_failed": "Password verification failed",
|
||||
"old_password_incorrect": "Old password is incorrect",
|
||||
"same_as_old_password": "New password cannot be the same as old password"
|
||||
}
|
||||
}
|
||||
44
api/app/locales/en/workspace.json
Normal file
44
api/app/locales/en/workspace.json
Normal file
@@ -0,0 +1,44 @@
|
||||
{
|
||||
"list_retrieved": "Workspace list retrieved successfully",
|
||||
"created": "Workspace created successfully",
|
||||
"updated": "Workspace updated successfully",
|
||||
"deleted": "Workspace deleted successfully",
|
||||
"switched": "Workspace switched successfully",
|
||||
"not_found": "Workspace not found or access denied",
|
||||
"already_exists": "Workspace already exists",
|
||||
"permission_denied": "No permission to access this workspace",
|
||||
"name_required": "Workspace name is required",
|
||||
"invalid_name": "Invalid workspace name format",
|
||||
"members": {
|
||||
"list_retrieved": "Workspace members list retrieved successfully",
|
||||
"role_updated": "Member role updated successfully",
|
||||
"deleted": "Member deleted successfully",
|
||||
"not_found": "Member not found",
|
||||
"cannot_remove_self": "Cannot remove yourself",
|
||||
"cannot_remove_last_manager": "Cannot remove the last manager",
|
||||
"already_member": "User is already a workspace member"
|
||||
},
|
||||
"invites": {
|
||||
"created": "Invite created successfully",
|
||||
"list_retrieved": "Invite list retrieved successfully",
|
||||
"validated": "Invite validated successfully",
|
||||
"revoked": "Invite revoked successfully",
|
||||
"accepted": "Invite accepted",
|
||||
"not_found": "Invite not found",
|
||||
"expired": "Invite has expired",
|
||||
"already_used": "Invite has already been used",
|
||||
"invalid_token": "Invalid invite token",
|
||||
"email_required": "Email address is required",
|
||||
"invalid_email": "Invalid email address format"
|
||||
},
|
||||
"storage": {
|
||||
"type_retrieved": "Storage type retrieved successfully",
|
||||
"type_updated": "Storage type updated successfully",
|
||||
"invalid_type": "Invalid storage type"
|
||||
},
|
||||
"models": {
|
||||
"config_retrieved": "Model configuration retrieved successfully",
|
||||
"config_updated": "Model configuration updated successfully",
|
||||
"invalid_config": "Invalid model configuration"
|
||||
}
|
||||
}
|
||||
26
api/app/locales/zh/README.md
Normal file
26
api/app/locales/zh/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# 中文翻译文件
|
||||
|
||||
此目录包含中文(简体)的翻译文件。
|
||||
|
||||
## 文件结构
|
||||
|
||||
- `common.json` - 通用翻译(成功消息、操作、验证)
|
||||
- `auth.json` - 认证模块翻译
|
||||
- `workspace.json` - 工作空间模块翻译
|
||||
- `tenant.json` - 租户模块翻译
|
||||
- `errors.json` - 错误消息翻译
|
||||
- `enums.json` - 枚举值翻译
|
||||
|
||||
## 翻译文件格式
|
||||
|
||||
所有翻译文件使用 JSON 格式,支持嵌套结构。
|
||||
|
||||
示例:
|
||||
```json
|
||||
{
|
||||
"success": {
|
||||
"created": "创建成功",
|
||||
"updated": "更新成功"
|
||||
}
|
||||
}
|
||||
```
|
||||
55
api/app/locales/zh/auth.json
Normal file
55
api/app/locales/zh/auth.json
Normal file
@@ -0,0 +1,55 @@
|
||||
{
|
||||
"login": {
|
||||
"success": "登录成功",
|
||||
"failed": "登录失败",
|
||||
"invalid_credentials": "用户名或密码错误",
|
||||
"account_locked": "账户已被锁定",
|
||||
"account_disabled": "账户已被禁用"
|
||||
},
|
||||
"logout": {
|
||||
"success": "登出成功",
|
||||
"failed": "登出失败"
|
||||
},
|
||||
"token": {
|
||||
"refresh_success": "token刷新成功",
|
||||
"invalid": "无效的token",
|
||||
"expired": "token已过期",
|
||||
"blacklisted": "token已失效",
|
||||
"invalid_refresh_token": "无效的refresh token",
|
||||
"refresh_token_blacklisted": "Refresh token已失效"
|
||||
},
|
||||
"registration": {
|
||||
"success": "注册成功",
|
||||
"failed": "注册失败",
|
||||
"email_exists": "邮箱已被使用",
|
||||
"username_exists": "用户名已被使用"
|
||||
},
|
||||
"password": {
|
||||
"reset_success": "密码重置成功",
|
||||
"reset_failed": "密码重置失败",
|
||||
"change_success": "密码修改成功",
|
||||
"change_failed": "密码修改失败",
|
||||
"incorrect": "密码错误",
|
||||
"too_weak": "密码强度不够",
|
||||
"mismatch": "两次输入的密码不一致"
|
||||
},
|
||||
"invite": {
|
||||
"invalid": "邀请码无效或已过期",
|
||||
"email_mismatch": "邀请邮箱与登录邮箱不匹配",
|
||||
"accept_success": "接受邀请成功",
|
||||
"accept_failed": "接受邀请失败",
|
||||
"password_verification_failed": "接受邀请失败,密码验证错误",
|
||||
"bind_workspace_success": "绑定工作空间成功",
|
||||
"bind_workspace_failed": "绑定工作空间失败"
|
||||
},
|
||||
"user": {
|
||||
"not_found": "用户不存在",
|
||||
"already_exists": "用户已存在",
|
||||
"created_with_invite": "用户创建成功并已加入工作空间"
|
||||
},
|
||||
"session": {
|
||||
"expired": "会话已过期,请重新登录",
|
||||
"invalid": "无效的会话",
|
||||
"single_session_enabled": "单点登录已启用,其他设备的登录将被注销"
|
||||
}
|
||||
}
|
||||
132
api/app/locales/zh/common.json
Normal file
132
api/app/locales/zh/common.json
Normal file
@@ -0,0 +1,132 @@
|
||||
{
|
||||
"success": {
|
||||
"created": "创建成功",
|
||||
"updated": "更新成功",
|
||||
"deleted": "删除成功",
|
||||
"retrieved": "获取成功",
|
||||
"saved": "保存成功",
|
||||
"uploaded": "上传成功",
|
||||
"downloaded": "下载成功",
|
||||
"sent": "发送成功",
|
||||
"completed": "完成",
|
||||
"confirmed": "已确认",
|
||||
"cancelled": "已取消",
|
||||
"archived": "已归档",
|
||||
"restored": "已恢复"
|
||||
},
|
||||
"actions": {
|
||||
"create": "创建",
|
||||
"update": "更新",
|
||||
"delete": "删除",
|
||||
"view": "查看",
|
||||
"edit": "编辑",
|
||||
"save": "保存",
|
||||
"cancel": "取消",
|
||||
"confirm": "确认",
|
||||
"submit": "提交",
|
||||
"upload": "上传",
|
||||
"download": "下载",
|
||||
"send": "发送",
|
||||
"search": "搜索",
|
||||
"filter": "筛选",
|
||||
"sort": "排序",
|
||||
"export": "导出",
|
||||
"import": "导入",
|
||||
"refresh": "刷新",
|
||||
"reset": "重置",
|
||||
"back": "返回",
|
||||
"next": "下一步",
|
||||
"previous": "上一步",
|
||||
"finish": "完成",
|
||||
"close": "关闭",
|
||||
"open": "打开",
|
||||
"archive": "归档",
|
||||
"restore": "恢复",
|
||||
"duplicate": "复制",
|
||||
"share": "分享",
|
||||
"invite": "邀请",
|
||||
"remove": "移除",
|
||||
"add": "添加",
|
||||
"select": "选择",
|
||||
"clear": "清除"
|
||||
},
|
||||
"validation": {
|
||||
"required": "{field}不能为空",
|
||||
"invalid_format": "{field}格式不正确",
|
||||
"too_long": "{field}长度不能超过{max}个字符",
|
||||
"too_short": "{field}长度不能少于{min}个字符",
|
||||
"invalid_email": "邮箱格式不正确",
|
||||
"invalid_url": "URL格式不正确",
|
||||
"invalid_phone": "手机号格式不正确",
|
||||
"invalid_date": "日期格式不正确",
|
||||
"invalid_number": "必须是有效的数字",
|
||||
"out_of_range": "{field}必须在{min}和{max}之间",
|
||||
"already_exists": "{field}已存在",
|
||||
"not_found": "{field}不存在",
|
||||
"invalid_value": "{field}的值无效",
|
||||
"password_mismatch": "两次输入的密码不一致",
|
||||
"weak_password": "密码强度不够,请使用更复杂的密码",
|
||||
"invalid_credentials": "用户名或密码错误",
|
||||
"unauthorized": "未授权访问",
|
||||
"forbidden": "没有权限执行此操作",
|
||||
"expired": "{field}已过期",
|
||||
"invalid_token": "无效的令牌",
|
||||
"file_too_large": "文件大小不能超过{max}",
|
||||
"invalid_file_type": "不支持的文件类型",
|
||||
"duplicate": "重复的{field}"
|
||||
},
|
||||
"status": {
|
||||
"active": "活跃",
|
||||
"inactive": "未激活",
|
||||
"pending": "待处理",
|
||||
"processing": "处理中",
|
||||
"completed": "已完成",
|
||||
"failed": "失败",
|
||||
"cancelled": "已取消",
|
||||
"archived": "已归档",
|
||||
"deleted": "已删除",
|
||||
"draft": "草稿",
|
||||
"published": "已发布",
|
||||
"suspended": "已暂停",
|
||||
"expired": "已过期"
|
||||
},
|
||||
"messages": {
|
||||
"loading": "加载中...",
|
||||
"saving": "保存中...",
|
||||
"processing": "处理中...",
|
||||
"uploading": "上传中...",
|
||||
"downloading": "下载中...",
|
||||
"no_data": "暂无数据",
|
||||
"no_results": "没有找到结果",
|
||||
"confirm_delete": "确定要删除吗?此操作不可恢复。",
|
||||
"confirm_action": "确定要执行此操作吗?",
|
||||
"operation_success": "操作成功",
|
||||
"operation_failed": "操作失败",
|
||||
"please_wait": "请稍候...",
|
||||
"try_again": "请重试",
|
||||
"contact_support": "如果问题持续,请联系技术支持"
|
||||
},
|
||||
"pagination": {
|
||||
"page": "第{page}页",
|
||||
"of": "共{total}页",
|
||||
"items": "共{total}条",
|
||||
"per_page": "每页{count}条",
|
||||
"showing": "显示第{from}到第{to}条,共{total}条",
|
||||
"first": "首页",
|
||||
"last": "末页",
|
||||
"next": "下一页",
|
||||
"previous": "上一页"
|
||||
},
|
||||
"time": {
|
||||
"just_now": "刚刚",
|
||||
"minutes_ago": "{count}分钟前",
|
||||
"hours_ago": "{count}小时前",
|
||||
"days_ago": "{count}天前",
|
||||
"weeks_ago": "{count}周前",
|
||||
"months_ago": "{count}个月前",
|
||||
"years_ago": "{count}年前",
|
||||
"today": "今天",
|
||||
"yesterday": "昨天",
|
||||
"tomorrow": "明天"
|
||||
}
|
||||
}
|
||||
132
api/app/locales/zh/enums.json
Normal file
132
api/app/locales/zh/enums.json
Normal file
@@ -0,0 +1,132 @@
|
||||
{
|
||||
"workspace_role": {
|
||||
"owner": "所有者",
|
||||
"manager": "管理员",
|
||||
"member": "成员",
|
||||
"guest": "访客"
|
||||
},
|
||||
"workspace_status": {
|
||||
"active": "活跃",
|
||||
"inactive": "未激活",
|
||||
"archived": "已归档",
|
||||
"suspended": "已暂停",
|
||||
"deleted": "已删除"
|
||||
},
|
||||
"invite_status": {
|
||||
"pending": "待处理",
|
||||
"accepted": "已接受",
|
||||
"rejected": "已拒绝",
|
||||
"revoked": "已撤销",
|
||||
"expired": "已过期"
|
||||
},
|
||||
"user_status": {
|
||||
"active": "活跃",
|
||||
"inactive": "未激活",
|
||||
"suspended": "已暂停",
|
||||
"deleted": "已删除",
|
||||
"pending": "待激活"
|
||||
},
|
||||
"tenant_status": {
|
||||
"active": "活跃",
|
||||
"inactive": "未激活",
|
||||
"suspended": "已暂停",
|
||||
"expired": "已过期",
|
||||
"trial": "试用中"
|
||||
},
|
||||
"file_status": {
|
||||
"uploading": "上传中",
|
||||
"processing": "处理中",
|
||||
"completed": "已完成",
|
||||
"failed": "失败",
|
||||
"deleted": "已删除"
|
||||
},
|
||||
"task_status": {
|
||||
"pending": "待处理",
|
||||
"running": "运行中",
|
||||
"completed": "已完成",
|
||||
"failed": "失败",
|
||||
"cancelled": "已取消",
|
||||
"paused": "已暂停"
|
||||
},
|
||||
"priority": {
|
||||
"low": "低",
|
||||
"medium": "中",
|
||||
"high": "高",
|
||||
"urgent": "紧急"
|
||||
},
|
||||
"visibility": {
|
||||
"public": "公开",
|
||||
"private": "私有",
|
||||
"internal": "内部",
|
||||
"shared": "共享"
|
||||
},
|
||||
"permission": {
|
||||
"read": "读取",
|
||||
"write": "写入",
|
||||
"delete": "删除",
|
||||
"admin": "管理",
|
||||
"owner": "所有者"
|
||||
},
|
||||
"notification_type": {
|
||||
"info": "信息",
|
||||
"warning": "警告",
|
||||
"error": "错误",
|
||||
"success": "成功"
|
||||
},
|
||||
"language": {
|
||||
"zh": "中文(简体)",
|
||||
"en": "English",
|
||||
"ja": "日本語",
|
||||
"ko": "한국어",
|
||||
"fr": "Français",
|
||||
"de": "Deutsch",
|
||||
"es": "Español"
|
||||
},
|
||||
"timezone": {
|
||||
"utc": "UTC",
|
||||
"asia_shanghai": "亚洲/上海",
|
||||
"asia_tokyo": "亚洲/东京",
|
||||
"america_new_york": "美洲/纽约",
|
||||
"europe_london": "欧洲/伦敦"
|
||||
},
|
||||
"date_format": {
|
||||
"short": "短日期",
|
||||
"medium": "中等日期",
|
||||
"long": "长日期",
|
||||
"full": "完整日期"
|
||||
},
|
||||
"sort_order": {
|
||||
"asc": "升序",
|
||||
"desc": "降序"
|
||||
},
|
||||
"filter_operator": {
|
||||
"equals": "等于",
|
||||
"not_equals": "不等于",
|
||||
"contains": "包含",
|
||||
"not_contains": "不包含",
|
||||
"starts_with": "开始于",
|
||||
"ends_with": "结束于",
|
||||
"greater_than": "大于",
|
||||
"less_than": "小于",
|
||||
"greater_or_equal": "大于等于",
|
||||
"less_or_equal": "小于等于",
|
||||
"in": "在列表中",
|
||||
"not_in": "不在列表中",
|
||||
"is_null": "为空",
|
||||
"is_not_null": "不为空"
|
||||
},
|
||||
"log_level": {
|
||||
"debug": "调试",
|
||||
"info": "信息",
|
||||
"warning": "警告",
|
||||
"error": "错误",
|
||||
"critical": "严重"
|
||||
},
|
||||
"api_method": {
|
||||
"get": "GET",
|
||||
"post": "POST",
|
||||
"put": "PUT",
|
||||
"patch": "PATCH",
|
||||
"delete": "DELETE"
|
||||
}
|
||||
}
|
||||
138
api/app/locales/zh/errors.json
Normal file
138
api/app/locales/zh/errors.json
Normal file
@@ -0,0 +1,138 @@
|
||||
{
|
||||
"common": {
|
||||
"internal_error": "服务器内部错误",
|
||||
"network_error": "网络连接错误",
|
||||
"timeout": "请求超时",
|
||||
"service_unavailable": "服务暂时不可用",
|
||||
"bad_request": "请求参数错误",
|
||||
"unauthorized": "未授权访问",
|
||||
"forbidden": "没有权限访问",
|
||||
"not_found": "请求的资源不存在",
|
||||
"method_not_allowed": "不支持的请求方法",
|
||||
"conflict": "资源冲突",
|
||||
"too_many_requests": "请求过于频繁,请稍后再试",
|
||||
"validation_failed": "数据验证失败",
|
||||
"database_error": "数据库操作失败",
|
||||
"file_operation_error": "文件操作失败"
|
||||
},
|
||||
"auth": {
|
||||
"invalid_credentials": "用户名或密码错误",
|
||||
"token_expired": "登录已过期,请重新登录",
|
||||
"token_invalid": "无效的登录令牌",
|
||||
"token_missing": "缺少登录令牌",
|
||||
"unauthorized": "未授权访问",
|
||||
"forbidden": "没有权限执行此操作",
|
||||
"account_locked": "账户已被锁定",
|
||||
"account_disabled": "账户已被禁用",
|
||||
"account_not_verified": "账户未验证",
|
||||
"password_incorrect": "密码错误",
|
||||
"password_too_weak": "密码强度不够",
|
||||
"password_expired": "密码已过期,请修改密码",
|
||||
"email_not_verified": "邮箱未验证",
|
||||
"phone_not_verified": "手机号未验证",
|
||||
"verification_code_invalid": "验证码无效",
|
||||
"verification_code_expired": "验证码已过期",
|
||||
"login_failed": "登录失败",
|
||||
"logout_failed": "登出失败",
|
||||
"session_expired": "会话已过期",
|
||||
"already_logged_in": "已经登录",
|
||||
"not_logged_in": "未登录"
|
||||
},
|
||||
"user": {
|
||||
"not_found": "用户不存在",
|
||||
"already_exists": "用户已存在",
|
||||
"email_already_exists": "邮箱已被使用",
|
||||
"phone_already_exists": "手机号已被使用",
|
||||
"username_already_exists": "用户名已被使用",
|
||||
"invalid_email": "邮箱格式不正确",
|
||||
"invalid_phone": "手机号格式不正确",
|
||||
"invalid_username": "用户名格式不正确",
|
||||
"create_failed": "创建用户失败",
|
||||
"update_failed": "更新用户失败",
|
||||
"delete_failed": "删除用户失败",
|
||||
"cannot_delete_self": "不能删除自己",
|
||||
"cannot_update_self_role": "不能修改自己的角色",
|
||||
"profile_update_failed": "更新个人资料失败",
|
||||
"avatar_upload_failed": "上传头像失败",
|
||||
"password_change_failed": "修改密码失败",
|
||||
"old_password_incorrect": "原密码错误"
|
||||
},
|
||||
"workspace": {
|
||||
"not_found": "工作空间不存在",
|
||||
"already_exists": "工作空间已存在",
|
||||
"name_required": "工作空间名称不能为空",
|
||||
"name_too_long": "工作空间名称过长",
|
||||
"create_failed": "创建工作空间失败",
|
||||
"update_failed": "更新工作空间失败",
|
||||
"delete_failed": "删除工作空间失败",
|
||||
"permission_denied": "没有权限访问此工作空间",
|
||||
"not_member": "不是工作空间成员",
|
||||
"already_member": "已经是工作空间成员",
|
||||
"member_limit_reached": "成员数量已达上限",
|
||||
"cannot_leave_last_manager": "不能离开,您是最后一个管理员",
|
||||
"cannot_remove_last_manager": "不能移除最后一个管理员",
|
||||
"cannot_remove_self": "不能移除自己",
|
||||
"invite_not_found": "邀请不存在",
|
||||
"invite_expired": "邀请已过期",
|
||||
"invite_already_accepted": "邀请已被接受",
|
||||
"invite_already_revoked": "邀请已被撤销",
|
||||
"invite_send_failed": "发送邀请失败",
|
||||
"archived": "工作空间已归档",
|
||||
"suspended": "工作空间已暂停"
|
||||
},
|
||||
"tenant": {
|
||||
"not_found": "租户不存在",
|
||||
"already_exists": "租户已存在",
|
||||
"create_failed": "创建租户失败",
|
||||
"update_failed": "更新租户失败",
|
||||
"delete_failed": "删除租户失败",
|
||||
"suspended": "租户已暂停",
|
||||
"expired": "租户已过期",
|
||||
"license_invalid": "许可证无效",
|
||||
"license_expired": "许可证已过期",
|
||||
"quota_exceeded": "配额已超限"
|
||||
},
|
||||
"file": {
|
||||
"not_found": "文件不存在",
|
||||
"upload_failed": "文件上传失败",
|
||||
"download_failed": "文件下载失败",
|
||||
"delete_failed": "文件删除失败",
|
||||
"too_large": "文件大小超过限制",
|
||||
"invalid_type": "不支持的文件类型",
|
||||
"invalid_format": "文件格式不正确",
|
||||
"corrupted": "文件已损坏",
|
||||
"storage_full": "存储空间已满",
|
||||
"access_denied": "没有权限访问此文件"
|
||||
},
|
||||
"api": {
|
||||
"rate_limit_exceeded": "API调用频率超限",
|
||||
"quota_exceeded": "API调用配额已用完",
|
||||
"invalid_api_key": "无效的API密钥",
|
||||
"api_key_expired": "API密钥已过期",
|
||||
"api_key_revoked": "API密钥已被撤销",
|
||||
"endpoint_not_found": "API端点不存在",
|
||||
"method_not_allowed": "不支持的请求方法",
|
||||
"invalid_request": "无效的请求",
|
||||
"missing_parameter": "缺少必需参数:{param}",
|
||||
"invalid_parameter": "参数无效:{param}"
|
||||
},
|
||||
"database": {
|
||||
"connection_failed": "数据库连接失败",
|
||||
"query_failed": "数据库查询失败",
|
||||
"transaction_failed": "数据库事务失败",
|
||||
"constraint_violation": "数据约束冲突",
|
||||
"duplicate_key": "数据重复",
|
||||
"foreign_key_violation": "外键约束冲突",
|
||||
"deadlock": "数据库死锁"
|
||||
},
|
||||
"validation": {
|
||||
"invalid_input": "输入数据无效",
|
||||
"missing_field": "缺少必需字段:{field}",
|
||||
"invalid_field": "字段无效:{field}",
|
||||
"field_too_long": "字段过长:{field}",
|
||||
"field_too_short": "字段过短:{field}",
|
||||
"invalid_format": "格式不正确:{field}",
|
||||
"invalid_value": "值无效:{field}",
|
||||
"out_of_range": "值超出范围:{field}"
|
||||
}
|
||||
}
|
||||
27
api/app/locales/zh/i18n.json
Normal file
27
api/app/locales/zh/i18n.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"language": {
|
||||
"not_found": "语言 {locale} 不存在",
|
||||
"already_exists": "语言 {locale} 已存在",
|
||||
"add_instructions": "语言 {locale} 验证成功。请在 {dir} 目录下创建翻译文件以完成添加。",
|
||||
"update_instructions": "语言 {locale} 更新验证成功。请更新环境变量 I18N_SUPPORTED_LANGUAGES 以应用配置更改。"
|
||||
},
|
||||
"namespace": {
|
||||
"not_found": "命名空间 {namespace} 在语言 {locale} 中不存在"
|
||||
},
|
||||
"translation": {
|
||||
"invalid_key_format": "翻译键格式无效: {key}。应使用格式: namespace.key.subkey",
|
||||
"update_instructions": "翻译 {locale}/{key} 更新验证成功。请修改对应的 JSON 翻译文件以应用更改。"
|
||||
},
|
||||
"reload": {
|
||||
"disabled": "翻译热重载功能已禁用。请在配置中启用 I18N_ENABLE_HOT_RELOAD。",
|
||||
"success": "翻译重载成功",
|
||||
"failed": "翻译重载失败: {error}"
|
||||
},
|
||||
"metrics": {
|
||||
"reset_success": "性能指标已重置"
|
||||
},
|
||||
"logs": {
|
||||
"export_success": "缺失翻译已导出到: {file}",
|
||||
"clear_success": "缺失翻译日志已清除"
|
||||
}
|
||||
}
|
||||
63
api/app/locales/zh/tenant.json
Normal file
63
api/app/locales/zh/tenant.json
Normal file
@@ -0,0 +1,63 @@
|
||||
{
|
||||
"info": {
|
||||
"get_success": "租户信息获取成功",
|
||||
"get_failed": "租户信息获取失败",
|
||||
"update_success": "租户信息更新成功",
|
||||
"update_failed": "租户信息更新失败"
|
||||
},
|
||||
"create": {
|
||||
"success": "租户创建成功",
|
||||
"failed": "租户创建失败"
|
||||
},
|
||||
"delete": {
|
||||
"success": "租户删除成功",
|
||||
"failed": "租户删除失败"
|
||||
},
|
||||
"status": {
|
||||
"activate_success": "租户启用成功",
|
||||
"activate_failed": "租户启用失败",
|
||||
"deactivate_success": "租户禁用成功",
|
||||
"deactivate_failed": "租户禁用失败"
|
||||
},
|
||||
"language": {
|
||||
"get_success": "租户语言配置获取成功",
|
||||
"get_failed": "租户语言配置获取失败",
|
||||
"update_success": "租户语言配置更新成功",
|
||||
"update_failed": "租户语言配置更新失败",
|
||||
"invalid_language": "不支持的语言代码",
|
||||
"default_not_in_supported": "默认语言必须在支持的语言列表中"
|
||||
},
|
||||
"list": {
|
||||
"get_success": "租户列表获取成功",
|
||||
"get_failed": "租户列表获取失败"
|
||||
},
|
||||
"users": {
|
||||
"list_success": "租户用户列表获取成功",
|
||||
"list_failed": "租户用户列表获取失败",
|
||||
"assign_success": "用户分配到租户成功",
|
||||
"assign_failed": "用户分配到租户失败",
|
||||
"remove_success": "用户从租户移除成功",
|
||||
"remove_failed": "用户从租户移除失败"
|
||||
},
|
||||
"statistics": {
|
||||
"get_success": "租户统计信息获取成功",
|
||||
"get_failed": "租户统计信息获取失败"
|
||||
},
|
||||
"validation": {
|
||||
"name_required": "租户名称不能为空",
|
||||
"name_invalid": "租户名称格式不正确",
|
||||
"name_too_long": "租户名称长度不能超过{max}个字符",
|
||||
"description_too_long": "租户描述长度不能超过{max}个字符",
|
||||
"language_code_invalid": "语言代码格式不正确",
|
||||
"supported_languages_empty": "支持的语言列表不能为空"
|
||||
},
|
||||
"errors": {
|
||||
"not_found": "租户不存在",
|
||||
"already_exists": "租户名称已存在",
|
||||
"permission_denied": "没有权限访问此租户",
|
||||
"has_users": "无法删除租户,存在关联的用户",
|
||||
"has_workspaces": "无法删除租户,存在关联的工作空间",
|
||||
"already_active": "租户已处于激活状态",
|
||||
"already_inactive": "租户已处于禁用状态"
|
||||
}
|
||||
}
|
||||
72
api/app/locales/zh/users.json
Normal file
72
api/app/locales/zh/users.json
Normal file
@@ -0,0 +1,72 @@
|
||||
{
|
||||
"info": {
|
||||
"get_success": "用户信息获取成功",
|
||||
"get_failed": "用户信息获取失败",
|
||||
"update_success": "用户信息更新成功",
|
||||
"update_failed": "用户信息更新失败"
|
||||
},
|
||||
"create": {
|
||||
"success": "用户创建成功",
|
||||
"failed": "用户创建失败",
|
||||
"superuser_success": "超级管理员创建成功",
|
||||
"superuser_failed": "超级管理员创建失败"
|
||||
},
|
||||
"delete": {
|
||||
"success": "用户删除成功",
|
||||
"failed": "用户删除失败",
|
||||
"deactivate_success": "用户停用成功",
|
||||
"deactivate_failed": "用户停用失败"
|
||||
},
|
||||
"activate": {
|
||||
"success": "用户激活成功",
|
||||
"failed": "用户激活失败"
|
||||
},
|
||||
"language": {
|
||||
"get_success": "语言偏好获取成功",
|
||||
"get_failed": "语言偏好获取失败",
|
||||
"update_success": "语言偏好更新成功",
|
||||
"update_failed": "语言偏好更新失败",
|
||||
"invalid_language": "不支持的语言代码",
|
||||
"current": "当前语言偏好"
|
||||
},
|
||||
"email": {
|
||||
"change_success": "邮箱修改成功",
|
||||
"change_failed": "邮箱修改失败",
|
||||
"code_sent": "验证码已发送到您的邮箱,请查收",
|
||||
"code_send_failed": "验证码发送失败",
|
||||
"code_invalid": "验证码无效或已过期",
|
||||
"already_exists": "该邮箱已被使用"
|
||||
},
|
||||
"list": {
|
||||
"get_success": "用户列表获取成功",
|
||||
"get_failed": "用户列表获取失败",
|
||||
"superusers_success": "租户超管列表获取成功",
|
||||
"superusers_failed": "租户超管列表获取失败"
|
||||
},
|
||||
"validation": {
|
||||
"username_required": "用户名不能为空",
|
||||
"username_invalid": "用户名格式不正确",
|
||||
"username_too_long": "用户名长度不能超过{max}个字符",
|
||||
"email_required": "邮箱不能为空",
|
||||
"email_invalid": "邮箱格式不正确",
|
||||
"password_required": "密码不能为空",
|
||||
"password_too_short": "密码长度不能少于{min}个字符",
|
||||
"password_too_long": "密码长度不能超过{max}个字符",
|
||||
"old_password_required": "旧密码不能为空",
|
||||
"new_password_required": "新密码不能为空",
|
||||
"verification_code_required": "验证码不能为空",
|
||||
"verification_code_invalid": "验证码格式不正确"
|
||||
},
|
||||
"errors": {
|
||||
"not_found": "用户不存在",
|
||||
"already_exists": "用户已存在",
|
||||
"permission_denied": "没有权限访问此用户",
|
||||
"cannot_delete_self": "不能删除自己",
|
||||
"cannot_deactivate_self": "不能停用自己",
|
||||
"already_deactivated": "用户已被停用",
|
||||
"already_activated": "用户已处于激活状态",
|
||||
"password_verification_failed": "密码验证失败",
|
||||
"old_password_incorrect": "旧密码不正确",
|
||||
"same_as_old_password": "新密码不能与旧密码相同"
|
||||
}
|
||||
}
|
||||
44
api/app/locales/zh/workspace.json
Normal file
44
api/app/locales/zh/workspace.json
Normal file
@@ -0,0 +1,44 @@
|
||||
{
|
||||
"list_retrieved": "工作空间列表获取成功",
|
||||
"created": "工作空间创建成功",
|
||||
"updated": "工作空间更新成功",
|
||||
"deleted": "工作空间删除成功",
|
||||
"switched": "工作空间切换成功",
|
||||
"not_found": "工作空间不存在或无权访问",
|
||||
"already_exists": "工作空间已存在",
|
||||
"permission_denied": "没有权限访问此工作空间",
|
||||
"name_required": "工作空间名称不能为空",
|
||||
"invalid_name": "工作空间名称格式不正确",
|
||||
"members": {
|
||||
"list_retrieved": "工作空间成员列表获取成功",
|
||||
"role_updated": "成员角色更新成功",
|
||||
"deleted": "成员删除成功",
|
||||
"not_found": "成员不存在",
|
||||
"cannot_remove_self": "不能删除自己",
|
||||
"cannot_remove_last_manager": "不能删除最后一个管理员",
|
||||
"already_member": "用户已经是工作空间成员"
|
||||
},
|
||||
"invites": {
|
||||
"created": "邀请创建成功",
|
||||
"list_retrieved": "邀请列表获取成功",
|
||||
"validated": "邀请验证成功",
|
||||
"revoked": "邀请撤销成功",
|
||||
"accepted": "邀请已接受",
|
||||
"not_found": "邀请不存在",
|
||||
"expired": "邀请已过期",
|
||||
"already_used": "邀请已被使用",
|
||||
"invalid_token": "无效的邀请令牌",
|
||||
"email_required": "邮箱地址不能为空",
|
||||
"invalid_email": "邮箱地址格式不正确"
|
||||
},
|
||||
"storage": {
|
||||
"type_retrieved": "存储类型获取成功",
|
||||
"type_updated": "存储类型更新成功",
|
||||
"invalid_type": "无效的存储类型"
|
||||
},
|
||||
"models": {
|
||||
"config_retrieved": "模型配置获取成功",
|
||||
"config_updated": "模型配置更新成功",
|
||||
"invalid_config": "无效的模型配置"
|
||||
}
|
||||
}
|
||||
196
api/app/main.py
196
api/app/main.py
@@ -92,6 +92,10 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Add i18n language detection middleware
|
||||
from app.i18n.middleware import LanguageMiddleware
|
||||
app.add_middleware(LanguageMiddleware)
|
||||
|
||||
logger.info("FastAPI应用程序启动")
|
||||
|
||||
|
||||
@@ -129,6 +133,11 @@ from app.core.exceptions import (
|
||||
from app.core.sensitive_filter import SensitiveDataFilter
|
||||
import traceback
|
||||
|
||||
# Import i18n exception support
|
||||
from app.i18n.exceptions import I18nException
|
||||
from app.i18n.service import get_translation_service
|
||||
from pydantic import ValidationError as PydanticValidationError
|
||||
|
||||
|
||||
# 处理验证异常
|
||||
@app.exception_handler(ValidationException)
|
||||
@@ -156,6 +165,131 @@ async def validation_exception_handler(request: Request, exc: ValidationExceptio
|
||||
)
|
||||
|
||||
|
||||
# 处理 i18n 异常(国际化异常)
|
||||
@app.exception_handler(I18nException)
|
||||
async def i18n_exception_handler(request: Request, exc: I18nException):
|
||||
"""
|
||||
处理国际化异常
|
||||
|
||||
I18nException 已经自动翻译了错误消息,直接返回即可
|
||||
"""
|
||||
# 获取当前语言
|
||||
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
|
||||
|
||||
# 获取异常详情(已经包含翻译后的消息)
|
||||
detail = exc.detail
|
||||
|
||||
# 过滤敏感信息
|
||||
if isinstance(detail, dict):
|
||||
filtered_message = SensitiveDataFilter.filter_string(detail.get("message", ""))
|
||||
filtered_detail = {
|
||||
**detail,
|
||||
"message": filtered_message
|
||||
}
|
||||
else:
|
||||
filtered_detail = SensitiveDataFilter.filter_string(str(detail))
|
||||
|
||||
logger.warning(
|
||||
f"I18n exception: {exc.error_key}",
|
||||
extra={
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
"error_code": exc.error_code,
|
||||
"error_key": exc.error_key,
|
||||
"language": language,
|
||||
"status_code": exc.status_code,
|
||||
"params": exc.params
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"success": False,
|
||||
**filtered_detail
|
||||
},
|
||||
headers=exc.headers
|
||||
)
|
||||
|
||||
|
||||
# 处理 Pydantic 验证错误(国际化支持)
|
||||
@app.exception_handler(PydanticValidationError)
|
||||
async def pydantic_validation_exception_handler(request: Request, exc: PydanticValidationError):
|
||||
"""
|
||||
处理 Pydantic 验证错误,支持国际化
|
||||
"""
|
||||
# 获取当前语言
|
||||
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
|
||||
|
||||
# 获取翻译服务
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# 翻译验证错误消息
|
||||
errors = []
|
||||
for error in exc.errors():
|
||||
field = ".".join(str(loc) for loc in error["loc"])
|
||||
error_type = error["type"]
|
||||
|
||||
# 尝试翻译错误消息
|
||||
if error_type == "value_error.missing":
|
||||
message = translation_service.translate(
|
||||
"errors.validation.missing_field",
|
||||
language,
|
||||
field=field
|
||||
)
|
||||
elif error_type == "value_error.any_str.max_length":
|
||||
message = translation_service.translate(
|
||||
"errors.validation.field_too_long",
|
||||
language,
|
||||
field=field
|
||||
)
|
||||
elif error_type == "value_error.any_str.min_length":
|
||||
message = translation_service.translate(
|
||||
"errors.validation.field_too_short",
|
||||
language,
|
||||
field=field
|
||||
)
|
||||
else:
|
||||
# 使用通用验证错误消息
|
||||
message = translation_service.translate(
|
||||
"errors.validation.invalid_field",
|
||||
language,
|
||||
field=field
|
||||
)
|
||||
|
||||
errors.append({
|
||||
"field": field,
|
||||
"message": message,
|
||||
"type": error_type
|
||||
})
|
||||
|
||||
# 翻译主错误消息
|
||||
main_message = translation_service.translate(
|
||||
"errors.common.validation_failed",
|
||||
language
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Pydantic validation error: {len(errors)} errors",
|
||||
extra={
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
"language": language,
|
||||
"errors": errors
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={
|
||||
"success": False,
|
||||
"error_code": "VALIDATION_FAILED",
|
||||
"message": main_message,
|
||||
"errors": errors
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 处理资源不存在异常
|
||||
@app.exception_handler(ResourceNotFoundException)
|
||||
async def not_found_exception_handler(request: Request, exc: ResourceNotFoundException):
|
||||
@@ -354,31 +488,66 @@ async def business_exception_handler(request: Request, exc: BusinessException):
|
||||
)
|
||||
|
||||
|
||||
# 统一异常处理:将HTTPException转换为统一响应结构
|
||||
# 统一异常处理:将HTTPException转换为统一响应结构(支持国际化)
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
"""处理HTTP异常"""
|
||||
# 过滤敏感信息
|
||||
filtered_detail = SensitiveDataFilter.filter_string(str(exc.detail))
|
||||
|
||||
"""处理HTTP异常,支持国际化"""
|
||||
# 获取当前语言
|
||||
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
|
||||
|
||||
# 获取翻译服务
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# 尝试翻译标准HTTP错误
|
||||
error_key_map = {
|
||||
400: "errors.common.bad_request",
|
||||
401: "errors.common.unauthorized",
|
||||
403: "errors.common.forbidden",
|
||||
404: "errors.common.not_found",
|
||||
405: "errors.common.method_not_allowed",
|
||||
409: "errors.common.conflict",
|
||||
422: "errors.common.validation_failed",
|
||||
429: "errors.common.too_many_requests",
|
||||
500: "errors.common.internal_error",
|
||||
503: "errors.common.service_unavailable",
|
||||
}
|
||||
|
||||
# 如果有对应的翻译键,使用翻译
|
||||
if exc.status_code in error_key_map:
|
||||
translated_message = translation_service.translate(
|
||||
error_key_map[exc.status_code],
|
||||
language
|
||||
)
|
||||
else:
|
||||
# 否则过滤原始消息
|
||||
translated_message = SensitiveDataFilter.filter_string(str(exc.detail))
|
||||
|
||||
logger.warning(
|
||||
f"HTTP exception: {filtered_detail}",
|
||||
f"HTTP exception: {translated_message}",
|
||||
extra={
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
"status_code": exc.status_code
|
||||
"status_code": exc.status_code,
|
||||
"language": language
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=fail(code=exc.status_code, msg=filtered_detail, error=filtered_detail)
|
||||
content=fail(code=exc.status_code, msg=translated_message, error=translated_message)
|
||||
)
|
||||
|
||||
|
||||
# 捕获未处理的异常,返回统一错误结构
|
||||
# 捕获未处理的异常,返回统一错误结构(支持国际化)
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
"""处理未捕获的异常"""
|
||||
"""处理未捕获的异常,支持国际化"""
|
||||
# 获取当前语言
|
||||
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
|
||||
|
||||
# 获取翻译服务
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# 记录完整的堆栈跟踪(日志过滤器会自动过滤敏感信息)
|
||||
logger.error(
|
||||
f"Unhandled exception: {exc}",
|
||||
@@ -386,6 +555,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
"exception_type": type(exc).__name__,
|
||||
"language": language,
|
||||
"traceback": traceback.format_exc()
|
||||
},
|
||||
exc_info=True
|
||||
@@ -394,7 +564,11 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
# 生产环境隐藏详细错误信息
|
||||
environment = os.getenv("ENVIRONMENT", "development")
|
||||
if environment == "production":
|
||||
message = "服务器内部错误,请稍后重试"
|
||||
# 使用翻译的通用错误消息
|
||||
message = translation_service.translate(
|
||||
"errors.common.internal_error",
|
||||
language
|
||||
)
|
||||
else:
|
||||
# 开发环境也要过滤敏感信息
|
||||
message = SensitiveDataFilter.filter_string(str(exc))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, DateTime, ForeignKey
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.db import Base
|
||||
from sqlalchemy.orm import relationship
|
||||
@@ -18,6 +18,8 @@ class AppShare(Base):
|
||||
source_workspace_id = Column(UUID(as_uuid=True), ForeignKey('workspaces.id'), nullable=False, comment="源工作空间ID")
|
||||
target_workspace_id = Column(UUID(as_uuid=True), ForeignKey('workspaces.id'), nullable=False, comment="目标工作空间ID")
|
||||
shared_by = Column(UUID(as_uuid=True), ForeignKey('users.id'), nullable=False, comment="分享者用户ID")
|
||||
permission = Column(String, default="readonly", nullable=False, comment="权限模式: readonly | editable")
|
||||
is_active = Column(Boolean, default=True, server_default='true', nullable=False, comment="是否有效,False 表示逻辑删除")
|
||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from app.db import Base
|
||||
from app.schemas import FileType
|
||||
|
||||
|
||||
class PerceptualType(IntEnum):
|
||||
@@ -15,6 +16,16 @@ class PerceptualType(IntEnum):
|
||||
TEXT = 3
|
||||
CONVERSATION = 4
|
||||
|
||||
@staticmethod
|
||||
def trans_from_file_type(file_type: FileType | str):
|
||||
type_map = {
|
||||
FileType.IMAGE: PerceptualType.VISION,
|
||||
FileType.AUDIO: PerceptualType.AUDIO,
|
||||
FileType.VIDEO: PerceptualType.VISION,
|
||||
FileType.DOCUMENT: PerceptualType.TEXT
|
||||
}
|
||||
return type_map.get(file_type, PerceptualType.TEXT)
|
||||
|
||||
|
||||
class FileStorageService(IntEnum):
|
||||
LOCAL = 1
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, DateTime, Boolean
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy import Column, String, DateTime, Boolean, text
|
||||
from sqlalchemy.dialects.postgresql import UUID, ARRAY
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db import Base
|
||||
|
||||
@@ -20,6 +20,10 @@ class Tenants(Base):
|
||||
external_id = Column(String(100), nullable=True, index=True) # 外部企业ID
|
||||
external_source = Column(String(50), nullable=True) # 来源系统
|
||||
|
||||
# 国际化语言配置字段
|
||||
default_language = Column(String(10), nullable=False, default='zh', server_default='zh', index=True) # 租户默认语言
|
||||
supported_languages = Column(ARRAY(String(10)), nullable=False, default=lambda: ['zh', 'en'], server_default=text("'{zh,en}'")) # 租户支持的语言列表
|
||||
|
||||
# Relationship to users - one tenant has many users
|
||||
users = relationship("User", back_populates="tenant")
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db import Base
|
||||
@@ -22,6 +22,9 @@ class User(Base):
|
||||
external_id = Column(String(100), nullable=True) # 外部用户ID
|
||||
external_source = Column(String(50), nullable=True) # 来源系统
|
||||
|
||||
# 用户语言偏好
|
||||
preferred_language = Column(String(10), server_default=text("'zh'"), default='zh', nullable=False, index=True) # 用户偏好语言,默认中文
|
||||
|
||||
current_workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=True) # 当前工作空间ID,可为空
|
||||
|
||||
# Foreign key to tenant - each user belongs to exactly one tenant
|
||||
|
||||
@@ -2,7 +2,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from sqlalchemy import and_, desc
|
||||
from sqlalchemy import and_, desc, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
@@ -127,6 +127,17 @@ class MemoryPerceptualRepository:
|
||||
db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_by_url(
|
||||
self,
|
||||
file_url: str
|
||||
) -> list[MemoryPerceptualModel]:
|
||||
try:
|
||||
stmt = select(MemoryPerceptualModel).where(MemoryPerceptualModel.file_path == file_url)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
except Exception:
|
||||
db_logger.error(f"Failed to query perceptual memories by file_url: file_url={file_url}")
|
||||
raise
|
||||
|
||||
def get_by_type(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
|
||||
194
api/app/repositories/neo4j/community_repository.py
Normal file
194
api/app/repositories/neo4j/community_repository.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Community 节点仓库
|
||||
|
||||
管理 Neo4j 中 Community 节点及 BELONGS_TO_COMMUNITY 边的 CRUD 操作。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
COMMUNITY_NODE_UPSERT,
|
||||
ENTITY_JOIN_COMMUNITY,
|
||||
ENTITY_LEAVE_ALL_COMMUNITIES,
|
||||
GET_ENTITY_NEIGHBORS,
|
||||
GET_ALL_ENTITIES_FOR_USER,
|
||||
GET_COMMUNITY_MEMBERS,
|
||||
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
||||
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
||||
CHECK_USER_HAS_COMMUNITIES,
|
||||
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||
UPDATE_COMMUNITY_METADATA,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommunityRepository:
|
||||
def __init__(self, connector: Neo4jConnector):
|
||||
self.connector = connector
|
||||
|
||||
async def upsert_community(
|
||||
self, community_id: str, end_user_id: str, member_count: int = 0
|
||||
) -> Optional[str]:
|
||||
"""创建或更新 Community 节点,返回 community_id。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
COMMUNITY_NODE_UPSERT,
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
member_count=member_count,
|
||||
)
|
||||
return result[0]["community_id"] if result else None
|
||||
except Exception as e:
|
||||
logger.error(f"upsert_community failed: {e}")
|
||||
return None
|
||||
|
||||
async def assign_entity_to_community(
|
||||
self, entity_id: str, community_id: str, end_user_id: str
|
||||
) -> bool:
|
||||
"""将实体关联到社区(先解除旧关联,再建立新关联)。"""
|
||||
try:
|
||||
await self.connector.execute_query(
|
||||
ENTITY_LEAVE_ALL_COMMUNITIES,
|
||||
entity_id=entity_id,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
result = await self.connector.execute_query(
|
||||
ENTITY_JOIN_COMMUNITY,
|
||||
entity_id=entity_id,
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
logger.error(f"assign_entity_to_community failed: {e}")
|
||||
return False
|
||||
|
||||
async def get_entity_neighbors(
|
||||
self, entity_id: str, end_user_id: str
|
||||
) -> List[Dict]:
|
||||
"""查询实体的直接邻居及其社区归属。"""
|
||||
try:
|
||||
return await self.connector.execute_query(
|
||||
GET_ENTITY_NEIGHBORS,
|
||||
entity_id=entity_id,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"get_entity_neighbors failed: {e}")
|
||||
return []
|
||||
|
||||
async def get_all_entity_neighbors_batch(
|
||||
self, end_user_id: str
|
||||
) -> Dict[str, List[Dict]]:
|
||||
"""一次性批量拉取该用户下所有实体的邻居,返回 {entity_id: [neighbors]} 字典。
|
||||
用于全量聚类预加载,避免每个实体单独查询。"""
|
||||
try:
|
||||
rows = await self.connector.execute_query(
|
||||
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
result: Dict[str, List[Dict]] = {}
|
||||
for row in rows:
|
||||
eid = row["entity_id"]
|
||||
neighbor = {k: v for k, v in row.items() if k != "entity_id"}
|
||||
result.setdefault(eid, []).append(neighbor)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"get_all_entity_neighbors_batch failed: {e}")
|
||||
return {}
|
||||
|
||||
async def get_all_entities(self, end_user_id: str) -> List[Dict]:
|
||||
"""拉取某用户下所有实体及其当前社区归属。"""
|
||||
try:
|
||||
return await self.connector.execute_query(
|
||||
GET_ALL_ENTITIES_FOR_USER,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"get_all_entities failed: {e}")
|
||||
return []
|
||||
|
||||
async def get_community_members(
|
||||
self, community_id: str, end_user_id: str
|
||||
) -> List[Dict]:
|
||||
"""查询社区成员列表。"""
|
||||
try:
|
||||
return await self.connector.execute_query(
|
||||
GET_COMMUNITY_MEMBERS,
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"get_community_members failed: {e}")
|
||||
return []
|
||||
|
||||
async def get_all_community_members_batch(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
) -> Dict[str, List[Dict]]:
|
||||
"""批量查询多个社区的成员,返回 {community_id: [members]} 字典。"""
|
||||
try:
|
||||
rows = await self.connector.execute_query(
|
||||
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
||||
community_ids=community_ids,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
result: Dict[str, List[Dict]] = {}
|
||||
for row in rows:
|
||||
cid = row["community_id"]
|
||||
result.setdefault(cid, []).append(row)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"get_all_community_members_batch failed: {e}")
|
||||
return {}
|
||||
|
||||
async def has_communities(self, end_user_id: str) -> bool:
|
||||
"""检查该用户是否已有 Community 节点(用于判断全量 vs 增量)。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
CHECK_USER_HAS_COMMUNITIES,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
return result[0]["community_count"] > 0 if result else False
|
||||
except Exception as e:
|
||||
logger.error(f"has_communities failed: {e}")
|
||||
return False
|
||||
|
||||
async def refresh_member_count(
|
||||
self, community_id: str, end_user_id: str
|
||||
) -> int:
|
||||
"""重新统计并更新社区成员数,返回最新数量。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
return result[0]["member_count"] if result else 0
|
||||
except Exception as e:
|
||||
logger.error(f"refresh_member_count failed: {e}")
|
||||
return 0
|
||||
|
||||
async def update_community_metadata(
|
||||
self,
|
||||
community_id: str,
|
||||
end_user_id: str,
|
||||
name: str,
|
||||
summary: str,
|
||||
core_entities: List[str],
|
||||
) -> bool:
|
||||
"""更新社区的名称、摘要和核心实体列表。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
UPDATE_COMMUNITY_METADATA,
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
name=name,
|
||||
summary=summary,
|
||||
core_entities=core_entities,
|
||||
)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
logger.error(f"update_community_metadata failed: {e}")
|
||||
return False
|
||||
@@ -1058,4 +1058,147 @@ Graph_Node_query = """
|
||||
3 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
"""
|
||||
"""
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Community 节点 & BELONGS_TO_COMMUNITY 边
|
||||
# ============================================================
|
||||
|
||||
# ─── Community 聚类相关 Cypher 模板 ───────────────────────────────────────────
|
||||
|
||||
COMMUNITY_NODE_UPSERT = """
|
||||
MERGE (c:Community {community_id: $community_id})
|
||||
SET c.end_user_id = $end_user_id,
|
||||
c.member_count = $member_count,
|
||||
c.updated_at = datetime()
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
ENTITY_JOIN_COMMUNITY = """
|
||||
MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id})
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
MERGE (e)-[:BELONGS_TO_COMMUNITY]->(c)
|
||||
SET c.updated_at = datetime()
|
||||
RETURN e.id AS entity_id, c.community_id AS community_id
|
||||
"""
|
||||
|
||||
ENTITY_LEAVE_ALL_COMMUNITIES = """
|
||||
MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id})
|
||||
MATCH (e)-[r:BELONGS_TO_COMMUNITY]->(:Community)
|
||||
DELETE r
|
||||
"""
|
||||
|
||||
GET_ENTITY_NEIGHBORS = """
|
||||
MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id})
|
||||
|
||||
// 来源一:直接关系邻居(EXTRACTED_RELATIONSHIP 边)
|
||||
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
|
||||
|
||||
// 来源二:同 Statement 共现邻居(REFERENCES_ENTITY 边)
|
||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
|
||||
WHERE nb2.id <> e.id
|
||||
|
||||
WITH collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
|
||||
UNWIND all_neighbors AS nb
|
||||
WITH nb WHERE nb IS NOT NULL
|
||||
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||
RETURN DISTINCT
|
||||
nb.id AS id,
|
||||
nb.name AS name,
|
||||
nb.name_embedding AS name_embedding,
|
||||
nb.activation_value AS activation_value,
|
||||
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||
"""
|
||||
|
||||
GET_ALL_ENTITIES_FOR_USER = """
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||
OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||
RETURN e.id AS id,
|
||||
e.name AS name,
|
||||
e.name_embedding AS name_embedding,
|
||||
e.activation_value AS activation_value,
|
||||
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||
"""
|
||||
|
||||
GET_COMMUNITY_MEMBERS = """
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
|
||||
RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
||||
e.importance_score AS importance_score, e.activation_value AS activation_value,
|
||||
e.name_embedding AS name_embedding
|
||||
ORDER BY coalesce(e.activation_value, 0) DESC
|
||||
"""
|
||||
|
||||
GET_ALL_COMMUNITY_MEMBERS_BATCH = """
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||
WHERE c.community_id IN $community_ids
|
||||
RETURN c.community_id AS community_id,
|
||||
e.id AS id,
|
||||
e.name_embedding AS name_embedding,
|
||||
e.activation_value AS activation_value
|
||||
"""
|
||||
|
||||
CHECK_USER_HAS_COMMUNITIES = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
RETURN count(c) AS community_count
|
||||
"""
|
||||
|
||||
UPDATE_COMMUNITY_MEMBER_COUNT = """
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
|
||||
WITH c, count(e) AS cnt
|
||||
SET c.member_count = cnt
|
||||
RETURN c.community_id AS community_id, cnt AS member_count
|
||||
"""
|
||||
|
||||
UPDATE_COMMUNITY_METADATA = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
SET c.name = $name,
|
||||
c.summary = $summary,
|
||||
c.core_entities = $core_entities,
|
||||
c.updated_at = datetime()
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
GET_ALL_ENTITY_NEIGHBORS_BATCH = """
|
||||
// 批量拉取某用户下所有实体的邻居(用于全量聚类预加载)
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||
|
||||
// 来源一:直接关系邻居
|
||||
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
|
||||
|
||||
// 来源二:同 Statement 共现邻居
|
||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
|
||||
WHERE nb2.id <> e.id
|
||||
|
||||
WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
|
||||
UNWIND all_neighbors AS nb
|
||||
WITH e, nb WHERE nb IS NOT NULL
|
||||
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||
RETURN DISTINCT
|
||||
e.id AS entity_id,
|
||||
nb.id AS id,
|
||||
nb.name AS name,
|
||||
nb.name_embedding AS name_embedding,
|
||||
nb.activation_value AS activation_value,
|
||||
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||
"""
|
||||
|
||||
GET_COMMUNITY_GRAPH_DATA = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[b:BELONGS_TO_COMMUNITY]->(c)
|
||||
OPTIONAL MATCH (e)-[r:EXTRACTED_RELATIONSHIP]-(e2:ExtractedEntity {end_user_id: $end_user_id})
|
||||
RETURN
|
||||
elementId(c) AS c_id,
|
||||
properties(c) AS c_props,
|
||||
elementId(e) AS e_id,
|
||||
properties(e) AS e_props,
|
||||
elementId(b) AS b_id,
|
||||
elementId(e2) AS e2_id,
|
||||
properties(e2) AS e2_props,
|
||||
elementId(r) AS r_id,
|
||||
type(r) AS r_type,
|
||||
properties(r) AS r_props,
|
||||
startNode(r) = e AS r_from_e
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import List
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -155,7 +157,9 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
entity_edges: List[EntityEntityEdge],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
connector: Neo4jConnector,
|
||||
config_id: Optional[str] = None,
|
||||
llm_model_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||
|
||||
@@ -288,6 +292,10 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
}
|
||||
logger.info("Transaction completed. Summary: %s", summary)
|
||||
logger.debug("Full transaction results: %r", results)
|
||||
|
||||
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
||||
schedule_clustering_after_write(entity_nodes, config_id=config_id, llm_model_id=llm_model_id)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -295,3 +303,55 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
print(f"Neo4j integration error: {e}")
|
||||
print("Continuing without database storage...")
|
||||
return False
|
||||
|
||||
|
||||
def schedule_clustering_after_write(
|
||||
entity_nodes: List,
|
||||
config_id: Optional[str] = None,
|
||||
llm_model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
写入 Neo4j 成功后,调度后台聚类任务。
|
||||
|
||||
可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。
|
||||
使用 asyncio.create_task 异步触发,不阻塞写入响应。
|
||||
"""
|
||||
if not entity_nodes:
|
||||
return
|
||||
|
||||
clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false"
|
||||
if not clustering_enabled:
|
||||
logger.info("[Clustering] 聚类已禁用(CLUSTERING_ENABLED=false),跳过聚类触发")
|
||||
return
|
||||
|
||||
end_user_id = entity_nodes[0].end_user_id
|
||||
new_entity_ids = [e.id for e in entity_nodes]
|
||||
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id))
|
||||
|
||||
|
||||
async def _trigger_clustering(
|
||||
new_entity_ids: List[str],
|
||||
end_user_id: str,
|
||||
config_id: Optional[str] = None,
|
||||
llm_model_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
聚类触发函数,自动判断全量初始化还是增量更新。
|
||||
"""
|
||||
connector = None
|
||||
try:
|
||||
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
|
||||
logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
|
||||
connector = Neo4jConnector()
|
||||
engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id)
|
||||
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
||||
logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 聚类触发失败: {e}", exc_info=True)
|
||||
finally:
|
||||
if connector:
|
||||
try:
|
||||
await connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -277,7 +277,12 @@ class App(BaseModel):
|
||||
tags: List[str] = []
|
||||
current_release_id: Optional[uuid.UUID] = None
|
||||
is_active: bool
|
||||
is_shared: bool = False # 是否是共享应用(从其他工作空间共享来的)
|
||||
is_shared: bool = False
|
||||
share_permission: Optional[str] = None
|
||||
source_workspace_name: Optional[str] = None # 共享来源工作空间名称(仅共享应用有值)
|
||||
source_workspace_icon: Optional[str] = None # 共享来源工作空间图标
|
||||
source_app_version: Optional[str] = None # 应用版本号
|
||||
source_app_is_active: Optional[bool] = None # 应用是否生效
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@@ -422,6 +427,12 @@ class AppRelease(BaseModel):
|
||||
class AppShareCreate(BaseModel):
|
||||
"""应用分享请求"""
|
||||
target_workspace_ids: List[uuid.UUID] = Field(..., description="目标工作空间ID列表")
|
||||
permission: str = Field(default="readonly", description="权限模式: readonly | editable")
|
||||
|
||||
|
||||
class UpdateSharePermissionRequest(BaseModel):
|
||||
"""更新共享权限请求"""
|
||||
permission: str = Field(..., description="新权限值: readonly | editable")
|
||||
|
||||
|
||||
class AppShare(BaseModel):
|
||||
@@ -433,9 +444,32 @@ class AppShare(BaseModel):
|
||||
source_workspace_id: uuid.UUID
|
||||
target_workspace_id: uuid.UUID
|
||||
shared_by: uuid.UUID
|
||||
permission: str = "readonly"
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
# 关联名称(从 relationship 读取)
|
||||
source_app_name: Optional[str] = None
|
||||
source_app_type: Optional[str] = None
|
||||
source_app_version: Optional[str] = None
|
||||
source_app_is_active: Optional[bool] = None
|
||||
target_workspace_name: Optional[str] = None
|
||||
target_workspace_icon: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def model_validate(cls, obj, **kwargs):
|
||||
instance = super().model_validate(obj, **kwargs)
|
||||
if hasattr(obj, 'source_app') and obj.source_app:
|
||||
instance.source_app_name = obj.source_app.name
|
||||
instance.source_app_type = obj.source_app.type
|
||||
instance.source_app_is_active = obj.source_app.is_active
|
||||
release = obj.source_app.current_release
|
||||
instance.source_app_version = release.version_name if release else None
|
||||
if hasattr(obj, 'target_workspace') and obj.target_workspace:
|
||||
instance.target_workspace_name = obj.target_workspace.name
|
||||
instance.target_workspace_icon = obj.target_workspace.icon
|
||||
return instance
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
73
api/app/schemas/i18n_schema.py
Normal file
73
api/app/schemas/i18n_schema.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
I18n Management API Schemas
|
||||
|
||||
This module defines Pydantic schemas for i18n management APIs.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Language Management Schemas
|
||||
# ============================================================================
|
||||
|
||||
class LanguageInfo(BaseModel):
|
||||
"""Language information"""
|
||||
code: str = Field(..., description="Language code (e.g., 'zh', 'en')")
|
||||
name: str = Field(..., description="Language name (e.g., 'Chinese', 'English')")
|
||||
native_name: str = Field(..., description="Native language name (e.g., '中文', 'English')")
|
||||
is_enabled: bool = Field(..., description="Whether the language is enabled")
|
||||
is_default: bool = Field(..., description="Whether this is the default language")
|
||||
|
||||
|
||||
class LanguageListResponse(BaseModel):
|
||||
"""Response for language list"""
|
||||
languages: List[LanguageInfo] = Field(..., description="List of available languages")
|
||||
|
||||
|
||||
class LanguageCreateRequest(BaseModel):
|
||||
"""Request to add a new language"""
|
||||
code: str = Field(..., description="Language code (e.g., 'ja', 'ko')", min_length=2, max_length=10)
|
||||
name: str = Field(..., description="Language name", min_length=1, max_length=100)
|
||||
native_name: str = Field(..., description="Native language name", min_length=1, max_length=100)
|
||||
is_enabled: bool = Field(default=True, description="Whether to enable the language")
|
||||
|
||||
|
||||
class LanguageUpdateRequest(BaseModel):
|
||||
"""Request to update language configuration"""
|
||||
is_enabled: Optional[bool] = Field(None, description="Whether the language is enabled")
|
||||
is_default: Optional[bool] = Field(None, description="Whether this is the default language")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Translation Management Schemas
|
||||
# ============================================================================
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
"""Response for translation data"""
|
||||
translations: Dict[str, Dict[str, Any]] = Field(
|
||||
...,
|
||||
description="Translations organized by locale and namespace"
|
||||
)
|
||||
|
||||
|
||||
class TranslationUpdateRequest(BaseModel):
|
||||
"""Request to update a translation"""
|
||||
value: str = Field(..., description="New translation value", min_length=1)
|
||||
description: Optional[str] = Field(None, description="Optional description of the translation")
|
||||
|
||||
|
||||
class MissingTranslationsResponse(BaseModel):
|
||||
"""Response for missing translations"""
|
||||
missing_translations: Dict[str, List[str]] = Field(
|
||||
...,
|
||||
description="Missing translation keys organized by locale"
|
||||
)
|
||||
|
||||
|
||||
class ReloadResponse(BaseModel):
|
||||
"""Response for translation reload"""
|
||||
success: bool = Field(..., description="Whether the reload was successful")
|
||||
reloaded_locales: List[str] = Field(..., description="List of reloaded locales")
|
||||
total_locales: int = Field(..., description="Total number of available locales")
|
||||
@@ -25,5 +25,6 @@ class AgentMemory_Long_Term(ABC):
|
||||
STRATEGY_CHUNK = "chunk"
|
||||
STRATEGY_TIME = "time"
|
||||
DEFAULT_SCOPE = 6
|
||||
TIME_SCOPE=5
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -85,7 +84,6 @@ class Semantic(BaseModel):
|
||||
|
||||
|
||||
class Content(BaseModel):
|
||||
summary: str
|
||||
keywords: list[str]
|
||||
topic: str
|
||||
domain: str
|
||||
|
||||
@@ -326,3 +326,14 @@ class ModelBaseQuery(BaseModel):
|
||||
is_official: Optional[bool] = Field(None, description="是否官方模型")
|
||||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""模型信息Schema"""
|
||||
model_name: str = Field(..., description="模型名称")
|
||||
provider: str = Field(..., description="模型提供商")
|
||||
api_key: str = Field(..., description="API密钥")
|
||||
api_base: str = Field(..., description="API基础URL")
|
||||
is_omni: bool = Field(default=False, description="是否为omni模型")
|
||||
model_type: ModelType = Field(..., description="模型类型")
|
||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ class TenantBase(BaseModel):
|
||||
name: str = Field(..., description="租户名称", max_length=255)
|
||||
description: Optional[str] = Field(None, description="租户描述", max_length=1000)
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
default_language: Optional[str] = Field('zh', description="租户默认语言", max_length=10)
|
||||
supported_languages: Optional[List[str]] = Field(['zh', 'en'], description="租户支持的语言列表")
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
@@ -18,6 +20,26 @@ class TenantBase(BaseModel):
|
||||
if not v or not v.strip():
|
||||
raise ValidationException('租户名称不能为空', code=BizCode.VALIDATION_FAILED)
|
||||
return v.strip()
|
||||
|
||||
@field_validator('default_language')
|
||||
@classmethod
|
||||
def validate_default_language(cls, v):
|
||||
if v:
|
||||
# Validate language code format (2-letter code, optionally with region)
|
||||
import re
|
||||
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', v):
|
||||
raise ValidationException('语言代码格式不正确', code=BizCode.VALIDATION_FAILED)
|
||||
return v
|
||||
|
||||
@field_validator('supported_languages')
|
||||
@classmethod
|
||||
def validate_supported_languages(cls, v):
|
||||
if v:
|
||||
import re
|
||||
for lang in v:
|
||||
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', lang):
|
||||
raise ValidationException(f'语言代码格式不正确: {lang}', code=BizCode.VALIDATION_FAILED)
|
||||
return v
|
||||
|
||||
|
||||
class TenantCreate(TenantBase):
|
||||
@@ -30,6 +52,8 @@ class TenantUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, description="租户名称", max_length=255)
|
||||
description: Optional[str] = Field(None, description="租户描述", max_length=1000)
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
default_language: Optional[str] = Field(None, description="租户默认语言", max_length=10)
|
||||
supported_languages: Optional[List[str]] = Field(None, description="租户支持的语言列表")
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
@@ -37,6 +61,25 @@ class TenantUpdate(BaseModel):
|
||||
if v is not None and (not v or not v.strip()):
|
||||
raise ValidationException('租户名称不能为空', code=BizCode.VALIDATION_FAILED)
|
||||
return v.strip() if v else v
|
||||
|
||||
@field_validator('default_language')
|
||||
@classmethod
|
||||
def validate_default_language(cls, v):
|
||||
if v:
|
||||
import re
|
||||
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', v):
|
||||
raise ValidationException('语言代码格式不正确', code=BizCode.VALIDATION_FAILED)
|
||||
return v
|
||||
|
||||
@field_validator('supported_languages')
|
||||
@classmethod
|
||||
def validate_supported_languages(cls, v):
|
||||
if v:
|
||||
import re
|
||||
for lang in v:
|
||||
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', lang):
|
||||
raise ValidationException(f'语言代码格式不正确: {lang}', code=BizCode.VALIDATION_FAILED)
|
||||
return v
|
||||
|
||||
|
||||
class Tenant(TenantBase):
|
||||
@@ -62,4 +105,29 @@ class TenantList(BaseModel):
|
||||
total: int
|
||||
page: int
|
||||
size: int
|
||||
pages: int
|
||||
pages: int
|
||||
|
||||
|
||||
class TenantLanguageConfig(BaseModel):
|
||||
"""租户语言配置Schema"""
|
||||
default_language: str = Field(..., description="租户默认语言", max_length=10)
|
||||
supported_languages: List[str] = Field(..., description="租户支持的语言列表")
|
||||
|
||||
@field_validator('default_language')
|
||||
@classmethod
|
||||
def validate_default_language(cls, v):
|
||||
import re
|
||||
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', v):
|
||||
raise ValidationException('语言代码格式不正确', code=BizCode.VALIDATION_FAILED)
|
||||
return v
|
||||
|
||||
@field_validator('supported_languages')
|
||||
@classmethod
|
||||
def validate_supported_languages(cls, v):
|
||||
if not v:
|
||||
raise ValidationException('支持的语言列表不能为空', code=BizCode.VALIDATION_FAILED)
|
||||
import re
|
||||
for lang in v:
|
||||
if not re.match(r'^[a-z]{2}(-[A-Z]{2})?$', lang):
|
||||
raise ValidationException(f'语言代码格式不正确: {lang}', code=BizCode.VALIDATION_FAILED)
|
||||
return v
|
||||
|
||||
@@ -58,6 +58,16 @@ class VerifyPasswordRequest(BaseModel):
|
||||
password: str = Field(..., description="密码")
|
||||
|
||||
|
||||
class LanguagePreferenceRequest(BaseModel):
|
||||
"""语言偏好设置请求"""
|
||||
language: str = Field(..., min_length=2, max_length=10, description="语言代码,如 'zh', 'en'")
|
||||
|
||||
|
||||
class LanguagePreferenceResponse(BaseModel):
|
||||
"""语言偏好响应"""
|
||||
language: str = Field(..., description="当前语言偏好")
|
||||
|
||||
|
||||
class ChangePasswordResponse(BaseModel):
|
||||
"""修改密码响应"""
|
||||
message: str
|
||||
@@ -74,6 +84,7 @@ class User(UserBase):
|
||||
current_workspace_id: Optional[uuid.UUID] = None
|
||||
current_workspace_name: Optional[str] = None
|
||||
role: Optional[WorkspaceRole] = None
|
||||
preferred_language: Optional[str] = "zh" # 用户语言偏好
|
||||
|
||||
# 将 datetime 转换为毫秒时间戳
|
||||
@validator("created_at", pre=True)
|
||||
|
||||
@@ -8,25 +8,21 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db
|
||||
from app.models import MultiAgentConfig, AgentConfig
|
||||
from app.models import MultiAgentConfig, AgentConfig, ModelType
|
||||
from app.models import WorkflowConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \
|
||||
AgentRunService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.tool_service import ToolService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
|
||||
logger = get_business_logger()
|
||||
@@ -126,8 +122,17 @@ class AppChatService:
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 调用 Agent(支持多模态)
|
||||
@@ -266,8 +271,17 @@ class AppChatService:
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 流式调用 Agent(支持多模态)
|
||||
|
||||
@@ -12,7 +12,7 @@ import uuid
|
||||
from typing import Annotated, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy import and_, delete, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -102,7 +102,8 @@ class AppService:
|
||||
# 2. 检查是否是共享给本工作空间的应用
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app.id,
|
||||
AppShare.target_workspace_id == workspace_id
|
||||
AppShare.target_workspace_id == workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
share = self.db.scalars(stmt).first()
|
||||
|
||||
@@ -125,6 +126,50 @@ class AppService:
|
||||
)
|
||||
raise BusinessException("应用不可访问", BizCode.WORKSPACE_NO_ACCESS)
|
||||
|
||||
def _get_share_permission(self, app: App, workspace_id: Optional[uuid.UUID]) -> Optional[str]:
|
||||
"""获取共享应用的权限
|
||||
|
||||
Returns:
|
||||
None: 不是共享应用(是本工作空间的应用)
|
||||
'readonly': 只读共享
|
||||
'editable': 可编辑共享
|
||||
"""
|
||||
from app.models import AppShare
|
||||
|
||||
if workspace_id is None or app.workspace_id == workspace_id:
|
||||
return None # 本工作空间的应用,不是共享的
|
||||
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app.id,
|
||||
AppShare.target_workspace_id == workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
share = self.db.scalars(stmt).first()
|
||||
return share.permission if share else None
|
||||
|
||||
def _validate_app_writable(self, app: App, workspace_id: Optional[uuid.UUID]) -> None:
|
||||
"""Validate that the app config is writable (owner only).
|
||||
|
||||
Shared apps (both readonly and editable) cannot modify config.
|
||||
- Own workspace app: allowed
|
||||
- Any shared app: denied
|
||||
|
||||
Raises:
|
||||
BusinessException: when app is not writable
|
||||
"""
|
||||
if workspace_id is None:
|
||||
return
|
||||
|
||||
# Own workspace app, allow
|
||||
if app.workspace_id == workspace_id:
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"应用写操作被拒",
|
||||
extra={"app_id": str(app.id), "workspace_id": str(workspace_id)}
|
||||
)
|
||||
raise BusinessException("共享应用不可修改配置", BizCode.WORKSPACE_NO_ACCESS)
|
||||
|
||||
def _get_app_or_404(self, app_id: uuid.UUID) -> App:
|
||||
"""获取应用或抛出404异常
|
||||
|
||||
@@ -454,6 +499,33 @@ class AppService:
|
||||
Returns:
|
||||
app_schema.App: 应用 Schema
|
||||
"""
|
||||
is_shared = app.workspace_id != current_workspace_id
|
||||
share_permission = None
|
||||
source_workspace_name = None
|
||||
source_workspace_icon = None
|
||||
source_app_version = None
|
||||
source_app_is_active = None
|
||||
|
||||
if is_shared:
|
||||
# 查询共享权限和来源工作空间名称
|
||||
from app.models import AppShare
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app.id,
|
||||
AppShare.target_workspace_id == current_workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
share = self.db.scalars(stmt).first()
|
||||
if share:
|
||||
share_permission = share.permission
|
||||
if share.source_workspace:
|
||||
source_workspace_name = share.source_workspace.name
|
||||
source_workspace_icon = share.source_workspace.icon
|
||||
|
||||
# 版本号和生效状态
|
||||
if app.current_release:
|
||||
source_app_version = app.current_release.version_name
|
||||
source_app_is_active = app.is_active
|
||||
|
||||
app_dict = {
|
||||
"id": app.id,
|
||||
"workspace_id": app.workspace_id,
|
||||
@@ -468,7 +540,12 @@ class AppService:
|
||||
"tags": app.tags or [],
|
||||
"current_release_id": app.current_release_id,
|
||||
"is_active": app.is_active,
|
||||
"is_shared": app.workspace_id != current_workspace_id, # 判断是否是共享应用
|
||||
"is_shared": is_shared,
|
||||
"share_permission": share_permission,
|
||||
"source_workspace_name": source_workspace_name,
|
||||
"source_workspace_icon": source_workspace_icon,
|
||||
"source_app_version": source_app_version,
|
||||
"source_app_is_active": source_app_is_active,
|
||||
"created_at": app.created_at,
|
||||
"updated_at": app.updated_at
|
||||
}
|
||||
@@ -594,7 +671,7 @@ class AppService:
|
||||
logger.info("更新应用", extra={"app_id": str(app_id)})
|
||||
|
||||
app = self._get_app_or_404(app_id)
|
||||
self._validate_workspace_access(app, workspace_id)
|
||||
self._validate_app_writable(app, workspace_id)
|
||||
|
||||
changed = False
|
||||
for field in ["name", "description", "icon", "icon_type", "visibility", "status", "tags"]:
|
||||
@@ -804,6 +881,7 @@ class AppService:
|
||||
status: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
include_shared: bool = True,
|
||||
shared_only: bool = False,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
) -> Tuple[List[App], int]:
|
||||
@@ -849,18 +927,24 @@ class AppService:
|
||||
if search:
|
||||
filters.append(func.lower(App.name).like(f"%{search.lower()}%"))
|
||||
|
||||
# 基础查询:本工作空间的应用
|
||||
if include_shared:
|
||||
# 查询本工作空间的应用 + 分享给本工作空间的应用
|
||||
# 使用 OR 条件:workspace_id = current OR app_id IN (shared apps)
|
||||
# shared_only implies include_shared; enforce to avoid confusing API usage
|
||||
if shared_only:
|
||||
include_shared = True
|
||||
|
||||
# 获取分享给本工作空间的应用ID列表
|
||||
# 基础查询:本工作空间的应用
|
||||
if shared_only:
|
||||
# 只返回共享给本工作空间的应用,不含自有应用
|
||||
shared_app_ids_stmt = (
|
||||
select(AppShare.source_app_id)
|
||||
.where(AppShare.target_workspace_id == workspace_id)
|
||||
.where(AppShare.target_workspace_id == workspace_id, AppShare.is_active.is_(True))
|
||||
)
|
||||
stmt = select(App).where(App.id.in_(shared_app_ids_stmt))
|
||||
elif include_shared:
|
||||
# 查询本工作空间的应用 + 分享给本工作空间的应用
|
||||
shared_app_ids_stmt = (
|
||||
select(AppShare.source_app_id)
|
||||
.where(AppShare.target_workspace_id == workspace_id, AppShare.is_active.is_(True))
|
||||
)
|
||||
|
||||
# 构建主查询:本工作空间的应用 OR 分享的应用
|
||||
stmt = select(App).where(
|
||||
or_(
|
||||
App.workspace_id == workspace_id,
|
||||
@@ -952,7 +1036,7 @@ class AppService:
|
||||
if app.type != "agent":
|
||||
raise BusinessException("只有 Agent 类型应用支持 Agent 配置", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
self._validate_workspace_access(app, workspace_id)
|
||||
self._validate_app_writable(app, workspace_id)
|
||||
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active.is_(True)).order_by(
|
||||
AgentConfig.updated_at.desc())
|
||||
@@ -1163,7 +1247,7 @@ class AppService:
|
||||
if app.type != AppType.WORKFLOW:
|
||||
raise BusinessException("只有 Workflow 类型应用支持 Workflow 配置", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
self._validate_workspace_access(app, workspace_id)
|
||||
self._validate_app_writable(app, workspace_id)
|
||||
|
||||
# 获取现有配置
|
||||
repo = WorkflowConfigRepository(self.db)
|
||||
@@ -1654,7 +1738,8 @@ class AppService:
|
||||
app_id: uuid.UUID,
|
||||
target_workspace_ids: List[uuid.UUID],
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
permission: str = "readonly"
|
||||
) -> list[AppShare]:
|
||||
"""分享应用到其他工作空间
|
||||
|
||||
@@ -1685,6 +1770,14 @@ class AppService:
|
||||
app = self._get_app_or_404(app_id)
|
||||
self._validate_workspace_access(app, workspace_id)
|
||||
|
||||
# 仅允许 agent 和 workflow 类型共享,multi_agent 不支持
|
||||
from app.models.app_model import AppType
|
||||
if app.type == AppType.MULTI_AGENT:
|
||||
raise BusinessException(
|
||||
"集群 Agent 不支持共享应用功能",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
# 2. 验证目标工作空间
|
||||
for target_ws_id in target_workspace_ids:
|
||||
target_ws = self.db.get(Workspace, target_ws_id)
|
||||
@@ -1706,7 +1799,8 @@ class AppService:
|
||||
# 检查是否已经分享过
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app_id,
|
||||
AppShare.target_workspace_id == target_ws_id
|
||||
AppShare.target_workspace_id == target_ws_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
existing_share = self.db.scalars(stmt).first()
|
||||
|
||||
@@ -1725,6 +1819,7 @@ class AppService:
|
||||
source_workspace_id=app.workspace_id,
|
||||
target_workspace_id=target_ws_id,
|
||||
shared_by=user_id,
|
||||
permission=permission,
|
||||
created_at=now,
|
||||
updated_at=now
|
||||
)
|
||||
@@ -1784,7 +1879,8 @@ class AppService:
|
||||
# 2. 查找分享记录
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app_id,
|
||||
AppShare.target_workspace_id == target_workspace_id
|
||||
AppShare.target_workspace_id == target_workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
share = self.db.scalars(stmt).first()
|
||||
|
||||
@@ -1798,8 +1894,8 @@ class AppService:
|
||||
f"app_id={app_id}, target_workspace_id={target_workspace_id}"
|
||||
)
|
||||
|
||||
# 3. 删除分享记录
|
||||
self.db.delete(share)
|
||||
# 3. 逻辑删除分享记录
|
||||
share.is_active = False
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
@@ -1807,6 +1903,48 @@ class AppService:
|
||||
extra={"app_id": str(app_id), "target_workspace_id": str(target_workspace_id)}
|
||||
)
|
||||
|
||||
def unshare_all_apps_to_workspace(
|
||||
self,
|
||||
*,
|
||||
target_workspace_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> int:
|
||||
"""Cancel all app shares from current workspace to a target workspace.
|
||||
|
||||
Args:
|
||||
target_workspace_id: Target workspace ID to cancel all shares to
|
||||
workspace_id: Current workspace ID (source)
|
||||
|
||||
Returns:
|
||||
Number of share records deleted
|
||||
"""
|
||||
from app.models import AppShare
|
||||
|
||||
logger.info(
|
||||
"取消对目标工作空间的所有应用分享",
|
||||
extra={"target_workspace_id": str(target_workspace_id), "workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
# Query active records first for reliable count
|
||||
id_stmt = select(AppShare.id).where(
|
||||
AppShare.source_workspace_id == workspace_id,
|
||||
AppShare.target_workspace_id == target_workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
ids = list(self.db.scalars(id_stmt).all())
|
||||
count = len(ids)
|
||||
|
||||
if ids:
|
||||
# Soft delete: mark as inactive
|
||||
from sqlalchemy import update as sa_update
|
||||
self.db.execute(
|
||||
sa_update(AppShare).where(AppShare.id.in_(ids)).values(is_active=False)
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
logger.info("已取消分享记录数", extra={"count": count})
|
||||
return count
|
||||
|
||||
def list_app_shares(
|
||||
self,
|
||||
*,
|
||||
@@ -1836,7 +1974,8 @@ class AppService:
|
||||
|
||||
# 查询分享记录
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app_id
|
||||
AppShare.source_app_id == app_id,
|
||||
AppShare.is_active.is_(True)
|
||||
).order_by(AppShare.created_at.desc())
|
||||
|
||||
shares = list(self.db.scalars(stmt).all())
|
||||
@@ -1848,6 +1987,166 @@ class AppService:
|
||||
|
||||
return shares
|
||||
|
||||
def remove_shared_app(
|
||||
self,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> None:
|
||||
"""被共享者从自己的工作空间移除共享应用
|
||||
|
||||
只删除共享记录,不影响源应用。
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
workspace_id: 当前工作空间ID(被共享的目标工作空间)
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: 当共享记录不存在时
|
||||
"""
|
||||
from app.models import AppShare
|
||||
|
||||
logger.info(
|
||||
"移除共享应用",
|
||||
extra={"app_id": str(app_id), "workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app_id,
|
||||
AppShare.target_workspace_id == workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
share = self.db.scalars(stmt).first()
|
||||
|
||||
if not share:
|
||||
raise ResourceNotFoundException(
|
||||
"共享记录",
|
||||
f"app_id={app_id}, workspace_id={workspace_id}"
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
share.is_active = False
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
"共享应用已移除",
|
||||
extra={"app_id": str(app_id), "workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
def remove_all_shared_apps_from_workspace(
|
||||
self,
|
||||
*,
|
||||
source_workspace_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> int:
|
||||
"""Remove all shared apps from a specific source workspace.
|
||||
|
||||
Args:
|
||||
source_workspace_id: The workspace that shared the apps
|
||||
workspace_id: Current workspace ID (recipient)
|
||||
|
||||
Returns:
|
||||
Number of share records deleted
|
||||
"""
|
||||
from app.models import AppShare
|
||||
|
||||
logger.info(
|
||||
"批量移除来源工作空间的共享应用",
|
||||
extra={"source_workspace_id": str(source_workspace_id), "workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
# Query active records for reliable count, then soft delete
|
||||
id_stmt = select(AppShare.id).where(
|
||||
AppShare.source_workspace_id == source_workspace_id,
|
||||
AppShare.target_workspace_id == workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
ids = list(self.db.scalars(id_stmt).all())
|
||||
count = len(ids)
|
||||
|
||||
if ids:
|
||||
from sqlalchemy import update as sa_update
|
||||
self.db.execute(
|
||||
sa_update(AppShare).where(AppShare.id.in_(ids)).values(is_active=False)
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
logger.info("已移除共享记录数", extra={"count": count})
|
||||
return count
|
||||
|
||||
def list_my_shared_out(
|
||||
self,
|
||||
*,
|
||||
workspace_id: uuid.UUID
|
||||
) -> List[AppShare]:
|
||||
"""列出本工作空间主动分享出去的所有记录(我的共享)
|
||||
|
||||
Returns:
|
||||
List[AppShare]: 分享记录列表,含源应用信息
|
||||
"""
|
||||
from app.models import AppShare
|
||||
|
||||
stmt = (
|
||||
select(AppShare)
|
||||
.where(
|
||||
AppShare.source_workspace_id == workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
.order_by(AppShare.created_at.desc())
|
||||
)
|
||||
return list(self.db.scalars(stmt).all())
|
||||
def update_share_permission(
|
||||
self,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
target_workspace_id: uuid.UUID,
|
||||
permission: str,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> "AppShare":
|
||||
"""更新共享权限(readonly <-> editable)
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
target_workspace_id: 目标工作空间ID
|
||||
permission: 新权限值 readonly | editable
|
||||
workspace_id: 当前工作空间ID(用于权限验证)
|
||||
|
||||
Returns:
|
||||
AppShare: 更新后的共享记录
|
||||
"""
|
||||
from app.models import AppShare
|
||||
|
||||
if permission not in ("readonly", "editable"):
|
||||
raise BusinessException("权限值无效,只允许 readonly 或 editable", BizCode.INVALID_PARAMETER)
|
||||
|
||||
app = self._get_app_or_404(app_id)
|
||||
self._validate_workspace_access(app, workspace_id)
|
||||
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app_id,
|
||||
AppShare.target_workspace_id == target_workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
share = self.db.scalars(stmt).first()
|
||||
|
||||
if not share:
|
||||
raise ResourceNotFoundException(
|
||||
"共享记录",
|
||||
f"app_id={app_id}, target_workspace_id={target_workspace_id}"
|
||||
)
|
||||
|
||||
share.permission = permission
|
||||
share.updated_at = datetime.datetime.now()
|
||||
self.db.commit()
|
||||
self.db.refresh(share)
|
||||
|
||||
logger.info(
|
||||
"共享权限已更新",
|
||||
extra={"app_id": str(app_id), "target_workspace_id": str(target_workspace_id), "permission": permission}
|
||||
)
|
||||
return share
|
||||
|
||||
|
||||
# ==================== 向后兼容的函数接口 ====================
|
||||
# 保留函数接口以兼容现有代码,但内部使用服务类
|
||||
|
||||
@@ -1942,6 +2241,7 @@ def list_apps(
|
||||
status: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
include_shared: bool = True,
|
||||
shared_only: bool = False,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
) -> Tuple[List[App], int]:
|
||||
@@ -1954,6 +2254,7 @@ def list_apps(
|
||||
status=status,
|
||||
search=search,
|
||||
include_shared=include_shared,
|
||||
shared_only=shared_only,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
)
|
||||
|
||||
@@ -75,7 +75,7 @@ class AudioTranscriptionService:
|
||||
try:
|
||||
# 下载音频文件
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
audio_response = await client.get(audio_url)
|
||||
audio_response = await client.get(audio_url, follow_redirects=True)
|
||||
audio_response.raise_for_status()
|
||||
audio_data = audio_response.content
|
||||
|
||||
|
||||
@@ -80,6 +80,7 @@ def authenticate_user_or_raise(db: Session, email: str, password: str) -> User:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_auth_logger
|
||||
from app.i18n.service import t
|
||||
|
||||
logger = get_auth_logger()
|
||||
|
||||
@@ -87,17 +88,17 @@ def authenticate_user_or_raise(db: Session, email: str, password: str) -> User:
|
||||
user = user_repository.get_user_by_email(db, email=email)
|
||||
if not user:
|
||||
logger.warning(f"用户不存在: {email}")
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 检查用户状态
|
||||
if not user.is_active:
|
||||
logger.warning(f"用户未激活: {email}")
|
||||
raise BusinessException("用户未激活", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.login.account_disabled"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 验证密码
|
||||
if not verify_password(password, user.hashed_password):
|
||||
logger.warning(f"密码错误: {email}")
|
||||
raise BusinessException("密码错误", code=BizCode.PASSWORD_ERROR)
|
||||
raise BusinessException(t("auth.password.incorrect"), code=BizCode.PASSWORD_ERROR)
|
||||
|
||||
logger.info(f"用户认证成功: {email}")
|
||||
return user
|
||||
@@ -254,6 +255,8 @@ def decode_access_token(token: str) -> dict:
|
||||
Raises:
|
||||
BusinessException: token 无效
|
||||
"""
|
||||
from app.i18n.service import t
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, TOKEN_SECRET_KEY, algorithms=[TOKEN_ALGORITHM])
|
||||
return {
|
||||
@@ -261,4 +264,4 @@ def decode_access_token(token: str) -> dict:
|
||||
"share_token": payload["share_token"]
|
||||
}
|
||||
except jwt.InvalidTokenError:
|
||||
raise BusinessException("无效的访问 token", BizCode.INVALID_TOKEN)
|
||||
raise BusinessException(t("auth.token.invalid"), BizCode.INVALID_TOKEN)
|
||||
@@ -23,9 +23,10 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.models import AgentConfig, ModelConfig, ModelType
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services import task_service
|
||||
from app.services.conversation_service import ConversationService
|
||||
@@ -501,9 +502,18 @@ class AgentRunService:
|
||||
processed_files = None
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
@@ -688,7 +698,8 @@ class AgentRunService:
|
||||
conversation_id=conversation_id,
|
||||
app_id=agent_config.app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
sub_agent=sub_agent
|
||||
)
|
||||
|
||||
# 6. 加载历史消息
|
||||
@@ -703,9 +714,18 @@ class AgentRunService:
|
||||
processed_files = None
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
@@ -840,7 +860,8 @@ class AgentRunService:
|
||||
"api_key": api_key.api_key,
|
||||
"api_base": api_key.api_base,
|
||||
"api_key_id": api_key.id,
|
||||
"is_omni": api_key.is_omni
|
||||
"is_omni": api_key.is_omni,
|
||||
"capability": api_key.capability
|
||||
}
|
||||
|
||||
async def _ensure_conversation(
|
||||
@@ -848,7 +869,8 @@ class AgentRunService:
|
||||
conversation_id: Optional[str],
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str]
|
||||
user_id: Optional[str],
|
||||
sub_agent: bool = False
|
||||
) -> str:
|
||||
"""确保会话存在(创建或验证)
|
||||
|
||||
@@ -909,20 +931,36 @@ class AgentRunService:
|
||||
conv_uuid = uuid.UUID(conversation_id)
|
||||
conversation = conversation_service.get_conversation(conv_uuid)
|
||||
|
||||
# 验证会话属于当前工作空间
|
||||
if conversation.workspace_id != workspace_id:
|
||||
logger.warning(
|
||||
"会话不属于当前工作空间",
|
||||
extra={
|
||||
"conversation_id": conversation_id,
|
||||
"conversation_workspace_id": str(conversation.workspace_id),
|
||||
"current_workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
raise BusinessException(
|
||||
"会话不属于当前工作空间",
|
||||
BizCode.PERMISSION_DENIED
|
||||
)
|
||||
# 验证会话属于当前工作空间(或属于共享应用的源工作空间)
|
||||
# sub_agent 内部调用时跳过校验,已在上层验证过
|
||||
if not sub_agent and conversation.workspace_id != workspace_id:
|
||||
# 检查是否是共享应用的会话(被共享者 workspace 访问源应用)
|
||||
from app.models import AppShare
|
||||
from sqlalchemy import select as sa_select
|
||||
share = self.db.scalars(
|
||||
sa_select(AppShare).where(
|
||||
AppShare.source_app_id == app_id,
|
||||
AppShare.target_workspace_id == workspace_id
|
||||
)
|
||||
).first()
|
||||
|
||||
# 情况2:sub_agent 内部调用时,workspace_id 是源应用的 workspace,
|
||||
# 而会话是被共享者创建的,只要会话属于同一个 app 即可放行
|
||||
same_app = (conversation.app_id == app_id)
|
||||
|
||||
if not share and not same_app:
|
||||
logger.warning(
|
||||
"会话不属于当前工作空间",
|
||||
extra={
|
||||
"conversation_id": conversation_id,
|
||||
"conversation_workspace_id": str(conversation.workspace_id),
|
||||
"current_workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
raise BusinessException(
|
||||
"会话不属于当前工作空间",
|
||||
BizCode.PERMISSION_DENIED
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"使用现有会话",
|
||||
|
||||
@@ -274,7 +274,7 @@ class MemoryAgentService:
|
||||
|
||||
Args:
|
||||
end_user_id: Group identifier (also used as end_user_id)
|
||||
message: Message to write
|
||||
messages: Message to write
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
|
||||
@@ -1,19 +1,27 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
import json_repair
|
||||
from jinja2 import Template
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
||||
from app.schemas import FileType
|
||||
from app.schemas.memory_perceptual_schema import (
|
||||
PerceptualQuerySchema,
|
||||
PerceptualTimelineResponse,
|
||||
PerceptualMemoryItem,
|
||||
AudioModal, Content, VideoModal, TextModal
|
||||
)
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
@@ -99,7 +107,7 @@ class MemoryPerceptualService:
|
||||
"keywords": content.keywords,
|
||||
"topic": content.topic,
|
||||
"domain": content.domain,
|
||||
"created_time": int(memory.created_time.timestamp()*1000),
|
||||
"created_time": int(memory.created_time.timestamp() * 1000),
|
||||
**detail
|
||||
}
|
||||
|
||||
@@ -108,7 +116,8 @@ class MemoryPerceptualService:
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}")
|
||||
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
|
||||
exc_info=True)
|
||||
raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
|
||||
BizCode.DB_ERROR)
|
||||
|
||||
@@ -138,7 +147,7 @@ class MemoryPerceptualService:
|
||||
for memory in memories:
|
||||
meta_data = memory.meta_data or {}
|
||||
content = meta_data.get("content", {})
|
||||
|
||||
|
||||
# 安全地提取 content 字段,提供默认值
|
||||
if content:
|
||||
content_obj = Content(**content)
|
||||
@@ -149,7 +158,7 @@ class MemoryPerceptualService:
|
||||
topic = "Unknown"
|
||||
domain = "Unknown"
|
||||
keywords = []
|
||||
|
||||
|
||||
memory_item = PerceptualMemoryItem(
|
||||
id=memory.id,
|
||||
perceptual_type=PerceptualType(memory.perceptual_type),
|
||||
@@ -161,7 +170,7 @@ class MemoryPerceptualService:
|
||||
topic=topic,
|
||||
domain=domain,
|
||||
keywords=keywords,
|
||||
created_time=int(memory.created_time.timestamp()*1000),
|
||||
created_time=int(memory.created_time.timestamp() * 1000),
|
||||
storage_service=FileStorageService(memory.storage_service),
|
||||
)
|
||||
memory_items.append(memory_item)
|
||||
@@ -183,3 +192,98 @@ class MemoryPerceptualService:
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
||||
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
||||
|
||||
async def generate_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model_config: ModelInfo,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict,
|
||||
):
|
||||
memories = self.repository.get_by_url(file_url)
|
||||
if memories:
|
||||
business_logger.info(f"Perceptual memory already exists: {file_url}")
|
||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
||||
memory_cache = memories[0]
|
||||
self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
||||
file_path=memory_cache.file_path,
|
||||
file_name=memory_cache.file_name,
|
||||
file_ext=memory_cache.file_ext,
|
||||
summary=memory_cache.summary,
|
||||
meta_data=memory_cache.meta_data
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
return
|
||||
llm = RedBearLLM(RedBearModelConfig(
|
||||
model_name=model_config.model_name,
|
||||
provider=model_config.provider,
|
||||
api_key=model_config.api_key,
|
||||
base_url=model_config.api_base,
|
||||
is_omni=model_config.is_omni
|
||||
), type=model_config.model_type)
|
||||
try:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
|
||||
except FileNotFoundError:
|
||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||
messages = [
|
||||
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
|
||||
{"role": RoleType.USER.value, "content": [
|
||||
{"type": "text", "text": "Summarize the following file"}, file_message
|
||||
]}
|
||||
]
|
||||
result = await llm.ainvoke(messages)
|
||||
content = json_repair.repair_json(result.content, return_objects=True)
|
||||
path = urlparse(file_url).path
|
||||
filename = os.path.basename(path)
|
||||
filename = unquote(filename)
|
||||
file_ext = os.path.splitext(filename)[1]
|
||||
if not file_ext:
|
||||
if file_type == FileType.AUDIO:
|
||||
file_ext = ".mp3"
|
||||
elif file_type == FileType.VIDEO:
|
||||
file_ext = ".mp4"
|
||||
elif file_type == FileType.DOCUMENT:
|
||||
file_ext = ".txt"
|
||||
elif file_type == FileType.IMAGE:
|
||||
file_ext = ".jpg"
|
||||
filename += file_ext
|
||||
file_content = {
|
||||
"keywords": content.get("keywords", []),
|
||||
"topic": content.get("topic"),
|
||||
"domain": content.get("domain")
|
||||
}
|
||||
if file_type in [FileType.IMAGE, FileType.VIDEO]:
|
||||
file_modalities = {
|
||||
"scene": content.get("scene")
|
||||
}
|
||||
elif file_type in [FileType.DOCUMENT]:
|
||||
file_modalities = {
|
||||
"section_count": content.get("section_count"),
|
||||
"title": content.get("title"),
|
||||
"first_line": content.get("first_line")
|
||||
}
|
||||
else:
|
||||
file_modalities = {
|
||||
"speaker_count": content.get("speaker_count")
|
||||
}
|
||||
self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType.trans_from_file_type(file_type),
|
||||
file_path=file_url,
|
||||
file_name=filename,
|
||||
file_ext=file_ext,
|
||||
summary=content.get('summary'),
|
||||
meta_data={
|
||||
"content": file_content,
|
||||
"modalities": file_modalities
|
||||
}
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
"""
|
||||
import base64
|
||||
import io
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
@@ -23,9 +24,12 @@ from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models import ModelApiKey
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||
from app.tasks import write_perceptual_memory
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -39,6 +43,7 @@ DOC_MIME = [
|
||||
|
||||
class MultimodalFormatStrategy(ABC):
|
||||
"""多模态格式策略基类"""
|
||||
|
||||
def __init__(self, file: FileInput):
|
||||
self.file = file
|
||||
|
||||
@@ -95,7 +100,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
if transcription:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
|
||||
"text": f"<audio url=\"{url}\">\ntext_transcription:{transcription}\n</audio>"
|
||||
}
|
||||
# 通义千问音频格式:{"type": "audio", "audio": "url"}
|
||||
return {
|
||||
@@ -125,7 +130,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
# 下载图片
|
||||
if content is None:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
self.file.set_content(content)
|
||||
@@ -231,7 +236,7 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
audio_data = content
|
||||
if content is None:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
audio_data = response.content
|
||||
self.file.set_content(audio_data)
|
||||
@@ -284,34 +289,56 @@ PROVIDER_STRATEGIES = {
|
||||
|
||||
|
||||
class MultimodalService:
|
||||
"""多模态文件处理服务"""
|
||||
"""
|
||||
Service for handling multimodal file processing.
|
||||
|
||||
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None,
|
||||
enable_audio_transcription: bool = False, is_omni: bool = False):
|
||||
Attributes:
|
||||
db (Session): Database session.
|
||||
model_api_key (str): API key for the model provider.
|
||||
provider (str): Name of the model provider.
|
||||
is_omni (bool): Indicates whether the model supports full multimodal capability.
|
||||
capability (list): Capability configuration of the model.
|
||||
audio_api_key (str | None): API key used for audio transcription.
|
||||
enable_audio_transcription (bool): Whether audio transcription is enabled.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
api_config: ModelInfo | None = None,
|
||||
audio_api_key: Optional[str] = None,
|
||||
enable_audio_transcription: bool = False,
|
||||
):
|
||||
"""
|
||||
初始化多模态服务
|
||||
|
||||
Initialize the multimodal service.
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider: 模型提供商(dashscope, bedrock, anthropic, openai 等)
|
||||
api_key: API 密钥(用于音频转文本)
|
||||
enable_audio_transcription: 是否启用音频转文本
|
||||
is_omni: 是否为 Omni 模型(dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
|
||||
db (Session): Database session.
|
||||
api_config (ModelApiKey | None): Model API configuration.
|
||||
audio_api_key (str | None): API key for audio transcription.
|
||||
enable_audio_transcription (bool): Enable audio transcription.
|
||||
"""
|
||||
self.db = db
|
||||
self.provider = provider.lower()
|
||||
self.api_key = api_key
|
||||
self.api_config = api_config
|
||||
if self.api_config is not None:
|
||||
self.model_api_key = api_config.api_key
|
||||
self.provider = api_config.provider.lower()
|
||||
self.is_omni = api_config.is_omni
|
||||
self.capability = api_config.capability
|
||||
self.audio_api_key = audio_api_key
|
||||
self.enable_audio_transcription = enable_audio_transcription
|
||||
self.is_omni = is_omni
|
||||
|
||||
async def process_files(
|
||||
self,
|
||||
files: Optional[List[FileInput]]
|
||||
end_user_id: uuid.UUID | str,
|
||||
files: Optional[List[FileInput]],
|
||||
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理文件列表,返回 LLM 可用的格式
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
files: 文件输入列表
|
||||
|
||||
Returns:
|
||||
@@ -319,6 +346,8 @@ class MultimodalService:
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
if isinstance(end_user_id, uuid.UUID):
|
||||
end_user_id = str(end_user_id)
|
||||
|
||||
# 获取对应的策略
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||
@@ -333,19 +362,25 @@ class MultimodalService:
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
strategy = strategy_class(file)
|
||||
if not file.url:
|
||||
file.url = await self.get_file_url(file)
|
||||
try:
|
||||
if file.type == FileType.IMAGE:
|
||||
if file.type == FileType.IMAGE and "vision" in self.capability:
|
||||
content = await self._process_image(file, strategy)
|
||||
result.append(content)
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
content = await self._process_document(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.AUDIO:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||
content = await self._process_audio(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.VIDEO:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.VIDEO and "video" in self.capability:
|
||||
content = await self._process_video(file, strategy)
|
||||
result.append(content)
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {file.type}")
|
||||
except Exception as e:
|
||||
@@ -355,7 +390,8 @@ class MultimodalService:
|
||||
"file_index": idx,
|
||||
"file_type": file.type,
|
||||
"error": str(e)
|
||||
}
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
# 继续处理其他文件,不中断整个流程
|
||||
result.append({
|
||||
@@ -366,6 +402,17 @@ class MultimodalService:
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
def write_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict
|
||||
):
|
||||
"""写入感知记忆"""
|
||||
if end_user_id and self.api_config:
|
||||
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
|
||||
|
||||
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理图片文件
|
||||
@@ -387,43 +434,6 @@ class MultimodalService:
|
||||
"text": f"[图片处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _download_and_encode_image(url: str) -> tuple[str, str]:
|
||||
"""
|
||||
下载图片并转换为 base64
|
||||
|
||||
Args:
|
||||
url: 图片 URL
|
||||
|
||||
Returns:
|
||||
tuple: (base64_data, media_type)
|
||||
"""
|
||||
from mimetypes import guess_type
|
||||
|
||||
# 下载图片
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# 获取图片数据
|
||||
image_data = response.content
|
||||
|
||||
# 确定 media type
|
||||
content_type = response.headers.get("content-type")
|
||||
if content_type and content_type.startswith("image/"):
|
||||
media_type = content_type
|
||||
else:
|
||||
# 从 URL 推断
|
||||
guessed_type, _ = guess_type(url)
|
||||
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||
|
||||
# 转换为 base64
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||
|
||||
return base64_data, media_type
|
||||
|
||||
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理文档文件(PDF、Word 等)
|
||||
@@ -436,7 +446,6 @@ class MultimodalService:
|
||||
Dict: 根据 provider 返回不同格式的文档内容
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
# 远程文档暂不支持提取
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
||||
@@ -471,12 +480,12 @@ class MultimodalService:
|
||||
|
||||
# 如果启用音频转文本且有 API Key
|
||||
transcription = None
|
||||
if self.enable_audio_transcription and self.api_key:
|
||||
if self.enable_audio_transcription and self.audio_api_key:
|
||||
logger.info(f"开始音频转文本: {url}")
|
||||
if self.provider == "dashscope":
|
||||
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
|
||||
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.audio_api_key)
|
||||
elif self.provider == "openai":
|
||||
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
|
||||
transcription = await AudioTranscriptionService.transcribe_openai(url, self.audio_api_key)
|
||||
else:
|
||||
logger.warning(f"Provider {self.provider} 不支持音频转文本")
|
||||
|
||||
@@ -557,7 +566,7 @@ class MultimodalService:
|
||||
file_content = file.get_content()
|
||||
if not file_content:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file.url)
|
||||
response = await client.get(file.url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
file_content = response.content
|
||||
file.set_content(file_content)
|
||||
|
||||
53
api/app/services/prompt/perceptual_summary_system.jinja2
Normal file
53
api/app/services/prompt/perceptual_summary_system.jinja2
Normal file
@@ -0,0 +1,53 @@
|
||||
{% raw %}You are a professional information extraction system.
|
||||
|
||||
Your task is to analyze the provided document content and generate structured metadata.
|
||||
|
||||
Extract the following fields:
|
||||
|
||||
* **summary**: A concise summary of the document in 2–4 sentences.
|
||||
* **keywords**: 5–10 important keywords or key phrases that best represent the document. This field MUST be a JSON array of strings.
|
||||
* **topic**: The primary topic of the document expressed as a short phrase (3–8 words).
|
||||
* **domain**: The broader knowledge domain or field the document belongs to (e.g., Artificial Intelligence, Computer Science, Finance, Healthcare, Education, Law, etc.).
|
||||
|
||||
STRICT RULES:
|
||||
|
||||
1. Output MUST be valid JSON.
|
||||
2. Do NOT output markdown.
|
||||
3. Do NOT output explanations.
|
||||
4. Do NOT output any text before or after the JSON.
|
||||
5. The JSON MUST contain EXACTLY these four keys:
|
||||
* summary
|
||||
* keywords
|
||||
* topic
|
||||
* domain{% endraw %}
|
||||
{% if file_type == 'image' or file_type == 'video' %} * scene {% endif %}
|
||||
{% if file_type == 'audio' %} * speaker_count {% endif %}
|
||||
{% if file_type == 'document' %} * section_count
|
||||
* title
|
||||
* first_line
|
||||
{% endif %}
|
||||
{% raw %}
|
||||
6. `keywords` MUST be a JSON array of strings.
|
||||
7. If the document content is insufficient, infer the best possible answer based on context.
|
||||
8. Ensure the JSON is syntactically correct.
|
||||
{% endraw %}
|
||||
9. Output using the language {{ language }}
|
||||
{% raw %}
|
||||
Required JSON format:
|
||||
|
||||
{
|
||||
"summary": "string",
|
||||
"keywords": ["keyword1", "keyword2", "keyword3", "keyword4", "keyword5"],
|
||||
"topic": "string",
|
||||
"domain": "string",
|
||||
{% endraw %}
|
||||
{% if file_type == 'image' or file_type == 'video' %} "scene": ["string", "string"] {% endif %}
|
||||
{% if file_type == 'document' %} "section_count": integer
|
||||
"title": "string",
|
||||
"first_line": "string"
|
||||
{% endif %}
|
||||
{% if file_type == 'audio' %} "speaker_count": integer {% endif %}
|
||||
{% raw %}
|
||||
}
|
||||
|
||||
Now analyze the following document and return the JSON result.{% endraw %}
|
||||
@@ -217,4 +217,55 @@ class TenantService:
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active
|
||||
)
|
||||
)
|
||||
|
||||
def get_tenant_language_config(self, tenant_id: uuid.UUID) -> Optional[dict]:
|
||||
"""获取租户语言配置"""
|
||||
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
|
||||
if not tenant:
|
||||
raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND)
|
||||
|
||||
return {
|
||||
"default_language": tenant.default_language,
|
||||
"supported_languages": tenant.supported_languages
|
||||
}
|
||||
|
||||
def update_tenant_language_config(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
default_language: str,
|
||||
supported_languages: list
|
||||
) -> Optional[dict]:
|
||||
"""更新租户语言配置"""
|
||||
# 检查租户是否存在
|
||||
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
|
||||
if not tenant:
|
||||
raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND)
|
||||
|
||||
# 验证默认语言在支持的语言列表中
|
||||
if default_language not in supported_languages:
|
||||
raise BusinessException(
|
||||
"默认语言必须在支持的语言列表中",
|
||||
code=BizCode.VALIDATION_FAILED
|
||||
)
|
||||
|
||||
try:
|
||||
# 更新语言配置
|
||||
tenant.default_language = default_language
|
||||
tenant.supported_languages = supported_languages
|
||||
self.db.commit()
|
||||
self.db.refresh(tenant)
|
||||
|
||||
business_logger.info(
|
||||
f"更新租户语言配置成功: {tenant.name} (ID: {tenant.id}), "
|
||||
f"默认语言: {default_language}, 支持语言: {supported_languages}"
|
||||
)
|
||||
|
||||
return {
|
||||
"default_language": tenant.default_language,
|
||||
"supported_languages": tenant.supported_languages
|
||||
}
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
business_logger.error(f"更新租户语言配置失败: {str(e)}")
|
||||
raise BusinessException(f"更新租户语言配置失败: {str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
@@ -1727,6 +1727,150 @@ async def analytics_graph_data(
|
||||
|
||||
# 辅助函数
|
||||
|
||||
async def analytics_community_graph_data(
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取社区图谱数据,包含 Community 节点、ExtractedEntity 节点及其关系。
|
||||
|
||||
Returns:
|
||||
包含 nodes、edges、statistics 的字典,格式与 analytics_graph_data 一致
|
||||
"""
|
||||
try:
|
||||
user_uuid = uuid.UUID(end_user_id)
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(user_uuid)
|
||||
if not end_user:
|
||||
return {
|
||||
"nodes": [], "edges": [],
|
||||
"statistics": {"total_nodes": 0, "total_edges": 0, "node_types": {}, "edge_types": {}},
|
||||
"message": "用户不存在"
|
||||
}
|
||||
|
||||
# 查询社区节点、实体节点、BELONGS_TO_COMMUNITY 边、实体间关系
|
||||
from app.repositories.neo4j.cypher_queries import GET_COMMUNITY_GRAPH_DATA
|
||||
rows = await _neo4j_connector.execute_query(GET_COMMUNITY_GRAPH_DATA, end_user_id=end_user_id)
|
||||
|
||||
nodes_map: Dict[str, dict] = {}
|
||||
edges_map: Dict[str, dict] = {}
|
||||
# 记录每个 Community 对应的实体 id 列表
|
||||
community_members: Dict[str, list] = {}
|
||||
|
||||
for row in rows:
|
||||
# Community 节点
|
||||
c_id = row["c_id"]
|
||||
if c_id and c_id not in nodes_map:
|
||||
raw = row["c_props"] or {}
|
||||
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
|
||||
"community_id", "end_user_id", "member_count", "updated_at",
|
||||
"name", "summary", "core_entities",
|
||||
) if k in raw}
|
||||
nodes_map[c_id] = {
|
||||
"id": c_id,
|
||||
"label": "Community",
|
||||
"properties": props,
|
||||
}
|
||||
|
||||
# ExtractedEntity 节点 (e)
|
||||
e_id = row["e_id"]
|
||||
if e_id and e_id not in nodes_map:
|
||||
raw = row["e_props"] or {}
|
||||
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
|
||||
"name", "end_user_id", "description", "created_at", "entity_type",
|
||||
) if k in raw}
|
||||
# 注入所属社区名称(c 是 e 直接归属的社区)
|
||||
c_raw = row["c_props"] or {}
|
||||
props["community_name"] = _clean_neo4j_value(c_raw.get("name")) or ""
|
||||
nodes_map[e_id] = {
|
||||
"id": e_id,
|
||||
"label": "ExtractedEntity",
|
||||
"properties": props,
|
||||
}
|
||||
|
||||
# ExtractedEntity 节点 (e2,可选)
|
||||
e2_id = row.get("e2_id")
|
||||
if e2_id and e2_id not in nodes_map:
|
||||
raw = row["e2_props"] or {}
|
||||
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
|
||||
"name", "end_user_id", "description", "created_at", "entity_type",
|
||||
) if k in raw}
|
||||
# e2 的社区归属在后处理阶段通过 community_members 补充
|
||||
props["community_name"] = ""
|
||||
nodes_map[e2_id] = {
|
||||
"id": e2_id,
|
||||
"label": "ExtractedEntity",
|
||||
"properties": props,
|
||||
}
|
||||
|
||||
# BELONGS_TO_COMMUNITY 边
|
||||
b_id = row["b_id"]
|
||||
if b_id and b_id not in edges_map:
|
||||
edges_map[b_id] = {
|
||||
"id": b_id,
|
||||
"source": e_id,
|
||||
"target": c_id,
|
||||
}
|
||||
# 收集社区成员 id
|
||||
if c_id and e_id:
|
||||
community_members.setdefault(c_id, [])
|
||||
if e_id not in community_members[c_id]:
|
||||
community_members[c_id].append(e_id)
|
||||
|
||||
# EXTRACTED_RELATIONSHIP 边(可选)
|
||||
r_id = row.get("r_id")
|
||||
if r_id and r_id not in edges_map and e2_id:
|
||||
r_props = {k: _clean_neo4j_value(v) for k, v in (row["r_props"] or {}).items()}
|
||||
source = e_id if row.get("r_from_e") else e2_id
|
||||
target = e2_id if row.get("r_from_e") else e_id
|
||||
edges_map[r_id] = {
|
||||
"id": r_id,
|
||||
"source": source,
|
||||
"target": target,
|
||||
}
|
||||
|
||||
nodes = list(nodes_map.values())
|
||||
edges = list(edges_map.values())
|
||||
|
||||
# 为每个 Community 节点注入 member_entity_ids,同时补全 e2 节点的 community_name
|
||||
for c_id, member_ids in community_members.items():
|
||||
c_node = nodes_map.get(c_id)
|
||||
if c_node:
|
||||
c_node["properties"]["member_entity_ids"] = member_ids
|
||||
c_name = c_node["properties"].get("name") or ""
|
||||
# 补全属于该社区但 community_name 为空的实体(即 e2 节点)
|
||||
for eid in member_ids:
|
||||
e_node = nodes_map.get(eid)
|
||||
if e_node and e_node["label"] == "ExtractedEntity":
|
||||
if not e_node["properties"].get("community_name"):
|
||||
e_node["properties"]["community_name"] = c_name
|
||||
|
||||
node_type_counts: Dict[str, int] = {}
|
||||
for n in nodes:
|
||||
node_type_counts[n["label"]] = node_type_counts.get(n["label"], 0) + 1
|
||||
|
||||
return {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"statistics": {
|
||||
"total_nodes": len(nodes),
|
||||
"total_edges": len(edges),
|
||||
"node_types": node_type_counts,
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError:
|
||||
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
|
||||
return {
|
||||
"nodes": [], "edges": [],
|
||||
"statistics": {"total_nodes": 0, "total_edges": 0, "node_types": {}, "edge_types": {}},
|
||||
"message": "无效的用户ID格式"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取社区图谱数据失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def _extract_node_properties(label: str, properties: Dict[str, Any],node_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
根据节点类型提取需要的属性字段
|
||||
|
||||
@@ -438,24 +438,26 @@ def update_last_login_time(db: Session, user_id: uuid.UUID) -> User:
|
||||
|
||||
async def change_password(db: Session, user_id: uuid.UUID, old_password: str, new_password: str, current_user: User) -> User:
|
||||
"""普通用户修改自己的密码"""
|
||||
from app.i18n.service import t
|
||||
|
||||
business_logger.info(f"用户修改密码请求: user_id={user_id}, current_user={current_user.id}")
|
||||
|
||||
# 检查权限:只能修改自己的密码
|
||||
if current_user.id != user_id:
|
||||
business_logger.warning(f"用户尝试修改他人密码: current_user={current_user.id}, target_user={user_id}")
|
||||
raise PermissionDeniedException("You can only change your own password")
|
||||
raise PermissionDeniedException(t("auth.password.change_failed"))
|
||||
|
||||
try:
|
||||
# 获取用户
|
||||
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
if not db_user:
|
||||
business_logger.warning(f"用户不存在: {user_id}")
|
||||
raise BusinessException("User not found", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 验证旧密码
|
||||
if not verify_password(old_password, db_user.hashed_password):
|
||||
business_logger.warning(f"用户旧密码验证失败: {user_id}")
|
||||
raise BusinessException("当前密码不正确", code=BizCode.VALIDATION_FAILED)
|
||||
raise BusinessException(t("auth.password.incorrect"), code=BizCode.VALIDATION_FAILED)
|
||||
|
||||
# 更新密码
|
||||
db_user.hashed_password = get_password_hash(new_password)
|
||||
@@ -471,7 +473,7 @@ async def change_password(db: Session, user_id: uuid.UUID, old_password: str, ne
|
||||
except Exception as e:
|
||||
business_logger.error(f"修改用户密码失败: user_id={user_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise BusinessException(f"修改用户密码失败: user_id={user_id} - {str(e)}", code=BizCode.DB_ERROR)
|
||||
raise BusinessException(t("auth.password.change_failed"), code=BizCode.DB_ERROR)
|
||||
|
||||
|
||||
async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_password: str = None, current_user: User = None) -> tuple[User, str]:
|
||||
@@ -487,6 +489,8 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
|
||||
Returns:
|
||||
tuple[User, str]: (更新后的用户对象, 实际使用的密码)
|
||||
"""
|
||||
from app.i18n.service import t
|
||||
|
||||
business_logger.info(f"管理员修改用户密码请求: admin={current_user.id}, target_user={target_user_id}")
|
||||
|
||||
# 检查权限:只有超级管理员可以修改他人密码
|
||||
@@ -496,7 +500,7 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
|
||||
try:
|
||||
permission_service.check_superuser(
|
||||
subject,
|
||||
error_message="只有超级管理员可以修改他人密码"
|
||||
error_message=t("auth.password.change_failed")
|
||||
)
|
||||
except PermissionDeniedException as e:
|
||||
business_logger.warning(f"非超管用户尝试修改他人密码: current_user={current_user.id}")
|
||||
@@ -507,12 +511,12 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
|
||||
target_user = user_repository.get_user_by_id(db=db, user_id=target_user_id)
|
||||
if not target_user:
|
||||
business_logger.warning(f"目标用户不存在: {target_user_id}")
|
||||
raise BusinessException("目标用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 检查租户权限:超管只能修改同租户用户的密码
|
||||
if current_user.tenant_id != target_user.tenant_id:
|
||||
business_logger.warning(f"跨租户密码修改尝试: admin_tenant={current_user.tenant_id}, target_tenant={target_user.tenant_id}")
|
||||
raise BusinessException("不可跨租户修改用户密码", code=BizCode.FORBIDDEN)
|
||||
raise BusinessException(t("auth.password.change_failed"), code=BizCode.FORBIDDEN)
|
||||
|
||||
# 如果没有提供新密码,则生成随机密码
|
||||
actual_password = new_password if new_password else generate_random_password()
|
||||
@@ -532,7 +536,7 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
|
||||
except Exception as e:
|
||||
business_logger.error(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise BusinessException(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}", code=BizCode.DB_ERROR)
|
||||
raise BusinessException(t("auth.password.change_failed"), code=BizCode.DB_ERROR)
|
||||
|
||||
|
||||
def generate_random_password(length: int = 12) -> str:
|
||||
@@ -740,3 +744,54 @@ async def verify_and_change_email(db: Session, user_id: uuid.UUID, new_email: Em
|
||||
#
|
||||
# business_logger.info(f"用户邮箱修改成功: {db_user.username}, new_email={new_email}")
|
||||
# return db_user
|
||||
|
||||
|
||||
def get_user_language_preference(db: Session, user_id: uuid.UUID, current_user: User) -> str:
|
||||
"""获取用户语言偏好"""
|
||||
business_logger.info(f"获取用户语言偏好: user_id={user_id}")
|
||||
|
||||
# 权限检查:只能获取自己的语言偏好
|
||||
if current_user.id != user_id:
|
||||
raise PermissionDeniedException("只能获取自己的语言偏好")
|
||||
|
||||
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
if not db_user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
language = db_user.preferred_language or "zh"
|
||||
business_logger.info(f"用户语言偏好: {db_user.username}, language={language}")
|
||||
return language
|
||||
|
||||
|
||||
def update_user_language_preference(
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
language: str,
|
||||
current_user: User
|
||||
) -> User:
|
||||
"""更新用户语言偏好"""
|
||||
business_logger.info(f"更新用户语言偏好: user_id={user_id}, language={language}")
|
||||
|
||||
# 权限检查:只能修改自己的语言偏好
|
||||
if current_user.id != user_id:
|
||||
raise PermissionDeniedException("只能修改自己的语言偏好")
|
||||
|
||||
# 验证语言代码是否支持
|
||||
from app.core.config import settings
|
||||
if language not in settings.I18N_SUPPORTED_LANGUAGES:
|
||||
raise BusinessException(
|
||||
f"不支持的语言代码: {language}。支持的语言: {', '.join(settings.I18N_SUPPORTED_LANGUAGES)}",
|
||||
code=BizCode.VALIDATION_FAILED
|
||||
)
|
||||
|
||||
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
if not db_user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 更新语言偏好
|
||||
db_user.preferred_language = language
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
business_logger.info(f"用户语言偏好更新成功: {db_user.username}, language={language}")
|
||||
return db_user
|
||||
|
||||
523
api/app/tasks.py
523
api/app/tasks.py
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -11,20 +10,48 @@ from datetime import datetime, timezone
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import redis
|
||||
import requests
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Import a unified Celery instance
|
||||
from app.celery_app import celery_app
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.rag.crawler.web_crawler import WebCrawler
|
||||
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
|
||||
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
|
||||
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
||||
from app.core.rag.integrations.feishu.models import FileInfo
|
||||
from app.core.rag.integrations.yuque.client import YuqueAPIClient
|
||||
from app.core.rag.integrations.yuque.models import YuqueDocInfo
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.prompts.generator import question_proposal
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models import Document, File, Knowledge
|
||||
from app.schemas import document_schema, file_schema
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
from app.utils.redis_lock import RedisLock
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 模块级同步 Redis 连接池,供 Celery 任务共享使用
|
||||
# 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致
|
||||
# 使用连接池而非单例客户端,提供更好的并发性能和自动重连
|
||||
_sync_redis_pool: redis.ConnectionPool = None
|
||||
_sync_redis_pool: redis.ConnectionPool | None = None
|
||||
|
||||
def _get_or_create_redis_pool() -> redis.ConnectionPool:
|
||||
|
||||
def _get_or_create_redis_pool() -> redis.ConnectionPool | None:
|
||||
"""获取或创建 Redis 连接池(懒初始化)"""
|
||||
global _sync_redis_pool
|
||||
if _sync_redis_pool is None:
|
||||
@@ -47,6 +74,7 @@ def _get_or_create_redis_pool() -> redis.ConnectionPool:
|
||||
return None
|
||||
return _sync_redis_pool
|
||||
|
||||
|
||||
def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
||||
"""获取同步 Redis 客户端(使用连接池)
|
||||
|
||||
@@ -60,7 +88,7 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
||||
pool = _get_or_create_redis_pool()
|
||||
if pool is None:
|
||||
return None
|
||||
|
||||
|
||||
client = redis.StrictRedis(connection_pool=pool)
|
||||
# 验证连接可用性
|
||||
client.ping()
|
||||
@@ -72,32 +100,18 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
||||
logger.error(f"Unexpected error getting Redis client: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# Import a unified Celery instance
|
||||
from app.celery_app import celery_app
|
||||
from app.core.config import settings
|
||||
from app.core.rag.crawler.web_crawler import WebCrawler
|
||||
from app.core.rag.graphrag.general.index import init_graphrag, run_graphrag_for_kb
|
||||
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
|
||||
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
||||
from app.core.rag.integrations.feishu.models import FileInfo
|
||||
from app.core.rag.integrations.yuque.client import YuqueAPIClient
|
||||
from app.core.rag.integrations.yuque.models import YuqueDocInfo
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.prompts.generator import question_proposal
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models.document_model import Document
|
||||
from app.models.file_model import File
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.schemas import document_schema, file_schema
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
def set_asyncio_event_loop():
|
||||
"""Set the asyncio event loop for the current thread."""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.process_item")
|
||||
@@ -294,9 +308,18 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
vector_size = len(vts[0])
|
||||
init_graphrag(task, vector_size)
|
||||
|
||||
async def _run(row: dict, document_ids: list[str], language: str, parser_config: dict, vector_service,
|
||||
chat_model, embedding_model, callback, with_resolution: bool = True,
|
||||
with_community: bool = True, ) -> dict:
|
||||
async def _run(
|
||||
row: dict,
|
||||
document_ids: list[str],
|
||||
language: str,
|
||||
parser_config: dict,
|
||||
vector_service,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
with_resolution: bool = True,
|
||||
with_community: bool = True
|
||||
) -> dict:
|
||||
await trio.sleep(5) # Delay for 10 seconds
|
||||
nonlocal progress_msg # Declare the use of an external progress_msg variable
|
||||
result = await run_graphrag_for_kb(
|
||||
@@ -329,6 +352,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
with_community=with_community,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(sync_task)
|
||||
@@ -448,6 +472,7 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
with_community=with_community,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(sync_task)
|
||||
@@ -1002,29 +1027,21 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
||||
# Log but continue - will fail later with proper error
|
||||
pass
|
||||
|
||||
async def _run() -> str:
|
||||
async def _run() -> dict:
|
||||
with get_db_context() as db:
|
||||
service = MemoryAgentService()
|
||||
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db,
|
||||
storage_type, user_rag_memory_id)
|
||||
return await service.read_memory(
|
||||
end_user_id,
|
||||
message,
|
||||
history,
|
||||
search_switch,
|
||||
actual_config_id, db,
|
||||
storage_type, user_rag_memory_id
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1056,7 +1073,8 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
|
||||
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, user_rag_memory_id: str,
|
||||
def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str,
|
||||
user_rag_memory_id: str,
|
||||
language: str = "zh") -> Dict[str, Any]:
|
||||
"""Celery task to process a write message via MemoryAgentService.
|
||||
Args:
|
||||
@@ -1073,10 +1091,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
Raises:
|
||||
Exception on failure
|
||||
"""
|
||||
from app.core.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id} (type: {type(config_id).__name__}), storage_type={storage_type}, language={language}")
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, "
|
||||
f"config_id={config_id} (type: {type(config_id).__name__}), "
|
||||
f"storage_type={storage_type}, language={language}")
|
||||
start_time = time.time()
|
||||
|
||||
# Convert config_id to UUID
|
||||
@@ -1086,13 +1105,14 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
actual_config_id = resolve_config_id(config_id, db)
|
||||
print(100*'-')
|
||||
print(100 * '-')
|
||||
print(actual_config_id)
|
||||
print(100*'-')
|
||||
print(100 * '-')
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}")
|
||||
logger.error(
|
||||
f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": f"Invalid config_id format: {config_id} - {str(e)}",
|
||||
@@ -1116,7 +1136,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
async def _run() -> str:
|
||||
with get_db_context() as db:
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory "
|
||||
f"with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||
service = MemoryAgentService()
|
||||
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type,
|
||||
user_rag_memory_id, language)
|
||||
@@ -1124,22 +1145,8 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
return result
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1193,28 +1200,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
}
|
||||
|
||||
|
||||
def reflection_engine() -> None:
|
||||
"""Empty function placeholder for timed background reflection.
|
||||
|
||||
Intentionally left blank; replace with real reflection logic later.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
|
||||
|
||||
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
|
||||
asyncio.run(self_reflexion(host_id))
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.reflection.timer")
|
||||
def reflection_timer_task() -> None:
|
||||
"""Periodic Celery task that invokes reflection_engine.
|
||||
|
||||
Raises an exception on failure.
|
||||
"""
|
||||
reflection_engine()
|
||||
|
||||
|
||||
# unused task
|
||||
# @celery_app.task(name="app.core.memory.agent.health.check_read_service")
|
||||
# def check_read_service_task() -> Dict[str, str]:
|
||||
@@ -1368,6 +1353,8 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
||||
"workspace_id": workspace_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.write_all_workspaces_memory_task",
|
||||
bind=True,
|
||||
@@ -1391,15 +1378,12 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.repositories.memory_increment_repository import write_memory_increment
|
||||
from app.services.memory_storage_service import search_all
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有活跃的工作空间
|
||||
@@ -1408,7 +1392,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
).all()
|
||||
|
||||
if not workspaces:
|
||||
api_logger.warning("没有找到活跃的工作空间")
|
||||
logger.warning("没有找到活跃的工作空间")
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": "没有找到活跃的工作空间",
|
||||
@@ -1416,13 +1400,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
"workspace_results": []
|
||||
}
|
||||
|
||||
api_logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
|
||||
logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
|
||||
all_workspace_results = []
|
||||
|
||||
# 遍历每个工作空间
|
||||
for workspace in workspaces:
|
||||
workspace_id = workspace.id
|
||||
api_logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
|
||||
logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
|
||||
|
||||
try:
|
||||
# 1. 查询当前workspace下的所有app(仅未删除的)
|
||||
@@ -1447,7 +1431,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
"memory_increment_id": str(memory_increment.id),
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
})
|
||||
api_logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0")
|
||||
logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0")
|
||||
continue
|
||||
|
||||
# 2. 查询所有app下的end_user_id(去重)
|
||||
@@ -1472,7 +1456,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
})
|
||||
except Exception as e:
|
||||
# 记录单个用户查询失败,但继续处理其他用户
|
||||
api_logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
|
||||
logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
|
||||
end_user_details.append({
|
||||
"end_user_id": str(end_user_id),
|
||||
"total": 0,
|
||||
@@ -1496,13 +1480,13 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
})
|
||||
|
||||
api_logger.info(
|
||||
logger.info(
|
||||
f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间
|
||||
api_logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
|
||||
logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
|
||||
all_workspace_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"workspace_name": workspace.name,
|
||||
@@ -1525,7 +1509,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆增量统计任务执行失败: {str(e)}")
|
||||
logger.error(f"记忆增量统计任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
@@ -1534,22 +1518,8 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1597,11 +1567,9 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.services.user_memory_service import UserMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("开始执行记忆缓存重新生成定时任务")
|
||||
|
||||
service = UserMemoryService()
|
||||
@@ -1734,22 +1702,8 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1785,15 +1739,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.services.memory_reflection_service import (
|
||||
MemoryReflectionService,
|
||||
WorkspaceAppService,
|
||||
)
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有工作空间
|
||||
@@ -1812,7 +1763,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
# 遍历每个工作空间
|
||||
for workspace in workspaces:
|
||||
workspace_id = workspace.id
|
||||
api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}")
|
||||
logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}")
|
||||
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
@@ -1824,7 +1775,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
workspace_reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['memory_configs'] == []:
|
||||
if not data['memory_configs']:
|
||||
continue
|
||||
|
||||
releases = data['releases']
|
||||
@@ -1835,7 +1786,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str(
|
||||
user['app_id']):
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_reflection_from_data(
|
||||
config_data=config,
|
||||
@@ -1855,12 +1806,12 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
"reflection_results": workspace_reflection_results
|
||||
})
|
||||
|
||||
api_logger.info(
|
||||
logger.info(
|
||||
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback() # Rollback failed transaction to allow next query
|
||||
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
||||
logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
||||
all_reflection_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"error": str(e),
|
||||
@@ -1879,7 +1830,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"工作空间反思任务执行失败: {str(e)}")
|
||||
logger.error(f"工作空间反思任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
@@ -1888,22 +1839,8 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -1944,18 +1881,16 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
api_logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}")
|
||||
logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}")
|
||||
|
||||
forget_service = MemoryForgetService()
|
||||
|
||||
# 运行遗忘周期
|
||||
# FIXME: MemeoryForgetService
|
||||
report = await forget_service.trigger_forgetting(
|
||||
db=db,
|
||||
end_user_id=None, # 处理所有组
|
||||
@@ -1964,7 +1899,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
api_logger.info(
|
||||
logger.info(
|
||||
f"遗忘周期定时任务完成: "
|
||||
f"融合 {report['merged_count']} 对节点, "
|
||||
f"失败 {report['failed_count']} 对, "
|
||||
@@ -1980,7 +1915,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
api_logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
|
||||
logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"status": "FAILED",
|
||||
@@ -1997,6 +1932,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Long-term Memory Storage Tasks (Batched Write Strategies)
|
||||
# =============================================================================
|
||||
@@ -2222,9 +2158,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
from app.repositories.implicit_emotions_storage_repository import (
|
||||
ImplicitEmotionsStorageRepository,
|
||||
@@ -2233,7 +2168,6 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("开始执行隐性记忆和情绪数据更新定时任务")
|
||||
|
||||
total_users = 0
|
||||
@@ -2267,7 +2201,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
for end_user_id in refresh_iter:
|
||||
logger.info(f"开始处理用户: {end_user_id}")
|
||||
user_start_time = time.time()
|
||||
|
||||
|
||||
implicit_success = False
|
||||
emotion_success = False
|
||||
errors = []
|
||||
@@ -2318,7 +2252,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
failed += 1
|
||||
|
||||
user_elapsed = time.time() - user_start_time
|
||||
|
||||
|
||||
# 记录用户处理结果
|
||||
user_result = {
|
||||
"end_user_id": end_user_id,
|
||||
@@ -2460,22 +2394,8 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -2521,14 +2441,12 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str,
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.implicit_emotions_storage_repository import (
|
||||
ImplicitEmotionsStorageRepository,
|
||||
)
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
@@ -2587,20 +2505,7 @@ def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str,
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
@@ -2633,6 +2538,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
默认生成中文(zh)兴趣分布数据。
|
||||
|
||||
Args:
|
||||
self: task object
|
||||
end_user_ids: 需要检查的用户ID列表
|
||||
|
||||
Returns:
|
||||
@@ -2641,11 +2547,9 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
@@ -2694,20 +2598,7 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
@@ -2720,3 +2611,185 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.write_perceptual_memory",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def write_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model_api_config: dict,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict
|
||||
):
|
||||
"""
|
||||
Write perceptual memory for a user into PostgreSQL and Neo4j.
|
||||
|
||||
This task generates or updates the user's perceptual memory
|
||||
in the backend databases. It is intended to be executed asynchronously
|
||||
via Celery.
|
||||
|
||||
Args:
|
||||
end_user_id (uuid.UUID): The unique identifier of the end user.
|
||||
model_api_config (ModelInfo): API configuration for the model
|
||||
used to generate perceptual memory.
|
||||
file_type (str): The file type
|
||||
file_url (url): The url of file
|
||||
file_message (dict): The file message containing details about the file
|
||||
to be processed.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
|
||||
set_asyncio_event_loop()
|
||||
with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
|
||||
model_info = ModelInfo(**model_api_config)
|
||||
with get_db_context() as db:
|
||||
memory_perceptual_service = MemoryPerceptualService(db)
|
||||
return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
|
||||
end_user_id,
|
||||
model_info,
|
||||
file_type,
|
||||
file_url,
|
||||
file_message,
|
||||
))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 社区聚类补全任务(触发型)
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.init_community_clustering_for_users",
|
||||
bind=True,
|
||||
ignore_result=False,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=7200, # 2小时硬超时
|
||||
soft_time_limit=6900,
|
||||
)
|
||||
def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。
|
||||
|
||||
由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。
|
||||
|
||||
Args:
|
||||
end_user_ids: 需要检查的用户 ID 列表
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.neo4j.community_repository import CommunityRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"[CommunityCluster] 开始社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
repo = CommunityRepository(connector)
|
||||
|
||||
# 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置)
|
||||
user_llm_map: Dict[str, Optional[str]] = {}
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
batch_configs = get_end_users_connected_configs_batch(end_user_ids, db)
|
||||
for uid, cfg_info in batch_configs.items():
|
||||
config_id = cfg_info.get("memory_config_id")
|
||||
if config_id:
|
||||
try:
|
||||
cfg = MemoryConfigService(db).load_memory_config(config_id=config_id)
|
||||
user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None
|
||||
except Exception as e:
|
||||
logger.warning(f"[CommunityCluster] 用户 {uid} 加载 LLM 配置失败,将使用 None: {e}")
|
||||
user_llm_map[uid] = None
|
||||
else:
|
||||
user_llm_map[uid] = None
|
||||
except Exception as e:
|
||||
logger.warning(f"[CommunityCluster] 批量获取 LLM 配置失败,所有用户将使用 None: {e}")
|
||||
|
||||
for end_user_id in end_user_ids:
|
||||
try:
|
||||
# 已有社区节点则跳过
|
||||
has_communities = await repo.has_communities(end_user_id)
|
||||
if has_communities:
|
||||
skipped += 1
|
||||
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过")
|
||||
continue
|
||||
|
||||
# 检查是否有 ExtractedEntity 节点
|
||||
entities = await repo.get_all_entities(end_user_id)
|
||||
if not entities:
|
||||
skipped += 1
|
||||
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过")
|
||||
continue
|
||||
|
||||
# 每个用户使用自己的 llm_model_id
|
||||
llm_model_id = user_llm_map.get(end_user_id)
|
||||
engine = LabelPropagationEngine(
|
||||
connector=connector,
|
||||
llm_model_id=llm_model_id,
|
||||
)
|
||||
|
||||
logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}")
|
||||
await engine.full_clustering(end_user_id)
|
||||
initialized += 1
|
||||
logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成")
|
||||
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"[CommunityCluster] 用户 {end_user_id} 聚类失败: {e}")
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
logger.info(
|
||||
f"[CommunityCluster] 任务完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}"
|
||||
)
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"initialized": initialized,
|
||||
"skipped": skipped,
|
||||
"failed": failed,
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
loop = set_asyncio_event_loop()
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
result["task_id"] = self.request.id
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user