diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 76fc0db5..e9b539df 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -194,6 +194,7 @@ def delete_app( def copy_app( app_id: uuid.UUID, new_name: Optional[str] = None, + payload: app_schema.CopyAppRequest = None, db: Session = Depends(get_db), current_user=Depends(get_current_user), ): @@ -205,6 +206,8 @@ def copy_app( - 不影响原应用 """ workspace_id = current_user.current_workspace_id + # body takes precedence over query param for backward compatibility + new_name = (payload.new_name if payload else None) or new_name logger.info( "用户请求复制应用", extra={ @@ -517,7 +520,7 @@ async def draft_run( # 提前验证和准备(在流式响应开始前完成) from app.services.app_service import AppService from app.services.multi_agent_service import MultiAgentService - from app.models import AgentConfig, ModelConfig + from app.models import AgentConfig, ModelConfig, AppRelease from sqlalchemy import select from app.core.exceptions import BusinessException from app.services.draft_run_service import AgentRunService @@ -537,6 +540,7 @@ async def draft_run( # 先获取 app 的 workspace_id end_user_repo = EndUserRepository(db) new_end_user = end_user_repo.get_or_create_end_user( + app_id=app_id, workspace_id=app.workspace_id, other_id=str(current_user.id), ) @@ -555,18 +559,29 @@ async def draft_run( service._check_agent_config(app_id) # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = db.scalars(stmt).first() - if not agent_cfg: - raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) + # 共享应用:从最新发布版本读配置快照,而非草稿 + is_shared = app.workspace_id != workspace_id + if is_shared: + if not app.current_release_id: + raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING) + release = db.get(AppRelease, app.current_release_id) + if not release: + raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING) + agent_cfg = service._agent_config_from_release(release) + model_config = db.get(ModelConfig, release.default_model_config_id) if release.default_model_config_id else None + else: + stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) + agent_cfg = db.scalars(stmt).first() + if not agent_cfg: + raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - # 3. 获取模型配置 - model_config = None - if agent_cfg.default_model_config_id: - model_config = db.get(ModelConfig, agent_cfg.default_model_config_id) - if not model_config: - from app.core.exceptions import ResourceNotFoundException - raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id)) + # 3. 获取模型配置 + model_config = None + if agent_cfg.default_model_config_id: + model_config = db.get(ModelConfig, agent_cfg.default_model_config_id) + if not model_config: + from app.core.exceptions import ResourceNotFoundException + raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id)) # 流式返回 if payload.stream: @@ -722,7 +737,17 @@ async def draft_run( msg="多 Agent 任务执行成功" ) elif app.type == AppType.WORKFLOW: # 工作流 - config = workflow_service.check_config(app_id) + # 共享应用:从最新发布版本读配置快照,而非草稿 + is_shared = app.workspace_id != workspace_id + if is_shared: + if not app.current_release_id: + raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING) + release = db.get(AppRelease, app.current_release_id) + if not release: + raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING) + config = service._workflow_config_from_release(release) + else: + config = workflow_service.check_config(app_id) # 3. 流式返回 if payload.stream: logger.debug( @@ -869,6 +894,7 @@ async def draft_run_compare( # 先获取 app 的 workspace_id end_user_repo = EndUserRepository(db) new_end_user = end_user_repo.get_or_create_end_user( + app_id=app_id, workspace_id=app.workspace_id, other_id=str(current_user.id), ) diff --git a/api/app/controllers/file_storage_controller.py b/api/app/controllers/file_storage_controller.py index b79035c0..ff284f39 100644 --- a/api/app/controllers/file_storage_controller.py +++ b/api/app/controllers/file_storage_controller.py @@ -15,7 +15,7 @@ import os import uuid from typing import Any -from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status from fastapi.responses import FileResponse, RedirectResponse from sqlalchemy.orm import Session @@ -47,6 +47,19 @@ router = APIRouter( ) +def _match_scheme(request: Request, url: str) -> str: + """ + 将 presigned URL 的协议替换为与当前请求一致的协议(http/https)。 + 解决反向代理场景下 presigned URL 协议与请求协议不匹配的问题。 + """ + incoming_scheme = request.headers.get("x-forwarded-proto") or request.url.scheme + if url.startswith("http://") and incoming_scheme == "https": + return "https://" + url[7:] + if url.startswith("https://") and incoming_scheme == "http": + return "http://" + url[8:] + return url + + @router.post("/files", response_model=ApiResponse) async def upload_file( file: UploadFile = File(...), @@ -280,6 +293,7 @@ async def upload_file_with_share_token( @router.get("/files/{file_id}", response_model=Any) async def download_file( + request: Request, file_id: uuid.UUID, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), @@ -327,6 +341,7 @@ async def download_file( else: try: presigned_url = await storage_service.get_file_url(file_key, expires=3600) + presigned_url = _match_scheme(request, presigned_url) api_logger.info(f"Redirecting to presigned URL: file_key={file_key}") return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND) except FileNotFoundError: @@ -400,6 +415,7 @@ async def delete_file( @router.get("/files/{file_id}/url", response_model=ApiResponse) async def get_file_url( + request: Request, file_id: uuid.UUID, expires: int = None, permanent: bool = False, @@ -463,6 +479,7 @@ async def get_file_url( else: # For remote storage (OSS/S3), get presigned URL url = await storage_service.get_file_url(file_key, expires=expires) + url = _match_scheme(request, url) api_logger.info(f"Generated file URL: file_id={file_id}") return success( @@ -484,6 +501,7 @@ async def get_file_url( @router.get("/public/{file_id}", response_model=Any) async def public_download_file( + request: Request, file_id: uuid.UUID, expires: int = 0, signature: str = "", @@ -555,6 +573,7 @@ async def public_download_file( # For remote storage, redirect to presigned URL try: presigned_url = await storage_service.get_file_url(file_key, expires=3600) + presigned_url = _match_scheme(request, presigned_url) return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND) except Exception as e: api_logger.error(f"Failed to get presigned URL: {e}") @@ -566,6 +585,7 @@ async def public_download_file( @router.get("/permanent/{file_id}", response_model=Any) async def permanent_download_file( + request: Request, file_id: uuid.UUID, db: Session = Depends(get_db), storage_service: FileStorageService = Depends(get_file_storage_service), @@ -625,6 +645,7 @@ async def permanent_download_file( try: # Use a very long expiration (7 days max for most cloud providers) presigned_url = await storage_service.get_file_url(file_key, expires=604800) + presigned_url = _match_scheme(request, presigned_url) return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND) except Exception as e: api_logger.error(f"Failed to get presigned URL: {e}") diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 19c82790..34572964 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -219,6 +219,7 @@ def list_conversations( app_service = AppService(db) app = app_service._get_app_or_404(share.app_id) new_end_user = end_user_repo.get_or_create_end_user( + app_id=share.app_id, workspace_id=app.workspace_id, other_id=other_id ) @@ -315,6 +316,7 @@ async def chat( app = app_service._get_app_or_404(share.app_id) workspace_id = app.workspace_id new_end_user = end_user_repo.get_or_create_end_user( + app_id=share.app_id, workspace_id=workspace_id, other_id=other_id, original_user_id=user_id @@ -661,6 +663,7 @@ async def config_query( content = { "app_type": release.app.type, "variables": workflow_service.get_start_node_variables(release.config), + "memory": workflow_service.is_memory_enable(release.config), "features": release.config.get("features") } elif release.app.type == AppType.AGENT: diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index d642861e..3b054d2a 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -94,6 +94,7 @@ async def chat( workspace_id = app.workspace_id end_user_repo = EndUserRepository(db) new_end_user = end_user_repo.get_or_create_end_user( + app_id=app.id, workspace_id=workspace_id, other_id=other_id, ) diff --git a/api/app/controllers/tool_controller.py b/api/app/controllers/tool_controller.py index 61048061..5563b9d7 100644 --- a/api/app/controllers/tool_controller.py +++ b/api/app/controllers/tool_controller.py @@ -3,6 +3,8 @@ from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session + +from app.core.error_codes import BizCode from app.schemas.tool_schema import ( ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest, ToolActiveUpdate @@ -250,8 +252,10 @@ async def sync_mcp_tools( try: result = await service.sync_mcp_tools(tool_id, current_user.tenant_id) if not result.get("success", False): - raise HTTPException(status_code=400, detail=result.get("message", "同步失败")) + raise BusinessException(result.get("message", "工具列表同步失败"), BizCode.BAD_REQUEST) return success(data=result, msg="MCP工具列表同步完成") + except BusinessException: + raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -274,8 +278,10 @@ async def test_tool_connection( # 普通连接测试 result = await service.test_connection(tool_id, current_user.tenant_id) if result["success"] is False: - raise HTTPException(status_code=400, detail=result["message"]) + raise BusinessException(result["message"], BizCode.SERVICE_UNAVAILABLE) return success(data=result, msg="连接测试完成") + except BusinessException: + raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/api/app/core/tools/mcp/base.py b/api/app/core/tools/mcp/base.py index 9e683ead..27dea86e 100644 --- a/api/app/core/tools/mcp/base.py +++ b/api/app/core/tools/mcp/base.py @@ -195,6 +195,6 @@ class MCPToolManager: except Exception as e: return { "success": False, - "error": str(e), - "message": "连接失败" + "error": "连接失败", + "message": str(e) } \ No newline at end of file diff --git a/api/app/core/workflow/engine/stream_output_coordinator.py b/api/app/core/workflow/engine/stream_output_coordinator.py index c2885ab0..ddee9adc 100644 --- a/api/app/core/workflow/engine/stream_output_coordinator.py +++ b/api/app/core/workflow/engine/stream_output_coordinator.py @@ -5,7 +5,7 @@ import re from typing import AsyncGenerator -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr from app.core.logging_config import get_logger from app.core.workflow.engine.variable_pool import VariablePool @@ -52,10 +52,11 @@ class OutputContent(BaseModel): ) ) - _SCOPE: str | None = None + _SCOPE: str | None = PrivateAttr(default=None) - def get_scope(self) -> str: - self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0] + def get_scope(self) -> str | None: + matches = SCOPE_PATTERN.findall(self.literal) + self._SCOPE = matches[0] if matches else None return self._SCOPE def depends_on_scope(self, scope: str) -> bool: @@ -68,6 +69,8 @@ class OutputContent(BaseModel): Returns: bool: True if this segment references the given scope. """ + if not self.is_variable: + return False if self._SCOPE: return self._SCOPE == scope return self.get_scope() == scope @@ -152,7 +155,7 @@ class StreamOutputConfig(BaseModel): """ # Case 1: resolve control branch dependency - if scope in self.control_nodes.keys(): + if scope in self.control_nodes: if status is None: raise RuntimeError("[Stream Output] Control node activation status not provided") if status in self.control_nodes[scope]: diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index 096f498f..0e9d3c62 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -27,7 +27,6 @@ class ToolNode(BaseNode): def _output_types(self) -> dict[str, VariableType]: return { "data": VariableType.STRING, - "error_code": VariableType.STRING, "execution_time": VariableType.NUMBER } @@ -48,10 +47,7 @@ class ToolNode(BaseNode): if not tenant_id: logger.error(f"节点 {self.node_id} 缺少租户ID") - return { - "success": False, - "data": "缺少租户ID" - } + raise ValueError("缺少租户ID") # 渲染工具参数 rendered_parameters = {} @@ -83,13 +79,8 @@ class ToolNode(BaseNode): logger.info(f"节点 {self.node_id} 工具执行成功") return { "data": result.data if isinstance(result.data, str) else json.dumps(result.data, ensure_ascii=False), - "error_code": "", "execution_time": result.execution_time } else: logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}") - return { - "data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False), - "error_code": result.error_code, - "execution_time": result.execution_time - } + raise ValueError(f"工具执行失败: {result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False)}") diff --git a/api/app/models/workflow_model.py b/api/app/models/workflow_model.py index 4f9ffe68..29fe5369 100644 --- a/api/app/models/workflow_model.py +++ b/api/app/models/workflow_model.py @@ -35,6 +35,7 @@ class WorkflowConfig(Base): # 执行配置 execution_config = Column(JSONB, nullable=False, default=dict) + features = Column(JSONB, nullable=True, default=dict) # 触发器配置(可选) triggers = Column(JSONB, default=list) diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index 590655a8..71c93634 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -66,7 +66,8 @@ class EndUserRepository: raise def get_or_create_end_user( - self, + self, + app_id: uuid.UUID, workspace_id: uuid.UUID, other_id: str, original_user_id: Optional[str] = None @@ -74,6 +75,7 @@ class EndUserRepository: """获取或创建终端用户 Args: + app_id: 应用ID workspace_id: 工作空间ID other_id: 第三方ID original_user_id: 原始用户ID (存储到 other_id) @@ -92,10 +94,14 @@ class EndUserRepository: if end_user: db_logger.debug(f"找到现有终端用户: 应用ID {workspace_id}、第三方ID {other_id}") + end_user.app_id=app_id + self.db.commit() + self.db.refresh(end_user) return end_user # 创建新用户 end_user = EndUser( + app_id=app_id, workspace_id=workspace_id, other_id=other_id ) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index b63c73f5..8bfedd5a 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -525,6 +525,13 @@ class AppRelease(BaseModel): return int(dt.timestamp() * 1000) if dt else None +# ---------- App Copy Schema ---------- + +class CopyAppRequest(BaseModel): + """复制应用请求""" + new_name: Optional[str] = Field(None, description="新应用名称,不填则使用原名称-副本") + + # ---------- App Share Schemas ---------- class AppShareCreate(BaseModel): diff --git a/api/app/schemas/workflow_schema.py b/api/app/schemas/workflow_schema.py index e580833f..d878d97c 100644 --- a/api/app/schemas/workflow_schema.py +++ b/api/app/schemas/workflow_schema.py @@ -80,6 +80,7 @@ class WorkflowConfigCreate(BaseModel): variables: list[VariableDefinition] = Field(default_factory=list, description="变量列表") execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置") triggers: list[TriggerConfig] = Field(default_factory=list, description="触发器列表") + features: dict = Field(default_factory=dict, description="功能特性配置") class WorkflowConfigUpdate(BaseModel): @@ -87,6 +88,7 @@ class WorkflowConfigUpdate(BaseModel): nodes: list[NodeDefinition] | None = None edges: list[EdgeDefinition] | None = None variables: list[VariableDefinition] | None = None + features: dict | None = None execution_config: ExecutionConfig | None = None triggers: list[TriggerConfig] | None = None @@ -102,6 +104,7 @@ class WorkflowConfig(BaseModel): variables: list[dict[str, Any]] execution_config: dict[str, Any] triggers: list[dict[str, Any]] + features: dict | None is_active: bool created_at: datetime.datetime updated_at: datetime.datetime @@ -114,6 +117,10 @@ class WorkflowConfig(BaseModel): def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None + @field_serializer("features", when_used="json") + def _serialize_features(self, features: dict | None): + return features or {} + # ==================== 工作流执行 ==================== diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 1b0613e8..94ee606e 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1426,6 +1426,47 @@ class AppService: logger.info("Agent 配置更新成功", extra={"app_id": str(app_id)}) return agent_cfg + def _agent_config_from_release(self, release: "AppRelease") -> "AgentConfig": + """从发布版本快照重建 AgentConfig 对象(不入库,仅用于运行)""" + cfg = release.config or {} + now = release.created_at or datetime.datetime.now() + agent_cfg = AgentConfig( + id=uuid.uuid4(), + app_id=release.app_id, + system_prompt=cfg.get("system_prompt", ""), + default_model_config_id=release.default_model_config_id, + model_parameters=cfg.get("model_parameters"), + knowledge_retrieval=cfg.get("knowledge_retrieval"), + memory=cfg.get("memory", {}), + variables=cfg.get("variables", []), + tools=cfg.get("tools", []), + skills=cfg.get("skills", {}), + features=cfg.get("features", {}), + is_active=True, + created_at=now, + updated_at=now, + ) + return agent_cfg + + def _workflow_config_from_release(self, release: "AppRelease") -> "WorkflowConfig": + """从发布版本快照重建 WorkflowConfig 对象(不入库,仅用于运行)""" + cfg = release.config or {} + now = release.created_at or datetime.datetime.now() + from app.models.workflow_model import WorkflowConfig as WorkflowConfigModel + wf_cfg = WorkflowConfigModel( + id=uuid.uuid4(), + app_id=release.app_id, + nodes=cfg.get("nodes", []), + edges=cfg.get("edges", []), + variables=cfg.get("variables", []), + execution_config=cfg.get("execution_config", {}), + triggers=cfg.get("triggers", []), + is_active=True, + created_at=now, + updated_at=now, + ) + return wf_cfg + def get_agent_config( self, *, @@ -1457,6 +1498,15 @@ class AppService: # 只读操作,允许访问共享应用 self._validate_app_accessible(app, workspace_id) + # 共享应用:返回最新发布版本的配置快照,而非草稿 + if workspace_id and app.workspace_id != workspace_id: + if not app.current_release_id: + raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING) + release = self.db.get(AppRelease, app.current_release_id) + if not release: + raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING) + return self._agent_config_from_release(release) + stmt = select(AgentConfig).where( AgentConfig.app_id == app_id, AgentConfig.is_active.is_(True) @@ -1555,6 +1605,16 @@ class AppService: # 只读操作,允许访问共享应用 self._validate_app_accessible(app, workspace_id) + + # 共享应用:返回最新发布版本的配置快照,而非草稿 + if workspace_id and app.workspace_id != workspace_id: + if not app.current_release_id: + raise BusinessException("该应用尚未发布,无法使用", BizCode.CONFIG_MISSING) + release = self.db.get(AppRelease, app.current_release_id) + if not release: + raise BusinessException("发布版本不存在", BizCode.CONFIG_MISSING) + return self._workflow_config_from_release(release) + repo = WorkflowConfigRepository(self.db) config = repo.get_by_app_id(app_id) if config: @@ -1609,6 +1669,7 @@ class AppService: variables=[var.model_dump() for var in data.variables] if data.variables else [], execution_config=data.execution_config.model_dump() if data.execution_config else {}, triggers=[trigger.model_dump() for trigger in data.triggers] if data.triggers else [], + features=data.features or {}, is_active=True, created_at=now, updated_at=now @@ -1622,6 +1683,7 @@ class AppService: workflow_cfg.variables = [var.model_dump() for var in data.variables] if data.variables else [] workflow_cfg.execution_config = data.execution_config.model_dump() if data.execution_config else {} workflow_cfg.triggers = [trigger.model_dump() for trigger in data.triggers] if data.triggers else [] + workflow_cfg.features = data.features or {} workflow_cfg.updated_at = now self.db.commit() @@ -1875,7 +1937,8 @@ class AppService: "edges": workflow_cfg.edges, "variables": workflow_cfg.variables, "execution_config": workflow_cfg.execution_config, - "triggers": workflow_cfg.triggers + "triggers": workflow_cfg.triggers, + "features": workflow_cfg.features or {} } is_valid, errors = WorkflowValidator.validate_for_publish(config) @@ -2062,7 +2125,8 @@ class AppService: ) if memory_config_id: - updated_count = self._update_endusers_memory_config(app_id, memory_config_id) + + updated_count = self._update_endusers_memory_config_by_workspace(app.workspace_id, memory_config_id) logger.info( f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, " f"memory_config_id={memory_config_id}, updated_count={updated_count}" diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 2c103d7c..92b13bfc 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -374,7 +374,7 @@ class AgentRunService: files: Optional[List[FileInput]] ) -> None: """校验上传文件是否符合 file_upload 配置""" - if not files: + if not files or not features_config: return fu = features_config.get("file_upload", {}) if not (isinstance(fu, dict) and fu.get("enabled")): diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 53d935fe..8a7c86e2 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -5,12 +5,14 @@ from urllib.parse import urlparse, unquote import json_repair from jinja2 import Template +from sqlalchemy import select from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.models import RedBearLLM, RedBearModelConfig +from app.models import FileMetadata from app.models.memory_perceptual_model import PerceptualType, FileStorageService from app.models.prompt_optimizer_model import RoleType from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository @@ -245,6 +247,18 @@ class MemoryPerceptualService: filename = os.path.basename(path) filename = unquote(filename) file_ext = os.path.splitext(filename)[1] + try: + file_id = uuid.UUID(filename) + stmt = select(FileMetadata).where( + FileMetadata.id == file_id + ) + file = self.db.execute(stmt).scalar_one_or_none() + + if file: + filename = file.file_name + file_ext = file.file_ext + except ValueError: + business_logger.debug(f"Remote file, file_id={filename}") if not file_ext: if file_type == FileType.AUDIO: file_ext = ".mp3" @@ -262,17 +276,17 @@ class MemoryPerceptualService: } if file_type in [FileType.IMAGE, FileType.VIDEO]: file_modalities = { - "scene": content.get("scene") + "scene": content.get("scene", []) } elif file_type in [FileType.DOCUMENT]: file_modalities = { - "section_count": content.get("section_count"), - "title": content.get("title"), - "first_line": content.get("first_line") + "section_count": content.get("section_count", 0), + "title": content.get("title", ""), + "first_line": content.get("first_line", "") } else: file_modalities = { - "speaker_count": content.get("speaker_count") + "speaker_count": content.get("speaker_count", 0) } self.repository.create_perceptual_memory( end_user_id=uuid.UUID(end_user_id), @@ -280,7 +294,7 @@ class MemoryPerceptualService: file_path=file_url, file_name=filename, file_ext=file_ext, - summary=content.get('summary'), + summary=content.get('summary', ""), meta_data={ "content": file_content, "modalities": file_modalities diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index 208f6ec0..1f0e1cc2 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -14,9 +14,13 @@ import uuid from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional +import csv +import json + import PyPDF2 import httpx import magic +import openpyxl from docx import Document from sqlalchemy.orm import Session @@ -39,6 +43,13 @@ DOC_MIME = [ 'application/msword', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' ] +XLSX_MIME = [ + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + 'application/vnd.ms-excel', + 'application/zip' +] +CSV_MIME = ['text/csv', 'application/csv'] +JSON_MIME = ['application/json'] class MultimodalFormatStrategy(ABC): @@ -48,22 +59,22 @@ class MultimodalFormatStrategy(ABC): self.file = file @abstractmethod - async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]: + async def format_image(self, url: str, content: bytes | None = None) -> tuple[bool, Dict[str, Any]]: """格式化图片""" pass @abstractmethod - async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]: """格式化文档""" pass @abstractmethod - async def format_audio(self, file_type: str, url: str, content: bytes | None = None) -> Dict[str, Any]: + async def format_audio(self, file_type: str, url: str, content: bytes | None = None) -> tuple[bool, Dict[str, Any]]: """格式化音频""" pass @abstractmethod - async def format_video(self, url: str) -> Dict[str, Any]: + async def format_video(self, url: str) -> tuple[bool, Dict[str, Any]]: """格式化视频""" pass @@ -71,16 +82,16 @@ class MultimodalFormatStrategy(ABC): class DashScopeFormatStrategy(MultimodalFormatStrategy): """通义千问策略""" - async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]: + async def format_image(self, url: str, content: bytes | None = None) -> tuple[bool, Dict[str, Any]]: """通义千问图片格式:{"type": "image", "image": "url"}""" - return { + return True, { "type": "image", "image": url } - async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]: """通义千问文档格式""" - return { + return True, { "type": "text", "text": f"\n{text}\n" } @@ -91,26 +102,26 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy): url: str, content: bytes | None = None, transcription: Optional[str] = None - ) -> Dict[str, Any]: + ) -> tuple[bool, Dict[str, Any]]: """ 通义千问音频格式 - 原生支持: qwen-audio 系列 - 其他模型: 需要转录为文本 """ if transcription: - return { + return True, { "type": "text", "text": f"" } # 通义千问音频格式:{"type": "audio", "audio": "url"} - return { + return True, { "type": "audio", "audio": url } - async def format_video(self, url: str) -> Dict[str, Any]: + async def format_video(self, url: str) -> tuple[bool, Dict[str, Any]]: """通义千问视频格式(qwen-vl 系列原生支持)""" - return { + return True, { "type": "video", "video": url } @@ -119,7 +130,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy): class BedrockFormatStrategy(MultimodalFormatStrategy): """Bedrock/Anthropic 策略""" - async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]: + async def format_image(self, url: str, content: bytes | None = None) -> tuple[bool, Dict[str, Any]]: """ Bedrock/Anthropic 格式: base64 编码 {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}} @@ -142,7 +153,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy): logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}") - return { + return True, { "type": "image", "source": { "type": "base64", @@ -151,13 +162,13 @@ class BedrockFormatStrategy(MultimodalFormatStrategy): } } - async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]: """Bedrock/Anthropic 文档格式(需要 base64 编码)""" # Bedrock 文档需要 base64 编码 text_bytes = text.encode('utf-8') base64_text = base64.b64encode(text_bytes).decode('utf-8') - return { + return True, { "type": "document", "source": { "type": "base64", @@ -171,24 +182,24 @@ class BedrockFormatStrategy(MultimodalFormatStrategy): url: str, content: bytes | None = None, transcription: Optional[str] = None - ) -> Dict[str, Any]: + ) -> tuple[bool, Dict[str, Any]]: """ Bedrock/Anthropic 音频格式 不支持原生音频,必须转录为文本 """ if transcription: - return { + return True, { "type": "text", "text": f"[音频转录]\n{transcription}" } - return { + return False, { "type": "text", "text": "[音频文件:Bedrock 不支持原生音频,请启用音频转文本功能]" } - async def format_video(self, url: str) -> Dict[str, Any]: + async def format_video(self, url: str) -> tuple[bool, Dict[str, Any]]: """Bedrock/Anthropic 视频格式""" - return { + return False, { "type": "text", "text": f"" } @@ -197,18 +208,18 @@ class BedrockFormatStrategy(MultimodalFormatStrategy): class OpenAIFormatStrategy(MultimodalFormatStrategy): """OpenAI 策略""" - async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]: + async def format_image(self, url: str, content: bytes | None = None) -> tuple[bool, Dict[str, Any]]: """OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}""" - return { + return True, { "type": "image_url", "image_url": { "url": url } } - async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]: """OpenAI 文档格式""" - return { + return True, { "type": "text", "text": f"\n{text}\n" } @@ -219,14 +230,14 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy): url: str, content: bytes | None = None, transcription: Optional[str] = None - ) -> Dict[str, Any]: + ) -> tuple[bool, Dict[str, Any]]: """ OpenAI 音频格式 - gpt-4o-audio 系列支持原生音频(需要 base64 编码) - 其他模型使用转录文本 """ if transcription: - return { + return True, { "type": "text", "text": f"" } @@ -255,7 +266,7 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy): # supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"} file_ext = "wav" if not file_ext else file_ext - return { + return True, { "type": "input_audio", "input_audio": { "data": f"data:;base64,{base64_audio}", @@ -264,14 +275,14 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy): } except Exception as e: logger.error(f"下载音频失败: {e}") - return { + return False, { "type": "text", "text": f"[音频处理失败: {str(e)}]" } - async def format_video(self, url: str) -> Dict[str, Any]: + async def format_video(self, url: str) -> tuple[bool, Dict[str, Any]]: """OpenAI 视频格式""" - return { + return True, { "type": "video_url", "video_url": { "url": url @@ -366,21 +377,25 @@ class MultimodalService: file.url = await self.get_file_url(file) try: if file.type == FileType.IMAGE and "vision" in self.capability: - content = await self._process_image(file, strategy) + is_support, content = await self._process_image(file, strategy) result.append(content) - self.write_perceptual_memory(end_user_id, file.type, file.url, content) + if is_support: + self.write_perceptual_memory(end_user_id, file.type, file.url, content) elif file.type == FileType.DOCUMENT: - content = await self._process_document(file, strategy) + is_support, content = await self._process_document(file, strategy) result.append(content) - self.write_perceptual_memory(end_user_id, file.type, file.url, content) + if is_support: + self.write_perceptual_memory(end_user_id, file.type, file.url, content) elif file.type == FileType.AUDIO and "audio" in self.capability: - content = await self._process_audio(file, strategy) + is_support, content = await self._process_audio(file, strategy) result.append(content) - self.write_perceptual_memory(end_user_id, file.type, file.url, content) + if is_support: + self.write_perceptual_memory(end_user_id, file.type, file.url, content) elif file.type == FileType.VIDEO and "video" in self.capability: - content = await self._process_video(file, strategy) + is_support, content = await self._process_video(file, strategy) result.append(content) - self.write_perceptual_memory(end_user_id, file.type, file.url, content) + if is_support: + self.write_perceptual_memory(end_user_id, file.type, file.url, content) else: logger.warning(f"不支持的文件类型: {file.type}") except Exception as e: @@ -413,7 +428,7 @@ class MultimodalService: if end_user_id and self.api_config: write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message) - async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]: + async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]: """ 处理图片文件 @@ -429,12 +444,12 @@ class MultimodalService: return await strategy.format_image(file.url, content=file.get_content()) except Exception as e: logger.error(f"处理图片失败: {e}", exc_info=True) - return { + return False, { "type": "text", "text": f"[图片处理失败: {str(e)}]" } - async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]: + async def _process_document(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]: """ 处理文档文件(PDF、Word 等) @@ -446,7 +461,7 @@ class MultimodalService: Dict: 根据 provider 返回不同格式的文档内容 """ if file.transfer_method == TransferMethod.REMOTE_URL: - return { + return True, { "type": "text", "text": f"\n{await self._extract_document_text(file)}\n" } @@ -464,7 +479,7 @@ class MultimodalService: # 使用策略格式化文档 return await strategy.format_document(file_name, text) - async def _process_audio(self, file: FileInput, strategy) -> Dict[str, Any]: + async def _process_audio(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]: """ 处理音频文件 @@ -492,12 +507,12 @@ class MultimodalService: return await strategy.format_audio(file.file_type, file.url, file.get_content(), transcription) except Exception as e: logger.error(f"处理音频失败: {e}", exc_info=True) - return { + return False, { "type": "text", "text": f"[音频处理失败: {str(e)}]" } - async def _process_video(self, file: FileInput, strategy) -> Dict[str, Any]: + async def _process_video(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]: """ 处理视频文件 @@ -513,7 +528,7 @@ class MultimodalService: return await strategy.format_video(file.url) except Exception as e: logger.error(f"处理视频失败: {e}", exc_info=True) - return { + return False, { "type": "text", "text": f"[视频处理失败: {str(e)}]" } @@ -577,6 +592,12 @@ class MultimodalService: return await self._extract_pdf_text(file_content) elif file_mime_type in DOC_MIME: return await self._extract_word_text(file_content) + elif file_mime_type in XLSX_MIME and file.file_type.endswith(("xlsx", "xls")): + return await self._extract_xlsx_text(file_content) + elif file_mime_type in CSV_MIME: + return await self._extract_csv_text(file_content) + elif file_mime_type in JSON_MIME: + return await self._extract_json_text(file_content) else: return f"[Unsupported file type: {file_mime_type}]" except Exception as e: @@ -602,7 +623,6 @@ class MultimodalService: async def _extract_word_text(file_content: bytes) -> str: """提取 Word 文档文本""" try: - # 使用 BytesIO 读取 Word 文档 word_file = io.BytesIO(file_content) doc = Document(word_file) text_parts = [paragraph.text for paragraph in doc.paragraphs] @@ -611,6 +631,42 @@ class MultimodalService: logger.error(f"提取 Word 文本失败: {e}") return f"[Word 提取失败: {str(e)}]" + @staticmethod + async def _extract_xlsx_text(file_content: bytes) -> str: + """提取 Excel 文本""" + try: + wb = openpyxl.load_workbook(io.BytesIO(file_content), read_only=True, data_only=True) + parts = [] + for sheet in wb.worksheets: + parts.append(f"[Sheet: {sheet.title}]") + for row in sheet.iter_rows(values_only=True): + parts.append('\t'.join('' if v is None else str(v) for v in row)) + return '\n'.join(parts) + except Exception as e: + logger.error(f"提取 Excel 文本失败: {e}") + return f"[Excel 提取失败: {str(e)}]" + + @staticmethod + async def _extract_csv_text(file_content: bytes) -> str: + """提取 CSV 文本""" + try: + text = file_content.decode('utf-8-sig') + reader = csv.reader(io.StringIO(text)) + return '\n'.join('\t'.join(row) for row in reader) + except Exception as e: + logger.error(f"提取 CSV 文本失败: {e}") + return f"[CSV 提取失败: {str(e)}]" + + @staticmethod + async def _extract_json_text(file_content: bytes) -> str: + """提取 JSON 文本""" + try: + data = json.loads(file_content.decode('utf-8')) + return json.dumps(data, ensure_ascii=False, indent=2) + except Exception as e: + logger.error(f"提取 JSON 文本失败: {e}") + return f"[JSON 提取失败: {str(e)}]" + def get_multimodal_service(db: Session) -> MultimodalService: """获取多模态服务实例(依赖注入)""" diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 4e7268d3..9f421976 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -570,6 +570,9 @@ class WorkflowService: message=f"工作流配置不存在: app_id={app_id}" ) + feature_configs = config.features or {} + self._validate_file_upload(feature_configs, payload.files) + input_data = { "message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id, @@ -737,6 +740,8 @@ class WorkflowService: code=BizCode.CONFIG_MISSING, message=f"工作流配置不存在: app_id={app_id}" ) + feature_configs = config.features or {} + self._validate_file_upload(feature_configs, payload.files) input_data = { "message": payload.message, "variables": payload.variables, @@ -845,7 +850,10 @@ class WorkflowService: yield event except Exception as e: - logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) + logger.error( + f"Workflow streaming execution failed: execution_id={execution.execution_id}, error={e}", + exc_info=True + ) self.update_execution_status( execution.execution_id, "failed", @@ -868,6 +876,80 @@ class WorkflowService: return node.get("config", {}).get("variables", []) raise BusinessException("workflow config error - start node not found") + @staticmethod + def is_memory_enable(config: dict) -> bool: + nodes = config.get("nodes", []) + for node in nodes: + if node.get("type") in [NodeType.MEMORY_READ, NodeType.MEMORY_WRITE]: + return True + return False + + @staticmethod + def _validate_file_upload( + features_config: dict[str, Any], + files: Optional[list[FileInput]] + ) -> None: + """校验上传文件是否符合 file_upload 配置""" + if not files: + return + fu = features_config.get("file_upload") + if fu is None: + return + if not (isinstance(fu, dict) and fu.get("enabled")): + raise BusinessException( + "The application does not have file upload functionality enabled", + BizCode.BAD_REQUEST + ) + max_count = fu.get("max_file_count", 5) + if len(files) > max_count: + raise BusinessException( + f"File count exceeds limit (maximum {max_count} files)", + BizCode.BAD_REQUEST + ) + + # 校验传输方式 + allowed_methods = fu.get("allowed_transfer_methods", ["local_file", "remote_url"]) + for f in files: + if f.transfer_method.value not in allowed_methods: + raise BusinessException( + f"Unsupport file transfer method:{f.transfer_method.value}," + f"allowed method:{', '.join(allowed_methods)}", + BizCode.BAD_REQUEST + ) + + # 各类型对应的开关和大小限制配置键 + type_cfg = { + "image": ("image_enabled", "image_max_size_mb", 20, "image"), + "audio": ("audio_enabled", "audio_max_size_mb", 50, "audio"), + "document": ("document_enabled", "document_max_size_mb", 100, "document"), + "video": ("video_enabled", "video_max_size_mb", 500, "video"), + } + + for f in files: + ftype = str(f.type) # 如 "image", "audio", "document", "video" + cfg = type_cfg.get(ftype) + if cfg is None: + continue + enabled_key, size_key, default_max_mb, label = cfg + + # 校验类型开关 + if not fu.get(enabled_key): + raise BusinessException( + f"The application has not enabled {label} file upload", + BizCode.BAD_REQUEST + ) + + # 校验文件大小(仅当内容已加载时) + content = f.get_content() + if content is not None: + max_mb = fu.get(size_key, default_max_mb) + size_mb = len(content) / (1024 * 1024) + if size_mb > max_mb: + raise BusinessException( + f"{label} File size exceeds the limit (maximum {max_mb} MB, current {size_mb:.1f} MB)", + BizCode.BAD_REQUEST + ) + # ==================== 依赖注入函数 ==================== diff --git a/api/migrations/versions/f017efe4831c_202603181652.py b/api/migrations/versions/f017efe4831c_202603181652.py new file mode 100644 index 00000000..833d29c0 --- /dev/null +++ b/api/migrations/versions/f017efe4831c_202603181652.py @@ -0,0 +1,30 @@ +"""202603181652 + +Revision ID: f017efe4831c +Revises: 818c6c535e14 +Create Date: 2026-03-18 16:52:21.639695 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'f017efe4831c' +down_revision: Union[str, None] = '818c6c535e14' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('workflow_configs', sa.Column('features', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('workflow_configs', 'features') + # ### end Alembic commands ### diff --git a/web/src/components/AudioRecorder/index.tsx b/web/src/components/AudioRecorder/index.tsx index b3b87130..639a9109 100644 --- a/web/src/components/AudioRecorder/index.tsx +++ b/web/src/components/AudioRecorder/index.tsx @@ -2,10 +2,12 @@ * @Author: ZhaoYing * @Date: 2026-02-06 21:11:51 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-16 18:06:00 + * @Last Modified time: 2026-03-17 18:39:09 */ import { type FC, useRef, useState } from 'react' import RecordRTC from 'recordrtc' +import { App } from 'antd' +import { useTranslation } from 'react-i18next'; import { fileUploadUrlWithoutApiPrefix } from '@/api/fileStorage' import { request } from '@/utils/request' @@ -20,6 +22,7 @@ interface AudioRecorderProps { /** Additional config passed to the upload request */ requestConfig?: Record; disabled?: boolean; + maxSize?: number; } const AudioRecorder: FC = ({ @@ -27,8 +30,11 @@ const AudioRecorder: FC = ({ className = '', action = fileUploadUrlWithoutApiPrefix, requestConfig = {}, - disabled = false + disabled = false, + maxSize, }) => { + const { message } = App.useApp() + const { t } = useTranslation(); // Whether the recorder is currently capturing audio const [isRecording, setIsRecording] = useState(false) // Holds the RecordRTC instance across renders @@ -57,6 +63,12 @@ const AudioRecorder: FC = ({ recorderRef.current.stopRecording(() => { const blob = recorderRef.current!.getBlob() const url = recorderRef.current!.toURL() + + if (maxSize && blob.size > maxSize * 1024 * 1024) { + message.error(t('common.fileSizeTip', { size: maxSize })); + return + } + const formData = new FormData() formData.append('file', blob, `recording_${Date.now()}.webm`) request diff --git a/web/src/components/Chat/ChatToolbar.tsx b/web/src/components/Chat/ChatToolbar.tsx index 1d368c30..883ac98a 100644 --- a/web/src/components/Chat/ChatToolbar.tsx +++ b/web/src/components/Chat/ChatToolbar.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-03-17 14:22:25 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-17 14:22:25 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-18 15:55:13 */ // Toolbar component for chat input area, supporting file upload, audio recording, and variable configuration import { useRef, forwardRef, useImperativeHandle, type ReactNode, useEffect } from 'react' @@ -120,7 +120,10 @@ const ChatToolbar = forwardRef(({ // Build dropdown menu items based on allowed transfer methods const fileMenus: MenuProps['items'] = [] - if (file_upload?.allowed_transfer_methods?.includes('remote_url')) { + const enabledTypes = ['image', 'document', 'video', 'audio'].filter( + type => file_upload?.[`${type}_enabled` as keyof FeaturesConfigForm['file_upload']] + ) + if (file_upload?.allowed_transfer_methods?.includes('remote_url') && enabledTypes.length > 0) { fileMenus.push({ key: 'url', label: t('memoryConversation.addRemoteFile'), @@ -133,9 +136,6 @@ const ChatToolbar = forwardRef(({ } }) } - const enabledTypes = ['image', 'document', 'video', 'audio'].filter( - type => file_upload?.[`${type}_enabled` as keyof FeaturesConfigForm['file_upload']] - ) if (file_upload?.allowed_transfer_methods?.includes('local_file') && enabledTypes.length > 0) { fileMenus.push({ key: 'upload', @@ -151,13 +151,11 @@ const ChatToolbar = forwardRef(({ }) } - console.log('queryValues', queryValues) - return (
-