Merge branch 'develop' into release/v0.2.7
This commit is contained in:
@@ -8,25 +8,21 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db
|
||||
from app.models import MultiAgentConfig, AgentConfig
|
||||
from app.models import MultiAgentConfig, AgentConfig, ModelType
|
||||
from app.models import WorkflowConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \
|
||||
AgentRunService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.tool_service import ToolService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
|
||||
logger = get_business_logger()
|
||||
@@ -126,8 +122,17 @@ class AppChatService:
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 调用 Agent(支持多模态)
|
||||
@@ -266,8 +271,17 @@ class AppChatService:
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 流式调用 Agent(支持多模态)
|
||||
|
||||
@@ -12,7 +12,7 @@ import uuid
|
||||
from typing import Annotated, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy import and_, delete, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -102,7 +102,8 @@ class AppService:
|
||||
# 2. 检查是否是共享给本工作空间的应用
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app.id,
|
||||
AppShare.target_workspace_id == workspace_id
|
||||
AppShare.target_workspace_id == workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
share = self.db.scalars(stmt).first()
|
||||
|
||||
@@ -125,6 +126,50 @@ class AppService:
|
||||
)
|
||||
raise BusinessException("应用不可访问", BizCode.WORKSPACE_NO_ACCESS)
|
||||
|
||||
def _get_share_permission(self, app: App, workspace_id: Optional[uuid.UUID]) -> Optional[str]:
|
||||
"""获取共享应用的权限
|
||||
|
||||
Returns:
|
||||
None: 不是共享应用(是本工作空间的应用)
|
||||
'readonly': 只读共享
|
||||
'editable': 可编辑共享
|
||||
"""
|
||||
from app.models import AppShare
|
||||
|
||||
if workspace_id is None or app.workspace_id == workspace_id:
|
||||
return None # 本工作空间的应用,不是共享的
|
||||
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app.id,
|
||||
AppShare.target_workspace_id == workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
share = self.db.scalars(stmt).first()
|
||||
return share.permission if share else None
|
||||
|
||||
def _validate_app_writable(self, app: App, workspace_id: Optional[uuid.UUID]) -> None:
|
||||
"""Validate that the app config is writable (owner only).
|
||||
|
||||
Shared apps (both readonly and editable) cannot modify config.
|
||||
- Own workspace app: allowed
|
||||
- Any shared app: denied
|
||||
|
||||
Raises:
|
||||
BusinessException: when app is not writable
|
||||
"""
|
||||
if workspace_id is None:
|
||||
return
|
||||
|
||||
# Own workspace app, allow
|
||||
if app.workspace_id == workspace_id:
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"应用写操作被拒",
|
||||
extra={"app_id": str(app.id), "workspace_id": str(workspace_id)}
|
||||
)
|
||||
raise BusinessException("共享应用不可修改配置", BizCode.WORKSPACE_NO_ACCESS)
|
||||
|
||||
def _get_app_or_404(self, app_id: uuid.UUID) -> App:
|
||||
"""获取应用或抛出404异常
|
||||
|
||||
@@ -454,6 +499,33 @@ class AppService:
|
||||
Returns:
|
||||
app_schema.App: 应用 Schema
|
||||
"""
|
||||
is_shared = app.workspace_id != current_workspace_id
|
||||
share_permission = None
|
||||
source_workspace_name = None
|
||||
source_workspace_icon = None
|
||||
source_app_version = None
|
||||
source_app_is_active = None
|
||||
|
||||
if is_shared:
|
||||
# 查询共享权限和来源工作空间名称
|
||||
from app.models import AppShare
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app.id,
|
||||
AppShare.target_workspace_id == current_workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
share = self.db.scalars(stmt).first()
|
||||
if share:
|
||||
share_permission = share.permission
|
||||
if share.source_workspace:
|
||||
source_workspace_name = share.source_workspace.name
|
||||
source_workspace_icon = share.source_workspace.icon
|
||||
|
||||
# 版本号和生效状态
|
||||
if app.current_release:
|
||||
source_app_version = app.current_release.version_name
|
||||
source_app_is_active = app.is_active
|
||||
|
||||
app_dict = {
|
||||
"id": app.id,
|
||||
"workspace_id": app.workspace_id,
|
||||
@@ -468,7 +540,12 @@ class AppService:
|
||||
"tags": app.tags or [],
|
||||
"current_release_id": app.current_release_id,
|
||||
"is_active": app.is_active,
|
||||
"is_shared": app.workspace_id != current_workspace_id, # 判断是否是共享应用
|
||||
"is_shared": is_shared,
|
||||
"share_permission": share_permission,
|
||||
"source_workspace_name": source_workspace_name,
|
||||
"source_workspace_icon": source_workspace_icon,
|
||||
"source_app_version": source_app_version,
|
||||
"source_app_is_active": source_app_is_active,
|
||||
"created_at": app.created_at,
|
||||
"updated_at": app.updated_at
|
||||
}
|
||||
@@ -594,7 +671,7 @@ class AppService:
|
||||
logger.info("更新应用", extra={"app_id": str(app_id)})
|
||||
|
||||
app = self._get_app_or_404(app_id)
|
||||
self._validate_workspace_access(app, workspace_id)
|
||||
self._validate_app_writable(app, workspace_id)
|
||||
|
||||
changed = False
|
||||
for field in ["name", "description", "icon", "icon_type", "visibility", "status", "tags"]:
|
||||
@@ -804,6 +881,7 @@ class AppService:
|
||||
status: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
include_shared: bool = True,
|
||||
shared_only: bool = False,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
) -> Tuple[List[App], int]:
|
||||
@@ -849,18 +927,24 @@ class AppService:
|
||||
if search:
|
||||
filters.append(func.lower(App.name).like(f"%{search.lower()}%"))
|
||||
|
||||
# 基础查询:本工作空间的应用
|
||||
if include_shared:
|
||||
# 查询本工作空间的应用 + 分享给本工作空间的应用
|
||||
# 使用 OR 条件:workspace_id = current OR app_id IN (shared apps)
|
||||
# shared_only implies include_shared; enforce to avoid confusing API usage
|
||||
if shared_only:
|
||||
include_shared = True
|
||||
|
||||
# 获取分享给本工作空间的应用ID列表
|
||||
# 基础查询:本工作空间的应用
|
||||
if shared_only:
|
||||
# 只返回共享给本工作空间的应用,不含自有应用
|
||||
shared_app_ids_stmt = (
|
||||
select(AppShare.source_app_id)
|
||||
.where(AppShare.target_workspace_id == workspace_id)
|
||||
.where(AppShare.target_workspace_id == workspace_id, AppShare.is_active.is_(True))
|
||||
)
|
||||
stmt = select(App).where(App.id.in_(shared_app_ids_stmt))
|
||||
elif include_shared:
|
||||
# 查询本工作空间的应用 + 分享给本工作空间的应用
|
||||
shared_app_ids_stmt = (
|
||||
select(AppShare.source_app_id)
|
||||
.where(AppShare.target_workspace_id == workspace_id, AppShare.is_active.is_(True))
|
||||
)
|
||||
|
||||
# 构建主查询:本工作空间的应用 OR 分享的应用
|
||||
stmt = select(App).where(
|
||||
or_(
|
||||
App.workspace_id == workspace_id,
|
||||
@@ -952,7 +1036,7 @@ class AppService:
|
||||
if app.type != "agent":
|
||||
raise BusinessException("只有 Agent 类型应用支持 Agent 配置", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
self._validate_workspace_access(app, workspace_id)
|
||||
self._validate_app_writable(app, workspace_id)
|
||||
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active.is_(True)).order_by(
|
||||
AgentConfig.updated_at.desc())
|
||||
@@ -1163,7 +1247,7 @@ class AppService:
|
||||
if app.type != AppType.WORKFLOW:
|
||||
raise BusinessException("只有 Workflow 类型应用支持 Workflow 配置", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
self._validate_workspace_access(app, workspace_id)
|
||||
self._validate_app_writable(app, workspace_id)
|
||||
|
||||
# 获取现有配置
|
||||
repo = WorkflowConfigRepository(self.db)
|
||||
@@ -1654,7 +1738,8 @@ class AppService:
|
||||
app_id: uuid.UUID,
|
||||
target_workspace_ids: List[uuid.UUID],
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
permission: str = "readonly"
|
||||
) -> list[AppShare]:
|
||||
"""分享应用到其他工作空间
|
||||
|
||||
@@ -1685,6 +1770,14 @@ class AppService:
|
||||
app = self._get_app_or_404(app_id)
|
||||
self._validate_workspace_access(app, workspace_id)
|
||||
|
||||
# 仅允许 agent 和 workflow 类型共享,multi_agent 不支持
|
||||
from app.models.app_model import AppType
|
||||
if app.type == AppType.MULTI_AGENT:
|
||||
raise BusinessException(
|
||||
"集群 Agent 不支持共享应用功能",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
# 2. 验证目标工作空间
|
||||
for target_ws_id in target_workspace_ids:
|
||||
target_ws = self.db.get(Workspace, target_ws_id)
|
||||
@@ -1706,7 +1799,8 @@ class AppService:
|
||||
# 检查是否已经分享过
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app_id,
|
||||
AppShare.target_workspace_id == target_ws_id
|
||||
AppShare.target_workspace_id == target_ws_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
existing_share = self.db.scalars(stmt).first()
|
||||
|
||||
@@ -1725,6 +1819,7 @@ class AppService:
|
||||
source_workspace_id=app.workspace_id,
|
||||
target_workspace_id=target_ws_id,
|
||||
shared_by=user_id,
|
||||
permission=permission,
|
||||
created_at=now,
|
||||
updated_at=now
|
||||
)
|
||||
@@ -1784,7 +1879,8 @@ class AppService:
|
||||
# 2. 查找分享记录
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app_id,
|
||||
AppShare.target_workspace_id == target_workspace_id
|
||||
AppShare.target_workspace_id == target_workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
share = self.db.scalars(stmt).first()
|
||||
|
||||
@@ -1798,8 +1894,8 @@ class AppService:
|
||||
f"app_id={app_id}, target_workspace_id={target_workspace_id}"
|
||||
)
|
||||
|
||||
# 3. 删除分享记录
|
||||
self.db.delete(share)
|
||||
# 3. 逻辑删除分享记录
|
||||
share.is_active = False
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
@@ -1807,6 +1903,48 @@ class AppService:
|
||||
extra={"app_id": str(app_id), "target_workspace_id": str(target_workspace_id)}
|
||||
)
|
||||
|
||||
def unshare_all_apps_to_workspace(
|
||||
self,
|
||||
*,
|
||||
target_workspace_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> int:
|
||||
"""Cancel all app shares from current workspace to a target workspace.
|
||||
|
||||
Args:
|
||||
target_workspace_id: Target workspace ID to cancel all shares to
|
||||
workspace_id: Current workspace ID (source)
|
||||
|
||||
Returns:
|
||||
Number of share records deleted
|
||||
"""
|
||||
from app.models import AppShare
|
||||
|
||||
logger.info(
|
||||
"取消对目标工作空间的所有应用分享",
|
||||
extra={"target_workspace_id": str(target_workspace_id), "workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
# Query active records first for reliable count
|
||||
id_stmt = select(AppShare.id).where(
|
||||
AppShare.source_workspace_id == workspace_id,
|
||||
AppShare.target_workspace_id == target_workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
ids = list(self.db.scalars(id_stmt).all())
|
||||
count = len(ids)
|
||||
|
||||
if ids:
|
||||
# Soft delete: mark as inactive
|
||||
from sqlalchemy import update as sa_update
|
||||
self.db.execute(
|
||||
sa_update(AppShare).where(AppShare.id.in_(ids)).values(is_active=False)
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
logger.info("已取消分享记录数", extra={"count": count})
|
||||
return count
|
||||
|
||||
def list_app_shares(
|
||||
self,
|
||||
*,
|
||||
@@ -1836,7 +1974,8 @@ class AppService:
|
||||
|
||||
# 查询分享记录
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app_id
|
||||
AppShare.source_app_id == app_id,
|
||||
AppShare.is_active.is_(True)
|
||||
).order_by(AppShare.created_at.desc())
|
||||
|
||||
shares = list(self.db.scalars(stmt).all())
|
||||
@@ -1848,6 +1987,166 @@ class AppService:
|
||||
|
||||
return shares
|
||||
|
||||
def remove_shared_app(
|
||||
self,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> None:
|
||||
"""被共享者从自己的工作空间移除共享应用
|
||||
|
||||
只删除共享记录,不影响源应用。
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
workspace_id: 当前工作空间ID(被共享的目标工作空间)
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: 当共享记录不存在时
|
||||
"""
|
||||
from app.models import AppShare
|
||||
|
||||
logger.info(
|
||||
"移除共享应用",
|
||||
extra={"app_id": str(app_id), "workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app_id,
|
||||
AppShare.target_workspace_id == workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
share = self.db.scalars(stmt).first()
|
||||
|
||||
if not share:
|
||||
raise ResourceNotFoundException(
|
||||
"共享记录",
|
||||
f"app_id={app_id}, workspace_id={workspace_id}"
|
||||
)
|
||||
|
||||
# Soft delete
|
||||
share.is_active = False
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
"共享应用已移除",
|
||||
extra={"app_id": str(app_id), "workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
def remove_all_shared_apps_from_workspace(
|
||||
self,
|
||||
*,
|
||||
source_workspace_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> int:
|
||||
"""Remove all shared apps from a specific source workspace.
|
||||
|
||||
Args:
|
||||
source_workspace_id: The workspace that shared the apps
|
||||
workspace_id: Current workspace ID (recipient)
|
||||
|
||||
Returns:
|
||||
Number of share records deleted
|
||||
"""
|
||||
from app.models import AppShare
|
||||
|
||||
logger.info(
|
||||
"批量移除来源工作空间的共享应用",
|
||||
extra={"source_workspace_id": str(source_workspace_id), "workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
# Query active records for reliable count, then soft delete
|
||||
id_stmt = select(AppShare.id).where(
|
||||
AppShare.source_workspace_id == source_workspace_id,
|
||||
AppShare.target_workspace_id == workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
ids = list(self.db.scalars(id_stmt).all())
|
||||
count = len(ids)
|
||||
|
||||
if ids:
|
||||
from sqlalchemy import update as sa_update
|
||||
self.db.execute(
|
||||
sa_update(AppShare).where(AppShare.id.in_(ids)).values(is_active=False)
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
logger.info("已移除共享记录数", extra={"count": count})
|
||||
return count
|
||||
|
||||
def list_my_shared_out(
|
||||
self,
|
||||
*,
|
||||
workspace_id: uuid.UUID
|
||||
) -> List[AppShare]:
|
||||
"""列出本工作空间主动分享出去的所有记录(我的共享)
|
||||
|
||||
Returns:
|
||||
List[AppShare]: 分享记录列表,含源应用信息
|
||||
"""
|
||||
from app.models import AppShare
|
||||
|
||||
stmt = (
|
||||
select(AppShare)
|
||||
.where(
|
||||
AppShare.source_workspace_id == workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
.order_by(AppShare.created_at.desc())
|
||||
)
|
||||
return list(self.db.scalars(stmt).all())
|
||||
def update_share_permission(
|
||||
self,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
target_workspace_id: uuid.UUID,
|
||||
permission: str,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> "AppShare":
|
||||
"""更新共享权限(readonly <-> editable)
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
target_workspace_id: 目标工作空间ID
|
||||
permission: 新权限值 readonly | editable
|
||||
workspace_id: 当前工作空间ID(用于权限验证)
|
||||
|
||||
Returns:
|
||||
AppShare: 更新后的共享记录
|
||||
"""
|
||||
from app.models import AppShare
|
||||
|
||||
if permission not in ("readonly", "editable"):
|
||||
raise BusinessException("权限值无效,只允许 readonly 或 editable", BizCode.INVALID_PARAMETER)
|
||||
|
||||
app = self._get_app_or_404(app_id)
|
||||
self._validate_workspace_access(app, workspace_id)
|
||||
|
||||
stmt = select(AppShare).where(
|
||||
AppShare.source_app_id == app_id,
|
||||
AppShare.target_workspace_id == target_workspace_id,
|
||||
AppShare.is_active.is_(True)
|
||||
)
|
||||
share = self.db.scalars(stmt).first()
|
||||
|
||||
if not share:
|
||||
raise ResourceNotFoundException(
|
||||
"共享记录",
|
||||
f"app_id={app_id}, target_workspace_id={target_workspace_id}"
|
||||
)
|
||||
|
||||
share.permission = permission
|
||||
share.updated_at = datetime.datetime.now()
|
||||
self.db.commit()
|
||||
self.db.refresh(share)
|
||||
|
||||
logger.info(
|
||||
"共享权限已更新",
|
||||
extra={"app_id": str(app_id), "target_workspace_id": str(target_workspace_id), "permission": permission}
|
||||
)
|
||||
return share
|
||||
|
||||
|
||||
# ==================== 向后兼容的函数接口 ====================
|
||||
# 保留函数接口以兼容现有代码,但内部使用服务类
|
||||
|
||||
@@ -1942,6 +2241,7 @@ def list_apps(
|
||||
status: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
include_shared: bool = True,
|
||||
shared_only: bool = False,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
) -> Tuple[List[App], int]:
|
||||
@@ -1954,6 +2254,7 @@ def list_apps(
|
||||
status=status,
|
||||
search=search,
|
||||
include_shared=include_shared,
|
||||
shared_only=shared_only,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
)
|
||||
|
||||
@@ -75,7 +75,7 @@ class AudioTranscriptionService:
|
||||
try:
|
||||
# 下载音频文件
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
audio_response = await client.get(audio_url)
|
||||
audio_response = await client.get(audio_url, follow_redirects=True)
|
||||
audio_response.raise_for_status()
|
||||
audio_data = audio_response.content
|
||||
|
||||
|
||||
@@ -80,6 +80,7 @@ def authenticate_user_or_raise(db: Session, email: str, password: str) -> User:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_auth_logger
|
||||
from app.i18n.service import t
|
||||
|
||||
logger = get_auth_logger()
|
||||
|
||||
@@ -87,17 +88,17 @@ def authenticate_user_or_raise(db: Session, email: str, password: str) -> User:
|
||||
user = user_repository.get_user_by_email(db, email=email)
|
||||
if not user:
|
||||
logger.warning(f"用户不存在: {email}")
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 检查用户状态
|
||||
if not user.is_active:
|
||||
logger.warning(f"用户未激活: {email}")
|
||||
raise BusinessException("用户未激活", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.login.account_disabled"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 验证密码
|
||||
if not verify_password(password, user.hashed_password):
|
||||
logger.warning(f"密码错误: {email}")
|
||||
raise BusinessException("密码错误", code=BizCode.PASSWORD_ERROR)
|
||||
raise BusinessException(t("auth.password.incorrect"), code=BizCode.PASSWORD_ERROR)
|
||||
|
||||
logger.info(f"用户认证成功: {email}")
|
||||
return user
|
||||
@@ -254,6 +255,8 @@ def decode_access_token(token: str) -> dict:
|
||||
Raises:
|
||||
BusinessException: token 无效
|
||||
"""
|
||||
from app.i18n.service import t
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, TOKEN_SECRET_KEY, algorithms=[TOKEN_ALGORITHM])
|
||||
return {
|
||||
@@ -261,4 +264,4 @@ def decode_access_token(token: str) -> dict:
|
||||
"share_token": payload["share_token"]
|
||||
}
|
||||
except jwt.InvalidTokenError:
|
||||
raise BusinessException("无效的访问 token", BizCode.INVALID_TOKEN)
|
||||
raise BusinessException(t("auth.token.invalid"), BizCode.INVALID_TOKEN)
|
||||
@@ -23,9 +23,10 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.models import AgentConfig, ModelConfig, ModelType
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services import task_service
|
||||
from app.services.conversation_service import ConversationService
|
||||
@@ -501,9 +502,18 @@ class AgentRunService:
|
||||
processed_files = None
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
@@ -688,7 +698,8 @@ class AgentRunService:
|
||||
conversation_id=conversation_id,
|
||||
app_id=agent_config.app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
sub_agent=sub_agent
|
||||
)
|
||||
|
||||
# 6. 加载历史消息
|
||||
@@ -703,9 +714,18 @@ class AgentRunService:
|
||||
processed_files = None
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(user_id, files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
# 7. 知识库检索
|
||||
@@ -840,7 +860,8 @@ class AgentRunService:
|
||||
"api_key": api_key.api_key,
|
||||
"api_base": api_key.api_base,
|
||||
"api_key_id": api_key.id,
|
||||
"is_omni": api_key.is_omni
|
||||
"is_omni": api_key.is_omni,
|
||||
"capability": api_key.capability
|
||||
}
|
||||
|
||||
async def _ensure_conversation(
|
||||
@@ -848,7 +869,8 @@ class AgentRunService:
|
||||
conversation_id: Optional[str],
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str]
|
||||
user_id: Optional[str],
|
||||
sub_agent: bool = False
|
||||
) -> str:
|
||||
"""确保会话存在(创建或验证)
|
||||
|
||||
@@ -909,20 +931,36 @@ class AgentRunService:
|
||||
conv_uuid = uuid.UUID(conversation_id)
|
||||
conversation = conversation_service.get_conversation(conv_uuid)
|
||||
|
||||
# 验证会话属于当前工作空间
|
||||
if conversation.workspace_id != workspace_id:
|
||||
logger.warning(
|
||||
"会话不属于当前工作空间",
|
||||
extra={
|
||||
"conversation_id": conversation_id,
|
||||
"conversation_workspace_id": str(conversation.workspace_id),
|
||||
"current_workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
raise BusinessException(
|
||||
"会话不属于当前工作空间",
|
||||
BizCode.PERMISSION_DENIED
|
||||
)
|
||||
# 验证会话属于当前工作空间(或属于共享应用的源工作空间)
|
||||
# sub_agent 内部调用时跳过校验,已在上层验证过
|
||||
if not sub_agent and conversation.workspace_id != workspace_id:
|
||||
# 检查是否是共享应用的会话(被共享者 workspace 访问源应用)
|
||||
from app.models import AppShare
|
||||
from sqlalchemy import select as sa_select
|
||||
share = self.db.scalars(
|
||||
sa_select(AppShare).where(
|
||||
AppShare.source_app_id == app_id,
|
||||
AppShare.target_workspace_id == workspace_id
|
||||
)
|
||||
).first()
|
||||
|
||||
# 情况2:sub_agent 内部调用时,workspace_id 是源应用的 workspace,
|
||||
# 而会话是被共享者创建的,只要会话属于同一个 app 即可放行
|
||||
same_app = (conversation.app_id == app_id)
|
||||
|
||||
if not share and not same_app:
|
||||
logger.warning(
|
||||
"会话不属于当前工作空间",
|
||||
extra={
|
||||
"conversation_id": conversation_id,
|
||||
"conversation_workspace_id": str(conversation.workspace_id),
|
||||
"current_workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
raise BusinessException(
|
||||
"会话不属于当前工作空间",
|
||||
BizCode.PERMISSION_DENIED
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"使用现有会话",
|
||||
|
||||
@@ -274,7 +274,7 @@ class MemoryAgentService:
|
||||
|
||||
Args:
|
||||
end_user_id: Group identifier (also used as end_user_id)
|
||||
message: Message to write
|
||||
messages: Message to write
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
|
||||
@@ -1,19 +1,27 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
import json_repair
|
||||
from jinja2 import Template
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
||||
from app.schemas import FileType
|
||||
from app.schemas.memory_perceptual_schema import (
|
||||
PerceptualQuerySchema,
|
||||
PerceptualTimelineResponse,
|
||||
PerceptualMemoryItem,
|
||||
AudioModal, Content, VideoModal, TextModal
|
||||
)
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
@@ -99,7 +107,7 @@ class MemoryPerceptualService:
|
||||
"keywords": content.keywords,
|
||||
"topic": content.topic,
|
||||
"domain": content.domain,
|
||||
"created_time": int(memory.created_time.timestamp()*1000),
|
||||
"created_time": int(memory.created_time.timestamp() * 1000),
|
||||
**detail
|
||||
}
|
||||
|
||||
@@ -108,7 +116,8 @@ class MemoryPerceptualService:
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}")
|
||||
business_logger.error(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
|
||||
exc_info=True)
|
||||
raise BusinessException(f"Failed to fetch latest {perceptual_type.name.lower()} memory: {str(e)}",
|
||||
BizCode.DB_ERROR)
|
||||
|
||||
@@ -138,7 +147,7 @@ class MemoryPerceptualService:
|
||||
for memory in memories:
|
||||
meta_data = memory.meta_data or {}
|
||||
content = meta_data.get("content", {})
|
||||
|
||||
|
||||
# 安全地提取 content 字段,提供默认值
|
||||
if content:
|
||||
content_obj = Content(**content)
|
||||
@@ -149,7 +158,7 @@ class MemoryPerceptualService:
|
||||
topic = "Unknown"
|
||||
domain = "Unknown"
|
||||
keywords = []
|
||||
|
||||
|
||||
memory_item = PerceptualMemoryItem(
|
||||
id=memory.id,
|
||||
perceptual_type=PerceptualType(memory.perceptual_type),
|
||||
@@ -161,7 +170,7 @@ class MemoryPerceptualService:
|
||||
topic=topic,
|
||||
domain=domain,
|
||||
keywords=keywords,
|
||||
created_time=int(memory.created_time.timestamp()*1000),
|
||||
created_time=int(memory.created_time.timestamp() * 1000),
|
||||
storage_service=FileStorageService(memory.storage_service),
|
||||
)
|
||||
memory_items.append(memory_item)
|
||||
@@ -183,3 +192,98 @@ class MemoryPerceptualService:
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to fetch perceptual memory timeline: {str(e)}")
|
||||
raise BusinessException(f"Failed to fetch perceptual memory timeline: {str(e)}", BizCode.DB_ERROR)
|
||||
|
||||
async def generate_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model_config: ModelInfo,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict,
|
||||
):
|
||||
memories = self.repository.get_by_url(file_url)
|
||||
if memories:
|
||||
business_logger.info(f"Perceptual memory already exists: {file_url}")
|
||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
||||
memory_cache = memories[0]
|
||||
self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
||||
file_path=memory_cache.file_path,
|
||||
file_name=memory_cache.file_name,
|
||||
file_ext=memory_cache.file_ext,
|
||||
summary=memory_cache.summary,
|
||||
meta_data=memory_cache.meta_data
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
return
|
||||
llm = RedBearLLM(RedBearModelConfig(
|
||||
model_name=model_config.model_name,
|
||||
provider=model_config.provider,
|
||||
api_key=model_config.api_key,
|
||||
base_url=model_config.api_base,
|
||||
is_omni=model_config.is_omni
|
||||
), type=model_config.model_type)
|
||||
try:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file_type, language='zh')
|
||||
except FileNotFoundError:
|
||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||
messages = [
|
||||
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
|
||||
{"role": RoleType.USER.value, "content": [
|
||||
{"type": "text", "text": "Summarize the following file"}, file_message
|
||||
]}
|
||||
]
|
||||
result = await llm.ainvoke(messages)
|
||||
content = json_repair.repair_json(result.content, return_objects=True)
|
||||
path = urlparse(file_url).path
|
||||
filename = os.path.basename(path)
|
||||
filename = unquote(filename)
|
||||
file_ext = os.path.splitext(filename)[1]
|
||||
if not file_ext:
|
||||
if file_type == FileType.AUDIO:
|
||||
file_ext = ".mp3"
|
||||
elif file_type == FileType.VIDEO:
|
||||
file_ext = ".mp4"
|
||||
elif file_type == FileType.DOCUMENT:
|
||||
file_ext = ".txt"
|
||||
elif file_type == FileType.IMAGE:
|
||||
file_ext = ".jpg"
|
||||
filename += file_ext
|
||||
file_content = {
|
||||
"keywords": content.get("keywords", []),
|
||||
"topic": content.get("topic"),
|
||||
"domain": content.get("domain")
|
||||
}
|
||||
if file_type in [FileType.IMAGE, FileType.VIDEO]:
|
||||
file_modalities = {
|
||||
"scene": content.get("scene")
|
||||
}
|
||||
elif file_type in [FileType.DOCUMENT]:
|
||||
file_modalities = {
|
||||
"section_count": content.get("section_count"),
|
||||
"title": content.get("title"),
|
||||
"first_line": content.get("first_line")
|
||||
}
|
||||
else:
|
||||
file_modalities = {
|
||||
"speaker_count": content.get("speaker_count")
|
||||
}
|
||||
self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType.trans_from_file_type(file_type),
|
||||
file_path=file_url,
|
||||
file_name=filename,
|
||||
file_ext=file_ext,
|
||||
summary=content.get('summary'),
|
||||
meta_data={
|
||||
"content": file_content,
|
||||
"modalities": file_modalities
|
||||
}
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
"""
|
||||
import base64
|
||||
import io
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
@@ -23,9 +24,12 @@ from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models import ModelApiKey
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||
from app.tasks import write_perceptual_memory
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -39,6 +43,7 @@ DOC_MIME = [
|
||||
|
||||
class MultimodalFormatStrategy(ABC):
|
||||
"""多模态格式策略基类"""
|
||||
|
||||
def __init__(self, file: FileInput):
|
||||
self.file = file
|
||||
|
||||
@@ -95,7 +100,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
if transcription:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
|
||||
"text": f"<audio url=\"{url}\">\ntext_transcription:{transcription}\n</audio>"
|
||||
}
|
||||
# 通义千问音频格式:{"type": "audio", "audio": "url"}
|
||||
return {
|
||||
@@ -125,7 +130,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
# 下载图片
|
||||
if content is None:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
self.file.set_content(content)
|
||||
@@ -231,7 +236,7 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
audio_data = content
|
||||
if content is None:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
audio_data = response.content
|
||||
self.file.set_content(audio_data)
|
||||
@@ -284,34 +289,56 @@ PROVIDER_STRATEGIES = {
|
||||
|
||||
|
||||
class MultimodalService:
|
||||
"""多模态文件处理服务"""
|
||||
"""
|
||||
Service for handling multimodal file processing.
|
||||
|
||||
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None,
|
||||
enable_audio_transcription: bool = False, is_omni: bool = False):
|
||||
Attributes:
|
||||
db (Session): Database session.
|
||||
model_api_key (str): API key for the model provider.
|
||||
provider (str): Name of the model provider.
|
||||
is_omni (bool): Indicates whether the model supports full multimodal capability.
|
||||
capability (list): Capability configuration of the model.
|
||||
audio_api_key (str | None): API key used for audio transcription.
|
||||
enable_audio_transcription (bool): Whether audio transcription is enabled.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
api_config: ModelInfo | None = None,
|
||||
audio_api_key: Optional[str] = None,
|
||||
enable_audio_transcription: bool = False,
|
||||
):
|
||||
"""
|
||||
初始化多模态服务
|
||||
|
||||
Initialize the multimodal service.
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider: 模型提供商(dashscope, bedrock, anthropic, openai 等)
|
||||
api_key: API 密钥(用于音频转文本)
|
||||
enable_audio_transcription: 是否启用音频转文本
|
||||
is_omni: 是否为 Omni 模型(dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
|
||||
db (Session): Database session.
|
||||
api_config (ModelApiKey | None): Model API configuration.
|
||||
audio_api_key (str | None): API key for audio transcription.
|
||||
enable_audio_transcription (bool): Enable audio transcription.
|
||||
"""
|
||||
self.db = db
|
||||
self.provider = provider.lower()
|
||||
self.api_key = api_key
|
||||
self.api_config = api_config
|
||||
if self.api_config is not None:
|
||||
self.model_api_key = api_config.api_key
|
||||
self.provider = api_config.provider.lower()
|
||||
self.is_omni = api_config.is_omni
|
||||
self.capability = api_config.capability
|
||||
self.audio_api_key = audio_api_key
|
||||
self.enable_audio_transcription = enable_audio_transcription
|
||||
self.is_omni = is_omni
|
||||
|
||||
async def process_files(
|
||||
self,
|
||||
files: Optional[List[FileInput]]
|
||||
end_user_id: uuid.UUID | str,
|
||||
files: Optional[List[FileInput]],
|
||||
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理文件列表,返回 LLM 可用的格式
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
files: 文件输入列表
|
||||
|
||||
Returns:
|
||||
@@ -319,6 +346,8 @@ class MultimodalService:
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
if isinstance(end_user_id, uuid.UUID):
|
||||
end_user_id = str(end_user_id)
|
||||
|
||||
# 获取对应的策略
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||
@@ -333,19 +362,25 @@ class MultimodalService:
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
strategy = strategy_class(file)
|
||||
if not file.url:
|
||||
file.url = await self.get_file_url(file)
|
||||
try:
|
||||
if file.type == FileType.IMAGE:
|
||||
if file.type == FileType.IMAGE and "vision" in self.capability:
|
||||
content = await self._process_image(file, strategy)
|
||||
result.append(content)
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
content = await self._process_document(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.AUDIO:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||
content = await self._process_audio(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.VIDEO:
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
elif file.type == FileType.VIDEO and "video" in self.capability:
|
||||
content = await self._process_video(file, strategy)
|
||||
result.append(content)
|
||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {file.type}")
|
||||
except Exception as e:
|
||||
@@ -355,7 +390,8 @@ class MultimodalService:
|
||||
"file_index": idx,
|
||||
"file_type": file.type,
|
||||
"error": str(e)
|
||||
}
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
# 继续处理其他文件,不中断整个流程
|
||||
result.append({
|
||||
@@ -366,6 +402,17 @@ class MultimodalService:
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
def write_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict
|
||||
):
|
||||
"""写入感知记忆"""
|
||||
if end_user_id and self.api_config:
|
||||
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
|
||||
|
||||
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理图片文件
|
||||
@@ -387,43 +434,6 @@ class MultimodalService:
|
||||
"text": f"[图片处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _download_and_encode_image(url: str) -> tuple[str, str]:
|
||||
"""
|
||||
下载图片并转换为 base64
|
||||
|
||||
Args:
|
||||
url: 图片 URL
|
||||
|
||||
Returns:
|
||||
tuple: (base64_data, media_type)
|
||||
"""
|
||||
from mimetypes import guess_type
|
||||
|
||||
# 下载图片
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# 获取图片数据
|
||||
image_data = response.content
|
||||
|
||||
# 确定 media type
|
||||
content_type = response.headers.get("content-type")
|
||||
if content_type and content_type.startswith("image/"):
|
||||
media_type = content_type
|
||||
else:
|
||||
# 从 URL 推断
|
||||
guessed_type, _ = guess_type(url)
|
||||
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||
|
||||
# 转换为 base64
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||
|
||||
return base64_data, media_type
|
||||
|
||||
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理文档文件(PDF、Word 等)
|
||||
@@ -436,7 +446,6 @@ class MultimodalService:
|
||||
Dict: 根据 provider 返回不同格式的文档内容
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
# 远程文档暂不支持提取
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
||||
@@ -471,12 +480,12 @@ class MultimodalService:
|
||||
|
||||
# 如果启用音频转文本且有 API Key
|
||||
transcription = None
|
||||
if self.enable_audio_transcription and self.api_key:
|
||||
if self.enable_audio_transcription and self.audio_api_key:
|
||||
logger.info(f"开始音频转文本: {url}")
|
||||
if self.provider == "dashscope":
|
||||
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
|
||||
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.audio_api_key)
|
||||
elif self.provider == "openai":
|
||||
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
|
||||
transcription = await AudioTranscriptionService.transcribe_openai(url, self.audio_api_key)
|
||||
else:
|
||||
logger.warning(f"Provider {self.provider} 不支持音频转文本")
|
||||
|
||||
@@ -557,7 +566,7 @@ class MultimodalService:
|
||||
file_content = file.get_content()
|
||||
if not file_content:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file.url)
|
||||
response = await client.get(file.url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
file_content = response.content
|
||||
file.set_content(file_content)
|
||||
|
||||
53
api/app/services/prompt/perceptual_summary_system.jinja2
Normal file
53
api/app/services/prompt/perceptual_summary_system.jinja2
Normal file
@@ -0,0 +1,53 @@
|
||||
{% raw %}You are a professional information extraction system.
|
||||
|
||||
Your task is to analyze the provided document content and generate structured metadata.
|
||||
|
||||
Extract the following fields:
|
||||
|
||||
* **summary**: A concise summary of the document in 2–4 sentences.
|
||||
* **keywords**: 5–10 important keywords or key phrases that best represent the document. This field MUST be a JSON array of strings.
|
||||
* **topic**: The primary topic of the document expressed as a short phrase (3–8 words).
|
||||
* **domain**: The broader knowledge domain or field the document belongs to (e.g., Artificial Intelligence, Computer Science, Finance, Healthcare, Education, Law, etc.).
|
||||
|
||||
STRICT RULES:
|
||||
|
||||
1. Output MUST be valid JSON.
|
||||
2. Do NOT output markdown.
|
||||
3. Do NOT output explanations.
|
||||
4. Do NOT output any text before or after the JSON.
|
||||
5. The JSON MUST contain EXACTLY these four keys:
|
||||
* summary
|
||||
* keywords
|
||||
* topic
|
||||
* domain{% endraw %}
|
||||
{% if file_type == 'image' or file_type == 'video' %} * scene {% endif %}
|
||||
{% if file_type == 'audio' %} * speaker_count {% endif %}
|
||||
{% if file_type == 'document' %} * section_count
|
||||
* title
|
||||
* first_line
|
||||
{% endif %}
|
||||
{% raw %}
|
||||
6. `keywords` MUST be a JSON array of strings.
|
||||
7. If the document content is insufficient, infer the best possible answer based on context.
|
||||
8. Ensure the JSON is syntactically correct.
|
||||
{% endraw %}
|
||||
9. Output using the language {{ language }}
|
||||
{% raw %}
|
||||
Required JSON format:
|
||||
|
||||
{
|
||||
"summary": "string",
|
||||
"keywords": ["keyword1", "keyword2", "keyword3", "keyword4", "keyword5"],
|
||||
"topic": "string",
|
||||
"domain": "string",
|
||||
{% endraw %}
|
||||
{% if file_type == 'image' or file_type == 'video' %} "scene": ["string", "string"] {% endif %}
|
||||
{% if file_type == 'document' %} "section_count": integer
|
||||
"title": "string",
|
||||
"first_line": "string"
|
||||
{% endif %}
|
||||
{% if file_type == 'audio' %} "speaker_count": integer {% endif %}
|
||||
{% raw %}
|
||||
}
|
||||
|
||||
Now analyze the following document and return the JSON result.{% endraw %}
|
||||
@@ -217,4 +217,55 @@ class TenantService:
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active
|
||||
)
|
||||
)
|
||||
|
||||
def get_tenant_language_config(self, tenant_id: uuid.UUID) -> Optional[dict]:
|
||||
"""获取租户语言配置"""
|
||||
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
|
||||
if not tenant:
|
||||
raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND)
|
||||
|
||||
return {
|
||||
"default_language": tenant.default_language,
|
||||
"supported_languages": tenant.supported_languages
|
||||
}
|
||||
|
||||
def update_tenant_language_config(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
default_language: str,
|
||||
supported_languages: list
|
||||
) -> Optional[dict]:
|
||||
"""更新租户语言配置"""
|
||||
# 检查租户是否存在
|
||||
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
|
||||
if not tenant:
|
||||
raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND)
|
||||
|
||||
# 验证默认语言在支持的语言列表中
|
||||
if default_language not in supported_languages:
|
||||
raise BusinessException(
|
||||
"默认语言必须在支持的语言列表中",
|
||||
code=BizCode.VALIDATION_FAILED
|
||||
)
|
||||
|
||||
try:
|
||||
# 更新语言配置
|
||||
tenant.default_language = default_language
|
||||
tenant.supported_languages = supported_languages
|
||||
self.db.commit()
|
||||
self.db.refresh(tenant)
|
||||
|
||||
business_logger.info(
|
||||
f"更新租户语言配置成功: {tenant.name} (ID: {tenant.id}), "
|
||||
f"默认语言: {default_language}, 支持语言: {supported_languages}"
|
||||
)
|
||||
|
||||
return {
|
||||
"default_language": tenant.default_language,
|
||||
"supported_languages": tenant.supported_languages
|
||||
}
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
business_logger.error(f"更新租户语言配置失败: {str(e)}")
|
||||
raise BusinessException(f"更新租户语言配置失败: {str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
@@ -1727,6 +1727,150 @@ async def analytics_graph_data(
|
||||
|
||||
# 辅助函数
|
||||
|
||||
async def analytics_community_graph_data(
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取社区图谱数据,包含 Community 节点、ExtractedEntity 节点及其关系。
|
||||
|
||||
Returns:
|
||||
包含 nodes、edges、statistics 的字典,格式与 analytics_graph_data 一致
|
||||
"""
|
||||
try:
|
||||
user_uuid = uuid.UUID(end_user_id)
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(user_uuid)
|
||||
if not end_user:
|
||||
return {
|
||||
"nodes": [], "edges": [],
|
||||
"statistics": {"total_nodes": 0, "total_edges": 0, "node_types": {}, "edge_types": {}},
|
||||
"message": "用户不存在"
|
||||
}
|
||||
|
||||
# 查询社区节点、实体节点、BELONGS_TO_COMMUNITY 边、实体间关系
|
||||
from app.repositories.neo4j.cypher_queries import GET_COMMUNITY_GRAPH_DATA
|
||||
rows = await _neo4j_connector.execute_query(GET_COMMUNITY_GRAPH_DATA, end_user_id=end_user_id)
|
||||
|
||||
nodes_map: Dict[str, dict] = {}
|
||||
edges_map: Dict[str, dict] = {}
|
||||
# 记录每个 Community 对应的实体 id 列表
|
||||
community_members: Dict[str, list] = {}
|
||||
|
||||
for row in rows:
|
||||
# Community 节点
|
||||
c_id = row["c_id"]
|
||||
if c_id and c_id not in nodes_map:
|
||||
raw = row["c_props"] or {}
|
||||
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
|
||||
"community_id", "end_user_id", "member_count", "updated_at",
|
||||
"name", "summary", "core_entities",
|
||||
) if k in raw}
|
||||
nodes_map[c_id] = {
|
||||
"id": c_id,
|
||||
"label": "Community",
|
||||
"properties": props,
|
||||
}
|
||||
|
||||
# ExtractedEntity 节点 (e)
|
||||
e_id = row["e_id"]
|
||||
if e_id and e_id not in nodes_map:
|
||||
raw = row["e_props"] or {}
|
||||
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
|
||||
"name", "end_user_id", "description", "created_at", "entity_type",
|
||||
) if k in raw}
|
||||
# 注入所属社区名称(c 是 e 直接归属的社区)
|
||||
c_raw = row["c_props"] or {}
|
||||
props["community_name"] = _clean_neo4j_value(c_raw.get("name")) or ""
|
||||
nodes_map[e_id] = {
|
||||
"id": e_id,
|
||||
"label": "ExtractedEntity",
|
||||
"properties": props,
|
||||
}
|
||||
|
||||
# ExtractedEntity 节点 (e2,可选)
|
||||
e2_id = row.get("e2_id")
|
||||
if e2_id and e2_id not in nodes_map:
|
||||
raw = row["e2_props"] or {}
|
||||
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
|
||||
"name", "end_user_id", "description", "created_at", "entity_type",
|
||||
) if k in raw}
|
||||
# e2 的社区归属在后处理阶段通过 community_members 补充
|
||||
props["community_name"] = ""
|
||||
nodes_map[e2_id] = {
|
||||
"id": e2_id,
|
||||
"label": "ExtractedEntity",
|
||||
"properties": props,
|
||||
}
|
||||
|
||||
# BELONGS_TO_COMMUNITY 边
|
||||
b_id = row["b_id"]
|
||||
if b_id and b_id not in edges_map:
|
||||
edges_map[b_id] = {
|
||||
"id": b_id,
|
||||
"source": e_id,
|
||||
"target": c_id,
|
||||
}
|
||||
# 收集社区成员 id
|
||||
if c_id and e_id:
|
||||
community_members.setdefault(c_id, [])
|
||||
if e_id not in community_members[c_id]:
|
||||
community_members[c_id].append(e_id)
|
||||
|
||||
# EXTRACTED_RELATIONSHIP 边(可选)
|
||||
r_id = row.get("r_id")
|
||||
if r_id and r_id not in edges_map and e2_id:
|
||||
r_props = {k: _clean_neo4j_value(v) for k, v in (row["r_props"] or {}).items()}
|
||||
source = e_id if row.get("r_from_e") else e2_id
|
||||
target = e2_id if row.get("r_from_e") else e_id
|
||||
edges_map[r_id] = {
|
||||
"id": r_id,
|
||||
"source": source,
|
||||
"target": target,
|
||||
}
|
||||
|
||||
nodes = list(nodes_map.values())
|
||||
edges = list(edges_map.values())
|
||||
|
||||
# 为每个 Community 节点注入 member_entity_ids,同时补全 e2 节点的 community_name
|
||||
for c_id, member_ids in community_members.items():
|
||||
c_node = nodes_map.get(c_id)
|
||||
if c_node:
|
||||
c_node["properties"]["member_entity_ids"] = member_ids
|
||||
c_name = c_node["properties"].get("name") or ""
|
||||
# 补全属于该社区但 community_name 为空的实体(即 e2 节点)
|
||||
for eid in member_ids:
|
||||
e_node = nodes_map.get(eid)
|
||||
if e_node and e_node["label"] == "ExtractedEntity":
|
||||
if not e_node["properties"].get("community_name"):
|
||||
e_node["properties"]["community_name"] = c_name
|
||||
|
||||
node_type_counts: Dict[str, int] = {}
|
||||
for n in nodes:
|
||||
node_type_counts[n["label"]] = node_type_counts.get(n["label"], 0) + 1
|
||||
|
||||
return {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"statistics": {
|
||||
"total_nodes": len(nodes),
|
||||
"total_edges": len(edges),
|
||||
"node_types": node_type_counts,
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError:
|
||||
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
|
||||
return {
|
||||
"nodes": [], "edges": [],
|
||||
"statistics": {"total_nodes": 0, "total_edges": 0, "node_types": {}, "edge_types": {}},
|
||||
"message": "无效的用户ID格式"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取社区图谱数据失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def _extract_node_properties(label: str, properties: Dict[str, Any],node_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
根据节点类型提取需要的属性字段
|
||||
|
||||
@@ -438,24 +438,26 @@ def update_last_login_time(db: Session, user_id: uuid.UUID) -> User:
|
||||
|
||||
async def change_password(db: Session, user_id: uuid.UUID, old_password: str, new_password: str, current_user: User) -> User:
|
||||
"""普通用户修改自己的密码"""
|
||||
from app.i18n.service import t
|
||||
|
||||
business_logger.info(f"用户修改密码请求: user_id={user_id}, current_user={current_user.id}")
|
||||
|
||||
# 检查权限:只能修改自己的密码
|
||||
if current_user.id != user_id:
|
||||
business_logger.warning(f"用户尝试修改他人密码: current_user={current_user.id}, target_user={user_id}")
|
||||
raise PermissionDeniedException("You can only change your own password")
|
||||
raise PermissionDeniedException(t("auth.password.change_failed"))
|
||||
|
||||
try:
|
||||
# 获取用户
|
||||
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
if not db_user:
|
||||
business_logger.warning(f"用户不存在: {user_id}")
|
||||
raise BusinessException("User not found", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 验证旧密码
|
||||
if not verify_password(old_password, db_user.hashed_password):
|
||||
business_logger.warning(f"用户旧密码验证失败: {user_id}")
|
||||
raise BusinessException("当前密码不正确", code=BizCode.VALIDATION_FAILED)
|
||||
raise BusinessException(t("auth.password.incorrect"), code=BizCode.VALIDATION_FAILED)
|
||||
|
||||
# 更新密码
|
||||
db_user.hashed_password = get_password_hash(new_password)
|
||||
@@ -471,7 +473,7 @@ async def change_password(db: Session, user_id: uuid.UUID, old_password: str, ne
|
||||
except Exception as e:
|
||||
business_logger.error(f"修改用户密码失败: user_id={user_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise BusinessException(f"修改用户密码失败: user_id={user_id} - {str(e)}", code=BizCode.DB_ERROR)
|
||||
raise BusinessException(t("auth.password.change_failed"), code=BizCode.DB_ERROR)
|
||||
|
||||
|
||||
async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_password: str = None, current_user: User = None) -> tuple[User, str]:
|
||||
@@ -487,6 +489,8 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
|
||||
Returns:
|
||||
tuple[User, str]: (更新后的用户对象, 实际使用的密码)
|
||||
"""
|
||||
from app.i18n.service import t
|
||||
|
||||
business_logger.info(f"管理员修改用户密码请求: admin={current_user.id}, target_user={target_user_id}")
|
||||
|
||||
# 检查权限:只有超级管理员可以修改他人密码
|
||||
@@ -496,7 +500,7 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
|
||||
try:
|
||||
permission_service.check_superuser(
|
||||
subject,
|
||||
error_message="只有超级管理员可以修改他人密码"
|
||||
error_message=t("auth.password.change_failed")
|
||||
)
|
||||
except PermissionDeniedException as e:
|
||||
business_logger.warning(f"非超管用户尝试修改他人密码: current_user={current_user.id}")
|
||||
@@ -507,12 +511,12 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
|
||||
target_user = user_repository.get_user_by_id(db=db, user_id=target_user_id)
|
||||
if not target_user:
|
||||
business_logger.warning(f"目标用户不存在: {target_user_id}")
|
||||
raise BusinessException("目标用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 检查租户权限:超管只能修改同租户用户的密码
|
||||
if current_user.tenant_id != target_user.tenant_id:
|
||||
business_logger.warning(f"跨租户密码修改尝试: admin_tenant={current_user.tenant_id}, target_tenant={target_user.tenant_id}")
|
||||
raise BusinessException("不可跨租户修改用户密码", code=BizCode.FORBIDDEN)
|
||||
raise BusinessException(t("auth.password.change_failed"), code=BizCode.FORBIDDEN)
|
||||
|
||||
# 如果没有提供新密码,则生成随机密码
|
||||
actual_password = new_password if new_password else generate_random_password()
|
||||
@@ -532,7 +536,7 @@ async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_pass
|
||||
except Exception as e:
|
||||
business_logger.error(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise BusinessException(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}", code=BizCode.DB_ERROR)
|
||||
raise BusinessException(t("auth.password.change_failed"), code=BizCode.DB_ERROR)
|
||||
|
||||
|
||||
def generate_random_password(length: int = 12) -> str:
|
||||
@@ -740,3 +744,54 @@ async def verify_and_change_email(db: Session, user_id: uuid.UUID, new_email: Em
|
||||
#
|
||||
# business_logger.info(f"用户邮箱修改成功: {db_user.username}, new_email={new_email}")
|
||||
# return db_user
|
||||
|
||||
|
||||
def get_user_language_preference(db: Session, user_id: uuid.UUID, current_user: User) -> str:
|
||||
"""获取用户语言偏好"""
|
||||
business_logger.info(f"获取用户语言偏好: user_id={user_id}")
|
||||
|
||||
# 权限检查:只能获取自己的语言偏好
|
||||
if current_user.id != user_id:
|
||||
raise PermissionDeniedException("只能获取自己的语言偏好")
|
||||
|
||||
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
if not db_user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
language = db_user.preferred_language or "zh"
|
||||
business_logger.info(f"用户语言偏好: {db_user.username}, language={language}")
|
||||
return language
|
||||
|
||||
|
||||
def update_user_language_preference(
|
||||
db: Session,
|
||||
user_id: uuid.UUID,
|
||||
language: str,
|
||||
current_user: User
|
||||
) -> User:
|
||||
"""更新用户语言偏好"""
|
||||
business_logger.info(f"更新用户语言偏好: user_id={user_id}, language={language}")
|
||||
|
||||
# 权限检查:只能修改自己的语言偏好
|
||||
if current_user.id != user_id:
|
||||
raise PermissionDeniedException("只能修改自己的语言偏好")
|
||||
|
||||
# 验证语言代码是否支持
|
||||
from app.core.config import settings
|
||||
if language not in settings.I18N_SUPPORTED_LANGUAGES:
|
||||
raise BusinessException(
|
||||
f"不支持的语言代码: {language}。支持的语言: {', '.join(settings.I18N_SUPPORTED_LANGUAGES)}",
|
||||
code=BizCode.VALIDATION_FAILED
|
||||
)
|
||||
|
||||
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
if not db_user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 更新语言偏好
|
||||
db_user.preferred_language = language
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
business_logger.info(f"用户语言偏好更新成功: {db_user.username}, language={language}")
|
||||
return db_user
|
||||
|
||||
Reference in New Issue
Block a user