Merge branch 'develop' into release/v0.2.7

This commit is contained in:
Ke Sun
2026-03-16 15:47:00 +08:00
committed by GitHub
155 changed files with 13164 additions and 1796 deletions

View File

@@ -8,25 +8,21 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig
from app.models import MultiAgentConfig, AgentConfig, ModelType
from app.models import WorkflowConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas import DraftRunRequest
from app.schemas.app_schema import FileInput
from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \
AgentRunService
from app.services.draft_run_service import create_web_search_tool
from app.services.draft_run_service import AgentRunService
from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.multimodal_service import MultimodalService
from app.services.tool_service import ToolService
from app.services.workflow_service import WorkflowService
logger = get_business_logger()
@@ -126,8 +122,17 @@ class AppChatService:
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
processed_files = await multimodal_service.process_files(files)
model_info = ModelInfo(
model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
api_key=api_key_obj.api_key,
api_base=api_key_obj.api_base,
capability=api_key_obj.capability,
is_omni=api_key_obj.is_omni,
model_type=ModelType.LLM
)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件")
# 调用 Agent支持多模态
@@ -266,8 +271,17 @@ class AppChatService:
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
processed_files = await multimodal_service.process_files(files)
model_info = ModelInfo(
model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
api_key=api_key_obj.api_key,
api_base=api_key_obj.api_base,
capability=api_key_obj.capability,
is_omni=api_key_obj.is_omni,
model_type=ModelType.LLM
)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件")
# 流式调用 Agent支持多模态

View File

@@ -12,7 +12,7 @@ import uuid
from typing import Annotated, Any, Dict, List, Optional, Tuple
from fastapi import Depends
from sqlalchemy import and_, func, or_, select
from sqlalchemy import and_, delete, func, or_, select
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
@@ -102,7 +102,8 @@ class AppService:
# 2. 检查是否是共享给本工作空间的应用
stmt = select(AppShare).where(
AppShare.source_app_id == app.id,
AppShare.target_workspace_id == workspace_id
AppShare.target_workspace_id == workspace_id,
AppShare.is_active.is_(True)
)
share = self.db.scalars(stmt).first()
@@ -125,6 +126,50 @@ class AppService:
)
raise BusinessException("应用不可访问", BizCode.WORKSPACE_NO_ACCESS)
def _get_share_permission(self, app: App, workspace_id: Optional[uuid.UUID]) -> Optional[str]:
"""获取共享应用的权限
Returns:
None: 不是共享应用(是本工作空间的应用)
'readonly': 只读共享
'editable': 可编辑共享
"""
from app.models import AppShare
if workspace_id is None or app.workspace_id == workspace_id:
return None # 本工作空间的应用,不是共享的
stmt = select(AppShare).where(
AppShare.source_app_id == app.id,
AppShare.target_workspace_id == workspace_id,
AppShare.is_active.is_(True)
)
share = self.db.scalars(stmt).first()
return share.permission if share else None
def _validate_app_writable(self, app: App, workspace_id: Optional[uuid.UUID]) -> None:
"""Validate that the app config is writable (owner only).
Shared apps (both readonly and editable) cannot modify config.
- Own workspace app: allowed
- Any shared app: denied
Raises:
BusinessException: when app is not writable
"""
if workspace_id is None:
return
# Own workspace app, allow
if app.workspace_id == workspace_id:
return
logger.warning(
"应用写操作被拒",
extra={"app_id": str(app.id), "workspace_id": str(workspace_id)}
)
raise BusinessException("共享应用不可修改配置", BizCode.WORKSPACE_NO_ACCESS)
def _get_app_or_404(self, app_id: uuid.UUID) -> App:
"""获取应用或抛出404异常
@@ -454,6 +499,33 @@ class AppService:
Returns:
app_schema.App: 应用 Schema
"""
is_shared = app.workspace_id != current_workspace_id
share_permission = None
source_workspace_name = None
source_workspace_icon = None
source_app_version = None
source_app_is_active = None
if is_shared:
# 查询共享权限和来源工作空间名称
from app.models import AppShare
stmt = select(AppShare).where(
AppShare.source_app_id == app.id,
AppShare.target_workspace_id == current_workspace_id,
AppShare.is_active.is_(True)
)
share = self.db.scalars(stmt).first()
if share:
share_permission = share.permission
if share.source_workspace:
source_workspace_name = share.source_workspace.name
source_workspace_icon = share.source_workspace.icon
# 版本号和生效状态
if app.current_release:
source_app_version = app.current_release.version_name
source_app_is_active = app.is_active
app_dict = {
"id": app.id,
"workspace_id": app.workspace_id,
@@ -468,7 +540,12 @@ class AppService:
"tags": app.tags or [],
"current_release_id": app.current_release_id,
"is_active": app.is_active,
"is_shared": app.workspace_id != current_workspace_id, # 判断是否是共享应用
"is_shared": is_shared,
"share_permission": share_permission,
"source_workspace_name": source_workspace_name,
"source_workspace_icon": source_workspace_icon,
"source_app_version": source_app_version,
"source_app_is_active": source_app_is_active,
"created_at": app.created_at,
"updated_at": app.updated_at
}
@@ -594,7 +671,7 @@ class AppService:
logger.info("更新应用", extra={"app_id": str(app_id)})
app = self._get_app_or_404(app_id)
self._validate_workspace_access(app, workspace_id)
self._validate_app_writable(app, workspace_id)
changed = False
for field in ["name", "description", "icon", "icon_type", "visibility", "status", "tags"]:
@@ -804,6 +881,7 @@ class AppService:
status: Optional[str] = None,
search: Optional[str] = None,
include_shared: bool = True,
shared_only: bool = False,
page: int = 1,
pagesize: int = 10,
) -> Tuple[List[App], int]:
@@ -849,18 +927,24 @@ class AppService:
if search:
filters.append(func.lower(App.name).like(f"%{search.lower()}%"))
# 基础查询:本工作空间的应用
if include_shared:
# 查询本工作空间的应用 + 分享给本工作空间的应用
# 使用 OR 条件workspace_id = current OR app_id IN (shared apps)
# shared_only implies include_shared; enforce to avoid confusing API usage
if shared_only:
include_shared = True
# 获取分享给本工作空间的应用ID列表
# 基础查询:本工作空间的应用
if shared_only:
# 只返回共享给本工作空间的应用,不含自有应用
shared_app_ids_stmt = (
select(AppShare.source_app_id)
.where(AppShare.target_workspace_id == workspace_id)
.where(AppShare.target_workspace_id == workspace_id, AppShare.is_active.is_(True))
)
stmt = select(App).where(App.id.in_(shared_app_ids_stmt))
elif include_shared:
# 查询本工作空间的应用 + 分享给本工作空间的应用
shared_app_ids_stmt = (
select(AppShare.source_app_id)
.where(AppShare.target_workspace_id == workspace_id, AppShare.is_active.is_(True))
)
# 构建主查询:本工作空间的应用 OR 分享的应用
stmt = select(App).where(
or_(
App.workspace_id == workspace_id,
@@ -952,7 +1036,7 @@ class AppService:
if app.type != "agent":
raise BusinessException("只有 Agent 类型应用支持 Agent 配置", BizCode.APP_TYPE_NOT_SUPPORTED)
self._validate_workspace_access(app, workspace_id)
self._validate_app_writable(app, workspace_id)
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active.is_(True)).order_by(
AgentConfig.updated_at.desc())
@@ -1163,7 +1247,7 @@ class AppService:
if app.type != AppType.WORKFLOW:
raise BusinessException("只有 Workflow 类型应用支持 Workflow 配置", BizCode.APP_TYPE_NOT_SUPPORTED)
self._validate_workspace_access(app, workspace_id)
self._validate_app_writable(app, workspace_id)
# 获取现有配置
repo = WorkflowConfigRepository(self.db)
@@ -1654,7 +1738,8 @@ class AppService:
app_id: uuid.UUID,
target_workspace_ids: List[uuid.UUID],
user_id: uuid.UUID,
workspace_id: Optional[uuid.UUID] = None
workspace_id: Optional[uuid.UUID] = None,
permission: str = "readonly"
) -> list[AppShare]:
"""分享应用到其他工作空间
@@ -1685,6 +1770,14 @@ class AppService:
app = self._get_app_or_404(app_id)
self._validate_workspace_access(app, workspace_id)
# 仅允许 agent 和 workflow 类型共享multi_agent 不支持
from app.models.app_model import AppType
if app.type == AppType.MULTI_AGENT:
raise BusinessException(
"集群 Agent 不支持共享应用功能",
BizCode.INVALID_PARAMETER
)
# 2. 验证目标工作空间
for target_ws_id in target_workspace_ids:
target_ws = self.db.get(Workspace, target_ws_id)
@@ -1706,7 +1799,8 @@ class AppService:
# 检查是否已经分享过
stmt = select(AppShare).where(
AppShare.source_app_id == app_id,
AppShare.target_workspace_id == target_ws_id
AppShare.target_workspace_id == target_ws_id,
AppShare.is_active.is_(True)
)
existing_share = self.db.scalars(stmt).first()
@@ -1725,6 +1819,7 @@ class AppService:
source_workspace_id=app.workspace_id,
target_workspace_id=target_ws_id,
shared_by=user_id,
permission=permission,
created_at=now,
updated_at=now
)
@@ -1784,7 +1879,8 @@ class AppService:
# 2. 查找分享记录
stmt = select(AppShare).where(
AppShare.source_app_id == app_id,
AppShare.target_workspace_id == target_workspace_id
AppShare.target_workspace_id == target_workspace_id,
AppShare.is_active.is_(True)
)
share = self.db.scalars(stmt).first()
@@ -1798,8 +1894,8 @@ class AppService:
f"app_id={app_id}, target_workspace_id={target_workspace_id}"
)
# 3. 删除分享记录
self.db.delete(share)
# 3. 逻辑删除分享记录
share.is_active = False
self.db.commit()
logger.info(
@@ -1807,6 +1903,48 @@ class AppService:
extra={"app_id": str(app_id), "target_workspace_id": str(target_workspace_id)}
)
def unshare_all_apps_to_workspace(
self,
*,
target_workspace_id: uuid.UUID,
workspace_id: uuid.UUID
) -> int:
"""Cancel all app shares from current workspace to a target workspace.
Args:
target_workspace_id: Target workspace ID to cancel all shares to
workspace_id: Current workspace ID (source)
Returns:
Number of share records deleted
"""
from app.models import AppShare
logger.info(
"取消对目标工作空间的所有应用分享",
extra={"target_workspace_id": str(target_workspace_id), "workspace_id": str(workspace_id)}
)
# Query active records first for reliable count
id_stmt = select(AppShare.id).where(
AppShare.source_workspace_id == workspace_id,
AppShare.target_workspace_id == target_workspace_id,
AppShare.is_active.is_(True)
)
ids = list(self.db.scalars(id_stmt).all())
count = len(ids)
if ids:
# Soft delete: mark as inactive
from sqlalchemy import update as sa_update
self.db.execute(
sa_update(AppShare).where(AppShare.id.in_(ids)).values(is_active=False)
)
self.db.commit()
logger.info("已取消分享记录数", extra={"count": count})
return count
def list_app_shares(
self,
*,
@@ -1836,7 +1974,8 @@ class AppService:
# 查询分享记录
stmt = select(AppShare).where(
AppShare.source_app_id == app_id
AppShare.source_app_id == app_id,
AppShare.is_active.is_(True)
).order_by(AppShare.created_at.desc())
shares = list(self.db.scalars(stmt).all())
@@ -1848,6 +1987,166 @@ class AppService:
return shares
def remove_shared_app(
self,
*,
app_id: uuid.UUID,
workspace_id: uuid.UUID
) -> None:
"""被共享者从自己的工作空间移除共享应用
只删除共享记录,不影响源应用。
Args:
app_id: 应用ID
workspace_id: 当前工作空间ID被共享的目标工作空间
Raises:
ResourceNotFoundException: 当共享记录不存在时
"""
from app.models import AppShare
logger.info(
"移除共享应用",
extra={"app_id": str(app_id), "workspace_id": str(workspace_id)}
)
stmt = select(AppShare).where(
AppShare.source_app_id == app_id,
AppShare.target_workspace_id == workspace_id,
AppShare.is_active.is_(True)
)
share = self.db.scalars(stmt).first()
if not share:
raise ResourceNotFoundException(
"共享记录",
f"app_id={app_id}, workspace_id={workspace_id}"
)
# Soft delete
share.is_active = False
self.db.commit()
logger.info(
"共享应用已移除",
extra={"app_id": str(app_id), "workspace_id": str(workspace_id)}
)
def remove_all_shared_apps_from_workspace(
self,
*,
source_workspace_id: uuid.UUID,
workspace_id: uuid.UUID
) -> int:
"""Remove all shared apps from a specific source workspace.
Args:
source_workspace_id: The workspace that shared the apps
workspace_id: Current workspace ID (recipient)
Returns:
Number of share records deleted
"""
from app.models import AppShare
logger.info(
"批量移除来源工作空间的共享应用",
extra={"source_workspace_id": str(source_workspace_id), "workspace_id": str(workspace_id)}
)
# Query active records for reliable count, then soft delete
id_stmt = select(AppShare.id).where(
AppShare.source_workspace_id == source_workspace_id,
AppShare.target_workspace_id == workspace_id,
AppShare.is_active.is_(True)
)
ids = list(self.db.scalars(id_stmt).all())
count = len(ids)
if ids:
from sqlalchemy import update as sa_update
self.db.execute(
sa_update(AppShare).where(AppShare.id.in_(ids)).values(is_active=False)
)
self.db.commit()
logger.info("已移除共享记录数", extra={"count": count})
return count
def list_my_shared_out(
self,
*,
workspace_id: uuid.UUID
) -> List[AppShare]:
"""列出本工作空间主动分享出去的所有记录(我的共享)
Returns:
List[AppShare]: 分享记录列表,含源应用信息
"""
from app.models import AppShare
stmt = (
select(AppShare)
.where(
AppShare.source_workspace_id == workspace_id,
AppShare.is_active.is_(True)
)
.order_by(AppShare.created_at.desc())
)
return list(self.db.scalars(stmt).all())
def update_share_permission(
self,
*,
app_id: uuid.UUID,
target_workspace_id: uuid.UUID,
permission: str,
workspace_id: Optional[uuid.UUID] = None
) -> "AppShare":
"""更新共享权限readonly <-> editable
Args:
app_id: 应用ID
target_workspace_id: 目标工作空间ID
permission: 新权限值 readonly | editable
workspace_id: 当前工作空间ID用于权限验证
Returns:
AppShare: 更新后的共享记录
"""
from app.models import AppShare
if permission not in ("readonly", "editable"):
raise BusinessException("权限值无效,只允许 readonly 或 editable", BizCode.INVALID_PARAMETER)
app = self._get_app_or_404(app_id)
self._validate_workspace_access(app, workspace_id)
stmt = select(AppShare).where(
AppShare.source_app_id == app_id,
AppShare.target_workspace_id == target_workspace_id,
AppShare.is_active.is_(True)
)
share = self.db.scalars(stmt).first()
if not share:
raise ResourceNotFoundException(
"共享记录",
f"app_id={app_id}, target_workspace_id={target_workspace_id}"
)
share.permission = permission
share.updated_at = datetime.datetime.now()
self.db.commit()
self.db.refresh(share)
logger.info(
"共享权限已更新",
extra={"app_id": str(app_id), "target_workspace_id": str(target_workspace_id), "permission": permission}
)
return share
# ==================== 向后兼容的函数接口 ====================
# 保留函数接口以兼容现有代码,但内部使用服务类
@@ -1942,6 +2241,7 @@ def list_apps(
status: Optional[str] = None,
search: Optional[str] = None,
include_shared: bool = True,
shared_only: bool = False,
page: int = 1,
pagesize: int = 10,
) -> Tuple[List[App], int]:
@@ -1954,6 +2254,7 @@ def list_apps(
status=status,
search=search,
include_shared=include_shared,
shared_only=shared_only,
page=page,
pagesize=pagesize,
)

View File

@@ -75,7 +75,7 @@ class AudioTranscriptionService:
try:
# 下载音频文件
async with httpx.AsyncClient(timeout=60.0) as client:
audio_response = await client.get(audio_url)
audio_response = await client.get(audio_url, follow_redirects=True)
audio_response.raise_for_status()
audio_data = audio_response.content

View File

@@ -80,6 +80,7 @@ def authenticate_user_or_raise(db: Session, email: str, password: str) -> User:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.core.logging_config import get_auth_logger
from app.i18n.service import t
logger = get_auth_logger()
@@ -87,17 +88,17 @@ def authenticate_user_or_raise(db: Session, email: str, password: str) -> User:
user = user_repository.get_user_by_email(db, email=email)
if not user:
logger.warning(f"用户不存在: {email}")
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
# 检查用户状态
if not user.is_active:
logger.warning(f"用户未激活: {email}")
raise BusinessException("用户未激活", code=BizCode.USER_NOT_FOUND)
raise BusinessException(t("auth.login.account_disabled"), code=BizCode.USER_NOT_FOUND)
# 验证密码
if not verify_password(password, user.hashed_password):
logger.warning(f"密码错误: {email}")
raise BusinessException("密码错误", code=BizCode.PASSWORD_ERROR)
raise BusinessException(t("auth.password.incorrect"), code=BizCode.PASSWORD_ERROR)
logger.info(f"用户认证成功: {email}")
return user
@@ -254,6 +255,8 @@ def decode_access_token(token: str) -> dict:
Raises:
BusinessException: token 无效
"""
from app.i18n.service import t
try:
payload = jwt.decode(token, TOKEN_SECRET_KEY, algorithms=[TOKEN_ALGORITHM])
return {
@@ -261,4 +264,4 @@ def decode_access_token(token: str) -> dict:
"share_token": payload["share_token"]
}
except jwt.InvalidTokenError:
raise BusinessException("无效的访问 token", BizCode.INVALID_TOKEN)
raise BusinessException(t("auth.token.invalid"), BizCode.INVALID_TOKEN)

View File

@@ -23,9 +23,10 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.models import AgentConfig, ModelConfig
from app.models import AgentConfig, ModelConfig, ModelType
from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput
from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services import task_service
from app.services.conversation_service import ConversationService
@@ -501,9 +502,18 @@ class AgentRunService:
processed_files = None
if files:
# 获取 provider 信息
model_info = ModelInfo(
model_name=api_key_config["model_name"],
provider=api_key_config["provider"],
api_key=api_key_config["api_key"],
api_base=api_key_config["api_base"],
capability=api_key_config["capability"],
is_omni=api_key_config["is_omni"],
model_type=ModelType.LLM
)
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
processed_files = await multimodal_service.process_files(files)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索
@@ -688,7 +698,8 @@ class AgentRunService:
conversation_id=conversation_id,
app_id=agent_config.app_id,
workspace_id=workspace_id,
user_id=user_id
user_id=user_id,
sub_agent=sub_agent
)
# 6. 加载历史消息
@@ -703,9 +714,18 @@ class AgentRunService:
processed_files = None
if files:
# 获取 provider 信息
model_info = ModelInfo(
model_name=api_key_config["model_name"],
provider=api_key_config["provider"],
api_key=api_key_config["api_key"],
api_base=api_key_config["api_base"],
capability=api_key_config["capability"],
is_omni=api_key_config["is_omni"],
model_type=ModelType.LLM
)
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
processed_files = await multimodal_service.process_files(files)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
# 7. 知识库检索
@@ -840,7 +860,8 @@ class AgentRunService:
"api_key": api_key.api_key,
"api_base": api_key.api_base,
"api_key_id": api_key.id,
"is_omni": api_key.is_omni
"is_omni": api_key.is_omni,
"capability": api_key.capability
}
async def _ensure_conversation(
@@ -848,7 +869,8 @@ class AgentRunService:
conversation_id: Optional[str],
app_id: uuid.UUID,
workspace_id: uuid.UUID,
user_id: Optional[str]
user_id: Optional[str],
sub_agent: bool = False
) -> str:
"""确保会话存在(创建或验证)
@@ -909,20 +931,36 @@ class AgentRunService:
conv_uuid = uuid.UUID(conversation_id)
conversation = conversation_service.get_conversation(conv_uuid)
# 验证会话属于当前工作空间
if conversation.workspace_id != workspace_id:
logger.warning(
"会话不属于当前工作空间",
extra={
"conversation_id": conversation_id,
"conversation_workspace_id": str(conversation.workspace_id),
"current_workspace_id": str(workspace_id)
}
)
raise BusinessException(
"会话不属于当前工作空间",
BizCode.PERMISSION_DENIED
)
# 验证会话属于当前工作空间(或属于共享应用的源工作空间)
# sub_agent 内部调用时跳过校验,已在上层验证过
if not sub_agent and conversation.workspace_id != workspace_id:
# 检查是否是共享应用的会话(被共享者 workspace 访问源应用)
from app.models import AppShare
from sqlalchemy import select as sa_select
share = self.db.scalars(
sa_select(AppShare).where(
AppShare.source_app_id == app_id,
AppShare.target_workspace_id == workspace_id
)
).first()
# 情况2sub_agent 内部调用时workspace_id 是源应用的 workspace
# 而会话是被共享者创建的,只要会话属于同一个 app 即可放行
same_app = (conversation.app_id == app_id)
if not share and not same_app:
logger.warning(
"会话不属于当前工作空间",
extra={
"conversation_id": conversation_id,
"conversation_workspace_id": str(conversation.workspace_id),
"current_workspace_id": str(workspace_id)
}
)
raise BusinessException(
"会话不属于当前工作空间",
BizCode.PERMISSION_DENIED
)
logger.debug(
"使用现有会话",

View File

@@ -274,7 +274,7 @@ class MemoryAgentService:
Args:
end_user_id: Group identifier (also used as end_user_id)
message: Message to write
messages: Message to write
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)

View File

@@ -1,19 +1,27 @@
import os
import uuid
from typing import Dict, Any, Optional
from urllib.parse import urlparse, unquote
import json_repair
from jinja2 import Template
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
from app.models.prompt_optimizer_model import RoleType
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
from app.schemas import FileType
from app.schemas.memory_perceptual_schema import (
PerceptualQuerySchema,
PerceptualTimelineResponse,
PerceptualMemoryItem,
AudioModal, Content, VideoModal, TextModal
)
from app.schemas.model_schema import ModelInfo
business_logger = get_business_logger()
@@ -99,7 +107,7 @@ class MemoryPerceptualService:
"keywords": content.keywords,
"topic": content.topic,
"domain": content.domain,
"created_time": int(memory.created_time.timestamp()*1000),
"created_time": int(memory.created_time.timestamp() * 1000),
**detail
}
@@ -108,7 +116,8 @@ class MemoryPerceptualService:
return result
except Exception as e:
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}")
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
exc_info=True)
raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
BizCode.DB_ERROR)
@@ -138,7 +147,7 @@ class MemoryPerceptualService:
for memory in memories:
meta_data = memory.meta_data or {}
content = meta_data.get("content", {})
# 安全地提取 content 字段,提供默认值
if content:
content_obj = Content(**content)
@@ -149,7 +158,7 @@ class MemoryPerceptualService:
topic = "Unknown"
domain = "Unknown"
keywords = []
memory_item = PerceptualMemoryItem(
id=memory.id,
perceptual_type=PerceptualType(memory.perceptual_type),
@@ -161,7 +170,7 @@ class MemoryPerceptualService:
topic=topic,
domain=domain,
keywords=keywords,
created_time=int(memory.created_time.timestamp()*1000),
created_time=int(memory.created_time.timestamp() * 1000),
storage_service=FileStorageService(memory.storage_service),
)
memory_items.append(memory_item)
@@ -183,3 +192,98 @@ class MemoryPerceptualService:
except Exception as e:
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
async def generate_perceptual_memory(
self,
end_user_id: str,
model_config: ModelInfo,
file_type: str,
file_url: str,
file_message: dict,
):
memories = self.repository.get_by_url(file_url)
if memories:
business_logger.info(f"Perceptual memory already exists: {file_url}")
if end_user_id not in [memory.end_user_id for memory in memories]:
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
memory_cache = memories[0]
self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType(memory_cache.perceptual_type),
file_path=memory_cache.file_path,
file_name=memory_cache.file_name,
file_ext=memory_cache.file_ext,
summary=memory_cache.summary,
meta_data=memory_cache.meta_data
)
self.db.commit()
return
llm = RedBearLLM(RedBearModelConfig(
model_name=model_config.model_name,
provider=model_config.provider,
api_key=model_config.api_key,
base_url=model_config.api_base,
is_omni=model_config.is_omni
), type=model_config.model_type)
try:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
messages = [
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
{"role": RoleType.USER.value, "content": [
{"type": "text", "text": "Summarize the following file"}, file_message
]}
]
result = await llm.ainvoke(messages)
content = json_repair.repair_json(result.content, return_objects=True)
path = urlparse(file_url).path
filename = os.path.basename(path)
filename = unquote(filename)
file_ext = os.path.splitext(filename)[1]
if not file_ext:
if file_type == FileType.AUDIO:
file_ext = ".mp3"
elif file_type == FileType.VIDEO:
file_ext = ".mp4"
elif file_type == FileType.DOCUMENT:
file_ext = ".txt"
elif file_type == FileType.IMAGE:
file_ext = ".jpg"
filename += file_ext
file_content = {
"keywords": content.get("keywords", []),
"topic": content.get("topic"),
"domain": content.get("domain")
}
if file_type in [FileType.IMAGE, FileType.VIDEO]:
file_modalities = {
"scene": content.get("scene")
}
elif file_type in [FileType.DOCUMENT]:
file_modalities = {
"section_count": content.get("section_count"),
"title": content.get("title"),
"first_line": content.get("first_line")
}
else:
file_modalities = {
"speaker_count": content.get("speaker_count")
}
self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType.trans_from_file_type(file_type),
file_path=file_url,
file_name=filename,
file_ext=file_ext,
summary=content.get('summary'),
meta_data={
"content": file_content,
"modalities": file_modalities
}
)
self.db.commit()

View File

@@ -10,6 +10,7 @@
"""
import base64
import io
import uuid
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
@@ -23,9 +24,12 @@ from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.models import ModelApiKey
from app.models.file_metadata_model import FileMetadata
from app.schemas.app_schema import FileInput, FileType, TransferMethod
from app.schemas.model_schema import ModelInfo
from app.services.audio_transcription_service import AudioTranscriptionService
from app.tasks import write_perceptual_memory
logger = get_business_logger()
@@ -39,6 +43,7 @@ DOC_MIME = [
class MultimodalFormatStrategy(ABC):
"""多模态格式策略基类"""
def __init__(self, file: FileInput):
self.file = file
@@ -95,7 +100,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
if transcription:
return {
"type": "text",
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
"text": f"<audio url=\"{url}\">\ntext_transcription:{transcription}\n</audio>"
}
# 通义千问音频格式:{"type": "audio", "audio": "url"}
return {
@@ -125,7 +130,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
# 下载图片
if content is None:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response = await client.get(url, follow_redirects=True)
response.raise_for_status()
content = response.content
self.file.set_content(content)
@@ -231,7 +236,7 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
audio_data = content
if content is None:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response = await client.get(url, follow_redirects=True)
response.raise_for_status()
audio_data = response.content
self.file.set_content(audio_data)
@@ -284,34 +289,56 @@ PROVIDER_STRATEGIES = {
class MultimodalService:
"""多模态文件处理服务"""
"""
Service for handling multimodal file processing.
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None,
enable_audio_transcription: bool = False, is_omni: bool = False):
Attributes:
db (Session): Database session.
model_api_key (str): API key for the model provider.
provider (str): Name of the model provider.
is_omni (bool): Indicates whether the model supports full multimodal capability.
capability (list): Capability configuration of the model.
audio_api_key (str | None): API key used for audio transcription.
enable_audio_transcription (bool): Whether audio transcription is enabled.
"""
def __init__(
self,
db: Session,
api_config: ModelInfo | None = None,
audio_api_key: Optional[str] = None,
enable_audio_transcription: bool = False,
):
"""
初始化多模态服务
Initialize the multimodal service.
Args:
db: 数据库会话
provider: 模型提供商dashscope, bedrock, anthropic, openai 等)
api_key: API 密钥(用于音频转文本)
enable_audio_transcription: 是否启用音频转文本
is_omni: 是否为 Omni 模型dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
db (Session): Database session.
api_config (ModelApiKey | None): Model API configuration.
audio_api_key (str | None): API key for audio transcription.
enable_audio_transcription (bool): Enable audio transcription.
"""
self.db = db
self.provider = provider.lower()
self.api_key = api_key
self.api_config = api_config
if self.api_config is not None:
self.model_api_key = api_config.api_key
self.provider = api_config.provider.lower()
self.is_omni = api_config.is_omni
self.capability = api_config.capability
self.audio_api_key = audio_api_key
self.enable_audio_transcription = enable_audio_transcription
self.is_omni = is_omni
async def process_files(
self,
files: Optional[List[FileInput]]
end_user_id: uuid.UUID | str,
files: Optional[List[FileInput]],
) -> List[Dict[str, Any]]:
"""
处理文件列表,返回 LLM 可用的格式
Args:
end_user_id: 用户ID
files: 文件输入列表
Returns:
@@ -319,6 +346,8 @@ class MultimodalService:
"""
if not files:
return []
if isinstance(end_user_id, uuid.UUID):
end_user_id = str(end_user_id)
# 获取对应的策略
# dashscope 的 omni 模型使用 OpenAI 兼容格式
@@ -333,19 +362,25 @@ class MultimodalService:
result = []
for idx, file in enumerate(files):
strategy = strategy_class(file)
if not file.url:
file.url = await self.get_file_url(file)
try:
if file.type == FileType.IMAGE:
if file.type == FileType.IMAGE and "vision" in self.capability:
content = await self._process_image(file, strategy)
result.append(content)
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.DOCUMENT:
content = await self._process_document(file, strategy)
result.append(content)
elif file.type == FileType.AUDIO:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.AUDIO and "audio" in self.capability:
content = await self._process_audio(file, strategy)
result.append(content)
elif file.type == FileType.VIDEO:
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
elif file.type == FileType.VIDEO and "video" in self.capability:
content = await self._process_video(file, strategy)
result.append(content)
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
else:
logger.warning(f"不支持的文件类型: {file.type}")
except Exception as e:
@@ -355,7 +390,8 @@ class MultimodalService:
"file_index": idx,
"file_type": file.type,
"error": str(e)
}
},
exc_info=True
)
# 继续处理其他文件,不中断整个流程
result.append({
@@ -366,6 +402,17 @@ class MultimodalService:
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result
def write_perceptual_memory(
self,
end_user_id: str,
file_type: str,
file_url: str,
file_message: dict
):
"""写入感知记忆"""
if end_user_id and self.api_config:
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理图片文件
@@ -387,43 +434,6 @@ class MultimodalService:
"text": f"[图片处理失败: {str(e)}]"
}
@staticmethod
async def _download_and_encode_image(url: str) -> tuple[str, str]:
"""
下载图片并转换为 base64
Args:
url: 图片 URL
Returns:
tuple: (base64_data, media_type)
"""
from mimetypes import guess_type
# 下载图片
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
# 获取图片数据
image_data = response.content
# 确定 media type
content_type = response.headers.get("content-type")
if content_type and content_type.startswith("image/"):
media_type = content_type
else:
# 从 URL 推断
guessed_type, _ = guess_type(url)
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
# 转换为 base64
base64_data = base64.b64encode(image_data).decode("utf-8")
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
return base64_data, media_type
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理文档文件PDF、Word 等)
@@ -436,7 +446,6 @@ class MultimodalService:
Dict: 根据 provider 返回不同格式的文档内容
"""
if file.transfer_method == TransferMethod.REMOTE_URL:
# 远程文档暂不支持提取
return {
"type": "text",
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
@@ -471,12 +480,12 @@ class MultimodalService:
# 如果启用音频转文本且有 API Key
transcription = None
if self.enable_audio_transcription and self.api_key:
if self.enable_audio_transcription and self.audio_api_key:
logger.info(f"开始音频转文本: {url}")
if self.provider == "dashscope":
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.audio_api_key)
elif self.provider == "openai":
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
transcription = await AudioTranscriptionService.transcribe_openai(url, self.audio_api_key)
else:
logger.warning(f"Provider {self.provider} 不支持音频转文本")
@@ -557,7 +566,7 @@ class MultimodalService:
file_content = file.get_content()
if not file_content:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file.url)
response = await client.get(file.url, follow_redirects=True)
response.raise_for_status()
file_content = response.content
file.set_content(file_content)

View File

@@ -0,0 +1,53 @@
{% raw %}You are a professional information extraction system.
Your task is to analyze the provided document content and generate structured metadata.
Extract the following fields:
* **summary**: A concise summary of the document in 24 sentences.
* **keywords**: 510 important keywords or key phrases that best represent the document. This field MUST be a JSON array of strings.
* **topic**: The primary topic of the document expressed as a short phrase (38 words).
* **domain**: The broader knowledge domain or field the document belongs to (e.g., Artificial Intelligence, Computer Science, Finance, Healthcare, Education, Law, etc.).
STRICT RULES:
1. Output MUST be valid JSON.
2. Do NOT output markdown.
3. Do NOT output explanations.
4. Do NOT output any text before or after the JSON.
5. The JSON MUST contain EXACTLY these four keys:
* summary
* keywords
* topic
* domain{% endraw %}
{% if file_type == 'image' or file_type == 'video' %} * scene {% endif %}
{% if file_type == 'audio' %} * speaker_count {% endif %}
{% if file_type == 'document' %} * section_count
* title
* first_line
{% endif %}
{% raw %}
6. `keywords` MUST be a JSON array of strings.
7. If the document content is insufficient, infer the best possible answer based on context.
8. Ensure the JSON is syntactically correct.
{% endraw %}
9. Output using the language {{ language }}
{% raw %}
Required JSON format:
{
"summary": "string",
"keywords": ["keyword1", "keyword2", "keyword3", "keyword4", "keyword5"],
"topic": "string",
"domain": "string",
{% endraw %}
{% if file_type == 'image' or file_type == 'video' %} "scene": ["string", "string"] {% endif %}
{% if file_type == 'document' %} "section_count": integer
"title": "string",
"first_line": "string"
{% endif %}
{% if file_type == 'audio' %} "speaker_count": integer {% endif %}
{% raw %}
}
Now analyze the following document and return the JSON result.{% endraw %}

View File

@@ -217,4 +217,55 @@ class TenantService:
skip=skip,
limit=limit,
is_active=is_active
)
)
def get_tenant_language_config(self, tenant_id: uuid.UUID) -> Optional[dict]:
"""获取租户语言配置"""
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
if not tenant:
raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND)
return {
"default_language": tenant.default_language,
"supported_languages": tenant.supported_languages
}
def update_tenant_language_config(
self,
tenant_id: uuid.UUID,
default_language: str,
supported_languages: list
) -> Optional[dict]:
"""更新租户语言配置"""
# 检查租户是否存在
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
if not tenant:
raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND)
# 验证默认语言在支持的语言列表中
if default_language not in supported_languages:
raise BusinessException(
"默认语言必须在支持的语言列表中",
code=BizCode.VALIDATION_FAILED
)
try:
# 更新语言配置
tenant.default_language = default_language
tenant.supported_languages = supported_languages
self.db.commit()
self.db.refresh(tenant)
business_logger.info(
f"更新租户语言配置成功: {tenant.name} (ID: {tenant.id}), "
f"默认语言: {default_language}, 支持语言: {supported_languages}"
)
return {
"default_language": tenant.default_language,
"supported_languages": tenant.supported_languages
}
except Exception as e:
self.db.rollback()
business_logger.error(f"更新租户语言配置失败: {str(e)}")
raise BusinessException(f"更新租户语言配置失败: {str(e)}", code=BizCode.DB_ERROR)

View File

@@ -1727,6 +1727,150 @@ async def analytics_graph_data(
# 辅助函数
async def analytics_community_graph_data(
db: Session,
end_user_id: str,
) -> Dict[str, Any]:
"""
获取社区图谱数据,包含 Community 节点、ExtractedEntity 节点及其关系。
Returns:
包含 nodes、edges、statistics 的字典,格式与 analytics_graph_data 一致
"""
try:
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
if not end_user:
return {
"nodes": [], "edges": [],
"statistics": {"total_nodes": 0, "total_edges": 0, "node_types": {}, "edge_types": {}},
"message": "用户不存在"
}
# 查询社区节点、实体节点、BELONGS_TO_COMMUNITY 边、实体间关系
from app.repositories.neo4j.cypher_queries import GET_COMMUNITY_GRAPH_DATA
rows = await _neo4j_connector.execute_query(GET_COMMUNITY_GRAPH_DATA, end_user_id=end_user_id)
nodes_map: Dict[str, dict] = {}
edges_map: Dict[str, dict] = {}
# 记录每个 Community 对应的实体 id 列表
community_members: Dict[str, list] = {}
for row in rows:
# Community 节点
c_id = row["c_id"]
if c_id and c_id not in nodes_map:
raw = row["c_props"] or {}
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
"community_id", "end_user_id", "member_count", "updated_at",
"name", "summary", "core_entities",
) if k in raw}
nodes_map[c_id] = {
"id": c_id,
"label": "Community",
"properties": props,
}
# ExtractedEntity 节点 (e)
e_id = row["e_id"]
if e_id and e_id not in nodes_map:
raw = row["e_props"] or {}
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
"name", "end_user_id", "description", "created_at", "entity_type",
) if k in raw}
# 注入所属社区名称c 是 e 直接归属的社区)
c_raw = row["c_props"] or {}
props["community_name"] = _clean_neo4j_value(c_raw.get("name")) or ""
nodes_map[e_id] = {
"id": e_id,
"label": "ExtractedEntity",
"properties": props,
}
# ExtractedEntity 节点 (e2可选)
e2_id = row.get("e2_id")
if e2_id and e2_id not in nodes_map:
raw = row["e2_props"] or {}
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
"name", "end_user_id", "description", "created_at", "entity_type",
) if k in raw}
# e2 的社区归属在后处理阶段通过 community_members 补充
props["community_name"] = ""
nodes_map[e2_id] = {
"id": e2_id,
"label": "ExtractedEntity",
"properties": props,
}
# BELONGS_TO_COMMUNITY 边
b_id = row["b_id"]
if b_id and b_id not in edges_map:
edges_map[b_id] = {
"id": b_id,
"source": e_id,
"target": c_id,
}
# 收集社区成员 id
if c_id and e_id:
community_members.setdefault(c_id, [])
if e_id not in community_members[c_id]:
community_members[c_id].append(e_id)
# EXTRACTED_RELATIONSHIP 边(可选)
r_id = row.get("r_id")
if r_id and r_id not in edges_map and e2_id:
r_props = {k: _clean_neo4j_value(v) for k, v in (row["r_props"] or {}).items()}
source = e_id if row.get("r_from_e") else e2_id
target = e2_id if row.get("r_from_e") else e_id
edges_map[r_id] = {
"id": r_id,
"source": source,
"target": target,
}
nodes = list(nodes_map.values())
edges = list(edges_map.values())
# 为每个 Community 节点注入 member_entity_ids同时补全 e2 节点的 community_name
for c_id, member_ids in community_members.items():
c_node = nodes_map.get(c_id)
if c_node:
c_node["properties"]["member_entity_ids"] = member_ids
c_name = c_node["properties"].get("name") or ""
# 补全属于该社区但 community_name 为空的实体(即 e2 节点)
for eid in member_ids:
e_node = nodes_map.get(eid)
if e_node and e_node["label"] == "ExtractedEntity":
if not e_node["properties"].get("community_name"):
e_node["properties"]["community_name"] = c_name
node_type_counts: Dict[str, int] = {}
for n in nodes:
node_type_counts[n["label"]] = node_type_counts.get(n["label"], 0) + 1
return {
"nodes": nodes,
"edges": edges,
"statistics": {
"total_nodes": len(nodes),
"total_edges": len(edges),
"node_types": node_type_counts,
}
}
except ValueError:
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
return {
"nodes": [], "edges": [],
"statistics": {"total_nodes": 0, "total_edges": 0, "node_types": {}, "edge_types": {}},
"message": "无效的用户ID格式"
}
except Exception as e:
logger.error(f"获取社区图谱数据失败: {str(e)}", exc_info=True)
raise
async def _extract_node_properties(label: str, properties: Dict[str, Any],node_id: str) -> Dict[str, Any]:
"""
根据节点类型提取需要的属性字段

View File

@@ -438,24 +438,26 @@ def update_last_login_time(db: Session, user_id: uuid.UUID) -> User:
async def change_password(db: Session, user_id: uuid.UUID, old_password: str, new_password: str, current_user: User) -> User:
"""普通用户修改自己的密码"""
from app.i18n.service import t
business_logger.info(f"用户修改密码请求: user_id={user_id}, current_user={current_user.id}")
# 检查权限:只能修改自己的密码
if current_user.id != user_id:
business_logger.warning(f"用户尝试修改他人密码: current_user={current_user.id}, target_user={user_id}")
raise PermissionDeniedException("You can only change your own password")
raise PermissionDeniedException(t("auth.password.change_failed"))
try:
# 获取用户
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
if not db_user:
business_logger.warning(f"用户不存在: {user_id}")
raise BusinessException("User not found", code=BizCode.USER_NOT_FOUND)
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
# 验证旧密码
if not verify_password(old_password, db_user.hashed_password):
business_logger.warning(f"用户旧密码验证失败: {user_id}")
raise BusinessException("当前密码不正确", code=BizCode.VALIDATION_FAILED)
raise BusinessException(t("auth.password.incorrect"), code=BizCode.VALIDATION_FAILED)
# 更新密码
db_user.hashed_password = get_password_hash(new_password)
@@ -471,7 +473,7 @@ async def change_password(db: Session, user_id: uuid.UUID, old_password: str, ne
except Exception as e:
business_logger.error(f"修改用户密码失败: user_id={user_id} - {str(e)}")
db.rollback()
raise BusinessException(f"修改用户密码失败: user_id={user_id} - {str(e)}", code=BizCode.DB_ERROR)
raise BusinessException(t("auth.password.change_failed"), code=BizCode.DB_ERROR)
async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_password: str = None, current_user: User = None) -> tuple[User, str]:
@@ -487,6 +489,8 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
Returns:
tuple[User, str]: (更新后的用户对象, 实际使用的密码)
"""
from app.i18n.service import t
business_logger.info(f"管理员修改用户密码请求: admin={current_user.id}, target_user={target_user_id}")
# 检查权限:只有超级管理员可以修改他人密码
@@ -496,7 +500,7 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
try:
permission_service.check_superuser(
subject,
error_message="只有超级管理员可以修改他人密码"
error_message=t("auth.password.change_failed")
)
except PermissionDeniedException as e:
business_logger.warning(f"非超管用户尝试修改他人密码: current_user={current_user.id}")
@@ -507,12 +511,12 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
target_user = user_repository.get_user_by_id(db=db, user_id=target_user_id)
if not target_user:
business_logger.warning(f"目标用户不存在: {target_user_id}")
raise BusinessException("目标用户不存在", code=BizCode.USER_NOT_FOUND)
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
# 检查租户权限:超管只能修改同租户用户的密码
if current_user.tenant_id != target_user.tenant_id:
business_logger.warning(f"跨租户密码修改尝试: admin_tenant={current_user.tenant_id}, target_tenant={target_user.tenant_id}")
raise BusinessException("不可跨租户修改用户密码", code=BizCode.FORBIDDEN)
raise BusinessException(t("auth.password.change_failed"), code=BizCode.FORBIDDEN)
# 如果没有提供新密码,则生成随机密码
actual_password = new_password if new_password else generate_random_password()
@@ -532,7 +536,7 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
except Exception as e:
business_logger.error(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}")
db.rollback()
raise BusinessException(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}", code=BizCode.DB_ERROR)
raise BusinessException(t("auth.password.change_failed"), code=BizCode.DB_ERROR)
def generate_random_password(length: int = 12) -> str:
@@ -740,3 +744,54 @@ async def verify_and_change_email(db: Session, user_id: uuid.UUID, new_email: Em
#
# business_logger.info(f"用户邮箱修改成功: {db_user.username}, new_email={new_email}")
# return db_user
def get_user_language_preference(db: Session, user_id: uuid.UUID, current_user: User) -> str:
"""获取用户语言偏好"""
business_logger.info(f"获取用户语言偏好: user_id={user_id}")
# 权限检查:只能获取自己的语言偏好
if current_user.id != user_id:
raise PermissionDeniedException("只能获取自己的语言偏好")
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
if not db_user:
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
language = db_user.preferred_language or "zh"
business_logger.info(f"用户语言偏好: {db_user.username}, language={language}")
return language
def update_user_language_preference(
db: Session,
user_id: uuid.UUID,
language: str,
current_user: User
) -> User:
"""更新用户语言偏好"""
business_logger.info(f"更新用户语言偏好: user_id={user_id}, language={language}")
# 权限检查:只能修改自己的语言偏好
if current_user.id != user_id:
raise PermissionDeniedException("只能修改自己的语言偏好")
# 验证语言代码是否支持
from app.core.config import settings
if language not in settings.I18N_SUPPORTED_LANGUAGES:
raise BusinessException(
f"不支持的语言代码: {language}。支持的语言: {', '.join(settings.I18N_SUPPORTED_LANGUAGES)}",
code=BizCode.VALIDATION_FAILED
)
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
if not db_user:
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
# 更新语言偏好
db_user.preferred_language = language
db.commit()
db.refresh(db_user)
business_logger.info(f"用户语言偏好更新成功: {db_user.username}, language={language}")
return db_user