Merge branch 'develop' of https://github.com/SuanmoSuanyangTechnology/MemoryBear into feature/app-share-wxy

This commit is contained in:
wxy
2026-03-13 17:24:20 +08:00
100 changed files with 8956 additions and 1123 deletions

View File

@@ -8,25 +8,21 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig
from app.models import MultiAgentConfig, AgentConfig, ModelType
from app.models import WorkflowConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas import DraftRunRequest
from app.schemas.app_schema import FileInput
from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \
AgentRunService
from app.services.draft_run_service import create_web_search_tool
from app.services.draft_run_service import AgentRunService
from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.multimodal_service import MultimodalService
from app.services.tool_service import ToolService
from app.services.workflow_service import WorkflowService
logger = get_business_logger()
@@ -126,8 +122,17 @@ class AppChatService:
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
processed_files = await multimodal_service.process_files(files)
model_info = ModelInfo(
model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
api_key=api_key_obj.api_key,
api_base=api_key_obj.api_base,
capability=api_key_obj.capability,
is_omni=api_key_obj.is_omni,
model_type=ModelType.LLM
)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件")
# 调用 Agent支持多模态
@@ -266,8 +271,17 @@ class AppChatService:
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
processed_files = await multimodal_service.process_files(files)
model_info = ModelInfo(
model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
api_key=api_key_obj.api_key,
api_base=api_key_obj.api_base,
capability=api_key_obj.capability,
is_omni=api_key_obj.is_omni,
model_type=ModelType.LLM
)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件")
# 流式调用 Agent支持多模态

View File

@@ -75,7 +75,7 @@ class AudioTranscriptionService:
try:
# 下载音频文件
async with httpx.AsyncClient(timeout=60.0) as client:
audio_response = await client.get(audio_url)
audio_response = await client.get(audio_url, follow_redirects=True)
audio_response.raise_for_status()
audio_data = audio_response.content

View File

@@ -80,6 +80,7 @@ def authenticate_user_or_raise(db: Session, email: str, password: str) -> User:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.core.logging_config import get_auth_logger
from app.i18n.service import t
logger = get_auth_logger()
@@ -87,17 +88,17 @@ def authenticate_user_or_raise(db: Session, email: str, password: str) -> User:
user = user_repository.get_user_by_email(db, email=email)
if not user:
logger.warning(f"用户不存在: {email}")
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
# 检查用户状态
if not user.is_active:
logger.warning(f"用户未激活: {email}")
raise BusinessException("用户未激活", code=BizCode.USER_NOT_FOUND)
raise BusinessException(t("auth.login.account_disabled"), code=BizCode.USER_NOT_FOUND)
# 验证密码
if not verify_password(password, user.hashed_password):
logger.warning(f"密码错误: {email}")
raise BusinessException("密码错误", code=BizCode.PASSWORD_ERROR)
raise BusinessException(t("auth.password.incorrect"), code=BizCode.PASSWORD_ERROR)
logger.info(f"用户认证成功: {email}")
return user
@@ -254,6 +255,8 @@ def decode_access_token(token: str) -> dict:
Raises:
BusinessException: token 无效
"""
from app.i18n.service import t
try:
payload = jwt.decode(token, TOKEN_SECRET_KEY, algorithms=[TOKEN_ALGORITHM])
return {
@@ -261,4 +264,4 @@ def decode_access_token(token: str) -> dict:
"share_token": payload["share_token"]
}
except jwt.InvalidTokenError:
raise BusinessException("无效的访问 token", BizCode.INVALID_TOKEN)
raise BusinessException(t("auth.token.invalid"), BizCode.INVALID_TOKEN)

View File

@@ -23,9 +23,10 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.models import AgentConfig, ModelConfig
from app.models import AgentConfig, ModelConfig, ModelType
from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput
from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services import task_service
from app.services.conversation_service import ConversationService
@@ -501,9 +502,18 @@ class AgentRunService:
processed_files = None
if files:
# 获取 provider 信息
model_info = ModelInfo(
model_name=api_key_config["model_name"],
provider=api_key_config["provider"],
api_key=api_key_config["api_key"],
api_base=api_key_config["api_base"],
capability=api_key_config["capability"],
is_omni=api_key_config["is_omni"],
model_type=ModelType.LLM
)
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
processed_files = await multimodal_service.process_files(files)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索
@@ -704,9 +714,18 @@ class AgentRunService:
processed_files = None
if files:
# 获取 provider 信息
model_info = ModelInfo(
model_name=api_key_config["model_name"],
provider=api_key_config["provider"],
api_key=api_key_config["api_key"],
api_base=api_key_config["api_base"],
capability=api_key_config["capability"],
is_omni=api_key_config["is_omni"],
model_type=ModelType.LLM
)
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
processed_files = await multimodal_service.process_files(files)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索
@@ -841,7 +860,8 @@ class AgentRunService:
"api_key": api_key.api_key,
"api_base": api_key.api_base,
"api_key_id": api_key.id,
"is_omni": api_key.is_omni
"is_omni": api_key.is_omni,
"capability": api_key.capability
}
async def _ensure_conversation(

View File

@@ -274,7 +274,7 @@ class MemoryAgentService:
Args:
end_user_id: Group identifier (also used as end_user_id)
message: Message to write
messages: Message to write
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)

View File

@@ -1,19 +1,27 @@
import os
import uuid
from typing import Dict, Any, Optional
from urllib.parse import urlparse, unquote
import json_repair
from jinja2 import Template
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
from app.models.prompt_optimizer_model import RoleType
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
from app.schemas import FileType
from app.schemas.memory_perceptual_schema import (
PerceptualQuerySchema,
PerceptualTimelineResponse,
PerceptualMemoryItem,
AudioModal, Content, VideoModal, TextModal
)
from app.schemas.model_schema import ModelInfo
business_logger = get_business_logger()
@@ -99,7 +107,7 @@ class MemoryPerceptualService:
"keywords": content.keywords,
"topic": content.topic,
"domain": content.domain,
"created_time": int(memory.created_time.timestamp()*1000),
"created_time": int(memory.created_time.timestamp() * 1000),
**detail
}
@@ -108,7 +116,8 @@ class MemoryPerceptualService:
return result
except Exception as e:
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}")
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
exc_info=True)
raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
BizCode.DB_ERROR)
@@ -138,7 +147,7 @@ class MemoryPerceptualService:
for memory in memories:
meta_data = memory.meta_data or {}
content = meta_data.get("content", {})
# 安全地提取 content 字段,提供默认值
if content:
content_obj = Content(**content)
@@ -149,7 +158,7 @@ class MemoryPerceptualService:
topic = "Unknown"
domain = "Unknown"
keywords = []
memory_item = PerceptualMemoryItem(
id=memory.id,
perceptual_type=PerceptualType(memory.perceptual_type),
@@ -161,7 +170,7 @@ class MemoryPerceptualService:
topic=topic,
domain=domain,
keywords=keywords,
created_time=int(memory.created_time.timestamp()*1000),
created_time=int(memory.created_time.timestamp() * 1000),
storage_service=FileStorageService(memory.storage_service),
)
memory_items.append(memory_item)
@@ -183,3 +192,98 @@ class MemoryPerceptualService:
except Exception as e:
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
async def generate_perceptual_memory(
self,
end_user_id: str,
model_config: ModelInfo,
file_type: str,
file_url: str,
file_message: dict,
):
memories = self.repository.get_by_url(file_url)
if memories:
business_logger.info(f"Perceptual memory already exists: {file_url}")
if end_user_id not in [memory.end_user_id for memory in memories]:
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
memory_cache = memories[0]
self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType(memory_cache.perceptual_type),
file_path=memory_cache.file_path,
file_name=memory_cache.file_name,
file_ext=memory_cache.file_ext,
summary=memory_cache.summary,
meta_data=memory_cache.meta_data
)
self.db.commit()
return
llm = RedBearLLM(RedBearModelConfig(
model_name=model_config.model_name,
provider=model_config.provider,
api_key=model_config.api_key,
base_url=model_config.api_base,
is_omni=model_config.is_omni
), type=model_config.model_type)
try:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
messages = [
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
{"role": RoleType.USER.value, "content": [
{"type": "text", "text": "Summarize the following file"}, file_message
]}
]
result = await llm.ainvoke(messages)
content = json_repair.repair_json(result.content, return_objects=True)
path = urlparse(file_url).path
filename = os.path.basename(path)
filename = unquote(filename)
file_ext = os.path.splitext(filename)[1]
if not file_ext:
if file_type == FileType.AUDIO:
file_ext = ".mp3"
elif file_type == FileType.VIDEO:
file_ext = ".mp4"
elif file_type == FileType.DOCUMENT:
file_ext = ".txt"
elif file_type == FileType.IMAGE:
file_ext = ".jpg"
filename += file_ext
file_content = {
"keywords": content.get("keywords", []),
"topic": content.get("topic"),
"domain": content.get("domain")
}
if file_type in [FileType.IMAGE, FileType.VIDEO]:
file_modalities = {
"scene": content.get("scene")
}
elif file_type in [FileType.DOCUMENT]:
file_modalities = {
"section_count": content.get("section_count"),
"title": content.get("title"),
"first_line": content.get("first_line")
}
else:
file_modalities = {
"speaker_count": content.get("speaker_count")
}
self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType.trans_from_file_type(file_type),
file_path=file_url,
file_name=filename,
file_ext=file_ext,
summary=content.get('summary'),
meta_data={
"content": file_content,
"modalities": file_modalities
}
)
self.db.commit()

View File

@@ -10,6 +10,7 @@
"""
import base64
import io
import uuid
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
@@ -23,9 +24,12 @@ from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.models import ModelApiKey
from app.models.file_metadata_model import FileMetadata
from app.schemas.app_schema import FileInput, FileType, TransferMethod
from app.schemas.model_schema import ModelInfo
from app.services.audio_transcription_service import AudioTranscriptionService
from app.tasks import write_perceptual_memory
logger = get_business_logger()
@@ -39,6 +43,7 @@ DOC_MIME = [
class MultimodalFormatStrategy(ABC):
"""多模态格式策略基类"""
def __init__(self, file: FileInput):
self.file = file
@@ -95,7 +100,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
if transcription:
return {
"type": "text",
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
"text": f"<audio url=\"{url}\">\ntext_transcription:{transcription}\n</audio>"
}
# 通义千问音频格式:{"type": "audio", "audio": "url"}
return {
@@ -125,7 +130,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
# 下载图片
if content is None:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response = await client.get(url, follow_redirects=True)
response.raise_for_status()
content = response.content
self.file.set_content(content)
@@ -231,7 +236,7 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
audio_data = content
if content is None:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response = await client.get(url, follow_redirects=True)
response.raise_for_status()
audio_data = response.content
self.file.set_content(audio_data)
@@ -284,34 +289,56 @@ PROVIDER_STRATEGIES = {
class MultimodalService:
"""多模态文件处理服务"""
"""
Service for handling multimodal file processing.
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None,
enable_audio_transcription: bool = False, is_omni: bool = False):
Attributes:
db (Session): Database session.
model_api_key (str): API key for the model provider.
provider (str): Name of the model provider.
is_omni (bool): Indicates whether the model supports full multimodal capability.
capability (list): Capability configuration of the model.
audio_api_key (str | None): API key used for audio transcription.
enable_audio_transcription (bool): Whether audio transcription is enabled.
"""
def __init__(
self,
db: Session,
api_config: ModelInfo | None = None,
audio_api_key: Optional[str] = None,
enable_audio_transcription: bool = False,
):
"""
初始化多模态服务
Initialize the multimodal service.
Args:
db: 数据库会话
provider: 模型提供商dashscope, bedrock, anthropic, openai 等)
api_key: API 密钥(用于音频转文本)
enable_audio_transcription: 是否启用音频转文本
is_omni: 是否为 Omni 模型dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
db (Session): Database session.
api_config (ModelApiKey | None): Model API configuration.
audio_api_key (str | None): API key for audio transcription.
enable_audio_transcription (bool): Enable audio transcription.
"""
self.db = db
self.provider = provider.lower()
self.api_key = api_key
self.api_config = api_config
if self.api_config is not None:
self.model_api_key = api_config.api_key
self.provider = api_config.provider.lower()
self.is_omni = api_config.is_omni
self.capability = api_config.capability
self.audio_api_key = audio_api_key
self.enable_audio_transcription = enable_audio_transcription
self.is_omni = is_omni
async def process_files(
self,
files: Optional[List[FileInput]]
end_user_id: uuid.UUID | str,
files: Optional[List[FileInput]],
) -> List[Dict[str, Any]]:
"""
处理文件列表,返回 LLM 可用的格式
Args:
end_user_id: 用户ID
files: 文件输入列表
Returns:
@@ -319,6 +346,8 @@ class MultimodalService:
"""
if not files:
return []
if isinstance(end_user_id, uuid.UUID):
end_user_id = str(end_user_id)
# 获取对应的策略
# dashscope 的 omni 模型使用 OpenAI 兼容格式
@@ -333,19 +362,25 @@ class MultimodalService:
result = []
for idx, file in enumerate(files):
strategy = strategy_class(file)
if not file.url:
file.url = await self.get_file_url(file)
try:
if file.type == FileType.IMAGE:
if file.type == FileType.IMAGE and "vision" in self.capability:
content = await self._process_image(file, strategy)
result.append(content)
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.DOCUMENT:
content = await self._process_document(file, strategy)
result.append(content)
elif file.type == FileType.AUDIO:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.AUDIO and "audio" in self.capability:
content = await self._process_audio(file, strategy)
result.append(content)
elif file.type == FileType.VIDEO:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.VIDEO and "video" in self.capability:
content = await self._process_video(file, strategy)
result.append(content)
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
else:
logger.warning(f"不支持的文件类型: {file.type}")
except Exception as e:
@@ -355,7 +390,8 @@ class MultimodalService:
"file_index": idx,
"file_type": file.type,
"error": str(e)
}
},
exc_info=True
)
# 继续处理其他文件,不中断整个流程
result.append({
@@ -366,6 +402,17 @@ class MultimodalService:
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result
def write_perceptual_memory(
self,
end_user_id: str,
file_type: str,
file_url: str,
file_message: dict
):
"""写入感知记忆"""
if end_user_id and self.api_config:
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理图片文件
@@ -387,43 +434,6 @@ class MultimodalService:
"text": f"[图片处理失败: {str(e)}]"
}
@staticmethod
async def _download_and_encode_image(url: str) -> tuple[str, str]:
"""
下载图片并转换为 base64
Args:
url: 图片 URL
Returns:
tuple: (base64_data, media_type)
"""
from mimetypes import guess_type
# 下载图片
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
# 获取图片数据
image_data = response.content
# 确定 media type
content_type = response.headers.get("content-type")
if content_type and content_type.startswith("image/"):
media_type = content_type
else:
# 从 URL 推断
guessed_type, _ = guess_type(url)
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
# 转换为 base64
base64_data = base64.b64encode(image_data).decode("utf-8")
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
return base64_data, media_type
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理文档文件PDF、Word 等)
@@ -436,7 +446,6 @@ class MultimodalService:
Dict: 根据 provider 返回不同格式的文档内容
"""
if file.transfer_method == TransferMethod.REMOTE_URL:
# 远程文档暂不支持提取
return {
"type": "text",
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
@@ -471,12 +480,12 @@ class MultimodalService:
# 如果启用音频转文本且有 API Key
transcription = None
if self.enable_audio_transcription and self.api_key:
if self.enable_audio_transcription and self.audio_api_key:
logger.info(f"开始音频转文本: {url}")
if self.provider == "dashscope":
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.audio_api_key)
elif self.provider == "openai":
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
transcription = await AudioTranscriptionService.transcribe_openai(url, self.audio_api_key)
else:
logger.warning(f"Provider {self.provider} 不支持音频转文本")
@@ -557,7 +566,7 @@ class MultimodalService:
file_content = file.get_content()
if not file_content:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file.url)
response = await client.get(file.url, follow_redirects=True)
response.raise_for_status()
file_content = response.content
file.set_content(file_content)

View File

@@ -0,0 +1,53 @@
{% raw %}You are a professional information extraction system.
Your task is to analyze the provided document content and generate structured metadata.
Extract the following fields:
* **summary**: A concise summary of the document in 24 sentences.
* **keywords**: 510 important keywords or key phrases that best represent the document. This field MUST be a JSON array of strings.
* **topic**: The primary topic of the document expressed as a short phrase (38 words).
* **domain**: The broader knowledge domain or field the document belongs to (e.g., Artificial Intelligence, Computer Science, Finance, Healthcare, Education, Law, etc.).
STRICT RULES:
1. Output MUST be valid JSON.
2. Do NOT output markdown.
3. Do NOT output explanations.
4. Do NOT output any text before or after the JSON.
5. The JSON MUST contain EXACTLY these four keys:
* summary
* keywords
* topic
* domain{% endraw %}
{% if file_type == 'image' or file_type == 'video' %} * scene {% endif %}
{% if file_type == 'audio' %} * speaker_count {% endif %}
{% if file_type == 'document' %} * section_count
* title
* first_line
{% endif %}
{% raw %}
6. `keywords` MUST be a JSON array of strings.
7. If the document content is insufficient, infer the best possible answer based on context.
8. Ensure the JSON is syntactically correct.
{% endraw %}
9. Output using the language {{ language }}
{% raw %}
Required JSON format:
{
"summary": "string",
"keywords": ["keyword1", "keyword2", "keyword3", "keyword4", "keyword5"],
"topic": "string",
"domain": "string",
{% endraw %}
{% if file_type == 'image' or file_type == 'video' %} "scene": ["string", "string"] {% endif %}
{% if file_type == 'document' %} "section_count": integer
"title": "string",
"first_line": "string"
{% endif %}
{% if file_type == 'audio' %} "speaker_count": integer {% endif %}
{% raw %}
}
Now analyze the following document and return the JSON result.{% endraw %}

View File

@@ -217,4 +217,55 @@ class TenantService:
skip=skip,
limit=limit,
is_active=is_active
)
)
def get_tenant_language_config(self, tenant_id: uuid.UUID) -> Optional[dict]:
"""获取租户语言配置"""
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
if not tenant:
raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND)
return {
"default_language": tenant.default_language,
"supported_languages": tenant.supported_languages
}
def update_tenant_language_config(
self,
tenant_id: uuid.UUID,
default_language: str,
supported_languages: list
) -> Optional[dict]:
"""更新租户语言配置"""
# 检查租户是否存在
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
if not tenant:
raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND)
# 验证默认语言在支持的语言列表中
if default_language not in supported_languages:
raise BusinessException(
"默认语言必须在支持的语言列表中",
code=BizCode.VALIDATION_FAILED
)
try:
# 更新语言配置
tenant.default_language = default_language
tenant.supported_languages = supported_languages
self.db.commit()
self.db.refresh(tenant)
business_logger.info(
f"更新租户语言配置成功: {tenant.name} (ID: {tenant.id}), "
f"默认语言: {default_language}, 支持语言: {supported_languages}"
)
return {
"default_language": tenant.default_language,
"supported_languages": tenant.supported_languages
}
except Exception as e:
self.db.rollback()
business_logger.error(f"更新租户语言配置失败: {str(e)}")
raise BusinessException(f"更新租户语言配置失败: {str(e)}", code=BizCode.DB_ERROR)

View File

@@ -438,24 +438,26 @@ def update_last_login_time(db: Session, user_id: uuid.UUID) -> User:
async def change_password(db: Session, user_id: uuid.UUID, old_password: str, new_password: str, current_user: User) -> User:
"""普通用户修改自己的密码"""
from app.i18n.service import t
business_logger.info(f"用户修改密码请求: user_id={user_id}, current_user={current_user.id}")
# 检查权限:只能修改自己的密码
if current_user.id != user_id:
business_logger.warning(f"用户尝试修改他人密码: current_user={current_user.id}, target_user={user_id}")
raise PermissionDeniedException("You can only change your own password")
raise PermissionDeniedException(t("auth.password.change_failed"))
try:
# 获取用户
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
if not db_user:
business_logger.warning(f"用户不存在: {user_id}")
raise BusinessException("User not found", code=BizCode.USER_NOT_FOUND)
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
# 验证旧密码
if not verify_password(old_password, db_user.hashed_password):
business_logger.warning(f"用户旧密码验证失败: {user_id}")
raise BusinessException("当前密码不正确", code=BizCode.VALIDATION_FAILED)
raise BusinessException(t("auth.password.incorrect"), code=BizCode.VALIDATION_FAILED)
# 更新密码
db_user.hashed_password = get_password_hash(new_password)
@@ -471,7 +473,7 @@ async def change_password(db: Session, user_id: uuid.UUID, old_password: str, ne
except Exception as e:
business_logger.error(f"修改用户密码失败: user_id={user_id} - {str(e)}")
db.rollback()
raise BusinessException(f"修改用户密码失败: user_id={user_id} - {str(e)}", code=BizCode.DB_ERROR)
raise BusinessException(t("auth.password.change_failed"), code=BizCode.DB_ERROR)
async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_password: str = None, current_user: User = None) -> tuple[User, str]:
@@ -487,6 +489,8 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
Returns:
tuple[User, str]: (更新后的用户对象, 实际使用的密码)
"""
from app.i18n.service import t
business_logger.info(f"管理员修改用户密码请求: admin={current_user.id}, target_user={target_user_id}")
# 检查权限:只有超级管理员可以修改他人密码
@@ -496,7 +500,7 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
try:
permission_service.check_superuser(
subject,
error_message="只有超级管理员可以修改他人密码"
error_message=t("auth.password.change_failed")
)
except PermissionDeniedException as e:
business_logger.warning(f"非超管用户尝试修改他人密码: current_user={current_user.id}")
@@ -507,12 +511,12 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
target_user = user_repository.get_user_by_id(db=db, user_id=target_user_id)
if not target_user:
business_logger.warning(f"目标用户不存在: {target_user_id}")
raise BusinessException("目标用户不存在", code=BizCode.USER_NOT_FOUND)
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
# 检查租户权限:超管只能修改同租户用户的密码
if current_user.tenant_id != target_user.tenant_id:
business_logger.warning(f"跨租户密码修改尝试: admin_tenant={current_user.tenant_id}, target_tenant={target_user.tenant_id}")
raise BusinessException("不可跨租户修改用户密码", code=BizCode.FORBIDDEN)
raise BusinessException(t("auth.password.change_failed"), code=BizCode.FORBIDDEN)
# 如果没有提供新密码,则生成随机密码
actual_password = new_password if new_password else generate_random_password()
@@ -532,7 +536,7 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
except Exception as e:
business_logger.error(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}")
db.rollback()
raise BusinessException(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}", code=BizCode.DB_ERROR)
raise BusinessException(t("auth.password.change_failed"), code=BizCode.DB_ERROR)
def generate_random_password(length: int = 12) -> str:
@@ -740,3 +744,54 @@ async def verify_and_change_email(db: Session, user_id: uuid.UUID, new_email: Em
#
# business_logger.info(f"用户邮箱修改成功: {db_user.username}, new_email={new_email}")
# return db_user
def get_user_language_preference(db: Session, user_id: uuid.UUID, current_user: User) -> str:
"""获取用户语言偏好"""
business_logger.info(f"获取用户语言偏好: user_id={user_id}")
# 权限检查:只能获取自己的语言偏好
if current_user.id != user_id:
raise PermissionDeniedException("只能获取自己的语言偏好")
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
if not db_user:
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
language = db_user.preferred_language or "zh"
business_logger.info(f"用户语言偏好: {db_user.username}, language={language}")
return language
def update_user_language_preference(
db: Session,
user_id: uuid.UUID,
language: str,
current_user: User
) -> User:
"""更新用户语言偏好"""
business_logger.info(f"更新用户语言偏好: user_id={user_id}, language={language}")
# 权限检查:只能修改自己的语言偏好
if current_user.id != user_id:
raise PermissionDeniedException("只能修改自己的语言偏好")
# 验证语言代码是否支持
from app.core.config import settings
if language not in settings.I18N_SUPPORTED_LANGUAGES:
raise BusinessException(
f"不支持的语言代码: {language}。支持的语言: {', '.join(settings.I18N_SUPPORTED_LANGUAGES)}",
code=BizCode.VALIDATION_FAILED
)
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
if not db_user:
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
# 更新语言偏好
db_user.preferred_language = language
db.commit()
db.refresh(db_user)
business_logger.info(f"用户语言偏好更新成功: {db_user.username}, language={language}")
return db_user