From ea391dc44ed232a0febcb71d8a0e3c1e3392d0ff Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Mon, 16 Mar 2026 18:00:09 +0800 Subject: [PATCH] feat(app): 1. Add new functional features to the agent; 2. Enhance the voice output; 3. Modify the end_user binding; 4. Delete and modify the tools. --- api/app/controllers/app_controller.py | 40 +- .../controllers/public_share_controller.py | 37 +- .../controllers/service/app_api_controller.py | 3 +- api/app/controllers/tool_controller.py | 29 +- api/app/core/tools/mcp/client.py | 2 +- api/app/db.py | 2 +- api/app/models/agent_app_config_model.py | 1 + api/app/models/end_user_model.py | 8 +- api/app/models/tool_model.py | 5 +- api/app/models/workspace_model.py | 1 + api/app/repositories/end_user_repository.py | 121 ++++-- api/app/repositories/tool_repository.py | 71 ++-- api/app/schemas/app_schema.py | 118 ++++++ api/app/schemas/end_user_schema.py | 2 +- api/app/schemas/tool_schema.py | 6 + api/app/services/agent_config_converter.py | 3 + api/app/services/app_chat_service.py | 93 ++++- api/app/services/app_service.py | 29 +- api/app/services/draft_run_service.py | 386 ++++++++++++++++-- api/app/services/memory_dashboard_service.py | 4 +- api/app/services/tool_service.py | 47 ++- api/app/tasks.py | 8 +- 22 files changed, 832 insertions(+), 184 deletions(-) diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 059bec6b..76fc0db5 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -254,6 +254,27 @@ def get_agent_config( return success(data=app_schema.AgentConfig.model_validate(cfg)) +@router.get("/{app_id}/opening", summary="获取应用开场白配置") +@cur_workspace_access_guard() +def get_opening( + app_id: uuid.UUID, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """返回开场白文本和预设问题,供前端对话界面初始化时展示""" + workspace_id = current_user.current_workspace_id + cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id) + features = cfg.features or {} + if hasattr(features, "model_dump"): + features = features.model_dump() + opening = features.get("opening_statement", {}) + return success(data=app_schema.OpeningResponse( + enabled=opening.get("enabled", False), + statement=opening.get("statement"), + suggested_questions=opening.get("suggested_questions", []), + )) + + @router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)") @cur_workspace_access_guard() def publish_app( @@ -513,11 +534,11 @@ async def draft_run( service._validate_app_accessible(app, workspace_id) if payload.user_id is None: + # 先获取 app 的 workspace_id end_user_repo = EndUserRepository(db) new_end_user = end_user_repo.get_or_create_end_user( - app_id=app_id, + workspace_id=app.workspace_id, other_id=str(current_user.id), - original_user_id=str(current_user.id) # Save original user_id to other_id ) payload.user_id = str(new_end_user.id) @@ -845,11 +866,11 @@ async def draft_run_compare( service._validate_app_accessible(app, workspace_id) if payload.user_id is None: + # 先获取 app 的 workspace_id end_user_repo = EndUserRepository(db) new_end_user = end_user_repo.get_or_create_end_user( - app_id=app_id, + workspace_id=app.workspace_id, other_id=str(current_user.id), - original_user_id=str(current_user.id) # Save original user_id to other_id ) payload.user_id = str(new_end_user.id) @@ -898,7 +919,12 @@ async def draft_run_compare( "conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id }) - + # 从 features 中读取功能开关(与 draft_run 保持一致) + features_config: dict = agent_cfg.features or {} + if hasattr(features_config, 'model_dump'): + features_config = features_config.model_dump() + web_search_feature = features_config.get("web_search", {}) + web_search = isinstance(web_search_feature, dict) and web_search_feature.get("enabled", False) # 流式返回 if payload.stream: @@ -915,7 +941,7 @@ async def draft_run_compare( variables=payload.variables, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id, - web_search=True, + web_search=web_search, memory=True, parallel=payload.parallel, timeout=payload.timeout or 60, @@ -946,7 +972,7 @@ async def draft_run_compare( variables=payload.variables, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id, - web_search=True, + web_search=web_search, memory=True, parallel=payload.parallel, timeout=payload.timeout or 60, diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 3c634ae0..19c82790 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -22,6 +22,7 @@ from app.schemas import release_share_schema, conversation_schema from app.schemas.response_schema import PageData, PageMeta from app.services import workspace_service from app.services.app_chat_service import AppChatService, get_app_chat_service +from app.services.app_service import AppService from app.services.auth_service import create_access_token from app.services.conversation_service import ConversationService from app.services.release_share_service import ReleaseShareService @@ -215,8 +216,10 @@ def list_conversations( service = SharedChatService(db) share, release = service.get_release_by_share_token(share_data.share_token, password) end_user_repo = EndUserRepository(db) + app_service = AppService(db) + app = app_service._get_app_or_404(share.app_id) new_end_user = end_user_repo.get_or_create_end_user( - app_id=share.app_id, + workspace_id=app.workspace_id, other_id=other_id ) logger.debug(new_end_user.id) @@ -308,25 +311,28 @@ async def chat( # Store end_user_id in database with original user_id end_user_repo = EndUserRepository(db) + app_service = AppService(db) + app = app_service._get_app_or_404(share.app_id) + workspace_id = app.workspace_id new_end_user = end_user_repo.get_or_create_end_user( - app_id=share.app_id, + workspace_id=workspace_id, other_id=other_id, - original_user_id=user_id # Save original user_id to other_id + original_user_id=user_id ) end_user_id = str(new_end_user.id) - appid = share.app_id + # appid = share.app_id """获取存储类型和工作空间的ID""" # 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用) - app = db.query(App).filter( - App.id == appid, - App.is_active.is_(True) - ).first() - if not app: - raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND) + # app = db.query(App).filter( + # App.id == appid, + # App.is_active.is_(True) + # ).first() + # if not app: + # raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND) - workspace_id = app.workspace_id + # workspace_id = app.workspace_id # 直接从 workspace 获取 storage_type(公开分享场景无需权限检查) storage_type = workspace_service.get_workspace_storage_type_without_auth( @@ -654,17 +660,20 @@ async def config_query( workflow_service = WorkflowService(db) content = { "app_type": release.app.type, - "variables": workflow_service.get_start_node_variables(release.config) + "variables": workflow_service.get_start_node_variables(release.config), + "features": release.config.get("features") } elif release.app.type == AppType.AGENT: content = { "app_type": release.app.type, - "variables": release.config.get("variables") + "variables": release.config.get("variables"), + "features": release.config.get("features") } elif release.app.type == AppType.MULTI_AGENT: content = { "app_type": release.app.type, - "variables": [] + "variables": [], + "features": release.config.get("features") } else: return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED) diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index 64143f57..d642861e 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -94,9 +94,8 @@ async def chat( workspace_id = app.workspace_id end_user_repo = EndUserRepository(db) new_end_user = end_user_repo.get_or_create_end_user( - app_id=app.id, + workspace_id=workspace_id, other_id=other_id, - original_user_id=other_id # Save original user_id to other_id ) end_user_id = str(new_end_user.id) web_search = True diff --git a/api/app/controllers/tool_controller.py b/api/app/controllers/tool_controller.py index 10ca83af..61048061 100644 --- a/api/app/controllers/tool_controller.py +++ b/api/app/controllers/tool_controller.py @@ -4,7 +4,8 @@ from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session from app.schemas.tool_schema import ( - ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest + ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, + CustomToolTestRequest, ToolActiveUpdate ) from app.core.response_utils import success @@ -156,7 +157,7 @@ async def delete_tool( current_user: User = Depends(get_current_user), service: ToolService = Depends(get_tool_service) ): - """删除工具""" + """删除工具(逻辑删除,is_active=False)""" try: success_flag = service.delete_tool(tool_id, current_user.tenant_id) if not success_flag: @@ -168,6 +169,30 @@ async def delete_tool( raise HTTPException(status_code=500, detail=str(e)) +@router.patch("/{tool_id}/active", response_model=ApiResponse) +async def set_tool_active( + tool_id: str, + request: ToolActiveUpdate, + current_user: User = Depends(get_current_user), + service: ToolService = Depends(get_tool_service) +): + """设置工具可用状态(启用/禁用) + + - is_active=true: 启用工具 + - is_active=false: 禁用工具(等同于删除,但可恢复) + """ + try: + success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active) + if not success_flag: + raise HTTPException(status_code=404, detail="工具不存在") + action = "启用" if request.is_active else "禁用" + return success(msg=f"工具已{action}") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/execution/execute", response_model=ApiResponse) async def execute_tool( request: ToolExecuteRequest, diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index f19902a2..6df6df51 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -23,7 +23,7 @@ class SimpleMCPClient: def __init__(self, server_url: str, connection_config: Dict[str, Any] = None): self.server_url = server_url self.connection_config = connection_config or {} - self.timeout = self.connection_config.get("timeout", 30) + self.timeout = self.connection_config.get("timeout", 10) # 确定连接类型 self.is_websocket = server_url.startswith(("ws://", "wss://")) diff --git a/api/app/db.py b/api/app/db.py index cdaa6dbd..80ab2756 100644 --- a/api/app/db.py +++ b/api/app/db.py @@ -16,7 +16,7 @@ engine = create_engine( pool_recycle=settings.DB_POOL_RECYCLE, pool_timeout=settings.DB_POOL_TIMEOUT, connect_args={ - "options": "-c timezone=Asia/Shanghai -c statement_timeout=60000" + "options": "-c timezone=UTC -c statement_timeout=60000" }, ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) diff --git a/api/app/models/agent_app_config_model.py b/api/app/models/agent_app_config_model.py index cc2e0686..3ece049e 100644 --- a/api/app/models/agent_app_config_model.py +++ b/api/app/models/agent_app_config_model.py @@ -31,6 +31,7 @@ class AgentConfig(Base): variables = Column(JSON, default=list, nullable=True, comment="变量配置") tools = Column(JSON, default=list, nullable=True, comment="工具配置") skills = Column(JSON, default=dict, nullable=True, comment="技能配置") + features = Column(JSON, default=dict, nullable=True, comment="功能特性配置") # 多 Agent 相关字段 agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone") diff --git a/api/app/models/end_user_model.py b/api/app/models/end_user_model.py index 28a44f1f..60600fcf 100644 --- a/api/app/models/end_user_model.py +++ b/api/app/models/end_user_model.py @@ -12,7 +12,8 @@ class EndUser(Base): __tablename__ = "end_users" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True) - app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False) + app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=True) + workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False) # end_user_id = Column(String, nullable=False, index=True) other_id = Column(String, nullable=True) # Store original user_id other_name = Column(String, default="", nullable=False) @@ -61,4 +62,7 @@ class EndUser(Base): app = relationship( "App", back_populates="end_users" - ) \ No newline at end of file + ) + + # 与 WorkSpace 的反向关系 + workspace = relationship("Workspace", back_populates="end_users") \ No newline at end of file diff --git a/api/app/models/tool_model.py b/api/app/models/tool_model.py index 98448bc5..e8d9c528 100644 --- a/api/app/models/tool_model.py +++ b/api/app/models/tool_model.py @@ -110,7 +110,10 @@ class ToolConfig(Base): # 元数据 version = Column(String(50), default="1.0.0") tags = Column(JSON, default=list) # 标签列表 - + + # 逻辑删除标志 + is_active = Column(Boolean, default=True, server_default='true', nullable=False, index=True, comment="是否可用,False表示已删除") + # 时间戳 created_at = Column(DateTime, default=datetime.now, nullable=False) updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False) diff --git a/api/app/models/workspace_model.py b/api/app/models/workspace_model.py index 4d42ed32..2db5e3df 100644 --- a/api/app/models/workspace_model.py +++ b/api/app/models/workspace_model.py @@ -38,6 +38,7 @@ class Workspace(Base): members = relationship("WorkspaceMember", back_populates="workspace") # users collaborate through membership api_keys = relationship("ApiKey", back_populates="workspace", cascade="all, delete-orphan") # API Keys memory_increments = relationship("MemoryIncrement", back_populates="workspace") + end_users = relationship("EndUser", back_populates="workspace", cascade="all, delete-orphan") class WorkspaceMember(Base): __tablename__ = "workspace_members" diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index 61faf6d4..590655a8 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -32,6 +32,21 @@ class EndUserRepository: db_logger.error(f"查询应用 {app_id} 下宿主时出错: {str(e)}") raise + def get_end_users_by_workspace(self, workspace_id: uuid.UUID) -> List[EndUser]: + """获取指定 workspace 下的所有 end_user""" + try: + end_users = ( + self.db.query(EndUser) + .filter(EndUser.workspace_id == workspace_id) + .all() + ) + db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(end_users)} 个终端用户") + return end_users + except Exception as e: + self.db.rollback() + db_logger.error(f"查询工作空间 {workspace_id} 下终端用户时出错: {str(e)}") + raise + def get_end_user_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]: """根据 end_user_id 查询宿主""" try: @@ -52,14 +67,14 @@ class EndUserRepository: def get_or_create_end_user( self, - app_id: uuid.UUID, + workspace_id: uuid.UUID, other_id: str, original_user_id: Optional[str] = None ) -> EndUser: """获取或创建终端用户 Args: - app_id: 应用ID + workspace_id: 工作空间ID other_id: 第三方ID original_user_id: 原始用户ID (存储到 other_id) """ @@ -68,26 +83,27 @@ class EndUserRepository: end_user = ( self.db.query(EndUser) .filter( - EndUser.app_id == app_id, + EndUser.workspace_id == workspace_id, EndUser.other_id == other_id ) + .order_by(EndUser.created_at.asc()) .first() ) if end_user: - db_logger.debug(f"找到现有终端用户: 应用ID {app_id}、第三方ID {other_id}") + db_logger.debug(f"找到现有终端用户: 应用ID {workspace_id}、第三方ID {other_id}") return end_user # 创建新用户 end_user = EndUser( - app_id=app_id, + workspace_id=workspace_id, other_id=other_id ) self.db.add(end_user) self.db.commit() self.db.refresh(end_user) - db_logger.info(f"创建新终端用户: (other_id: {other_id}) for app {app_id}") + db_logger.info(f"创建新终端用户: (other_id: {other_id}) for workspace {workspace_id}") return end_user except Exception as e: @@ -314,8 +330,7 @@ class EndUserRepository: try: end_users = ( self.db.query(EndUser) - .join(App, EndUser.app_id == App.id) - .filter(App.workspace_id == workspace_id) + .filter(EndUser.workspace_id == workspace_id) .all() ) db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(end_users)} 个终端用户") @@ -402,45 +417,79 @@ class EndUserRepository: db_logger.error(f"获取终端用户 {end_user_id} 的 memory_config_id 时出错: {str(e)}") raise - def batch_update_memory_config_id( - self, - app_id: uuid.UUID, - memory_config_id: uuid.UUID + # def batch_update_memory_config_id( + # self, + # app_id: uuid.UUID, + # memory_config_id: uuid.UUID + # ) -> int: + # """批量更新应用下所有终端用户的 memory_config_id + # + # Args: + # app_id: 应用ID + # memory_config_id: 新的记忆配置ID + # + # Returns: + # int: 更新的行数 + # """ + # try: + # from sqlalchemy import update + # + # stmt = ( + # update(EndUser) + # .where(EndUser.app_id == app_id) + # .values(memory_config_id=memory_config_id) + # ) + # + # result = self.db.execute(stmt) + # self.db.commit() + # + # updated_count = result.rowcount + # + # db_logger.info( + # f"批量更新终端用户记忆配置: app_id={app_id}, " + # f"memory_config_id={memory_config_id}, updated_count={updated_count}" + # ) + # + # return updated_count + # + # except Exception as e: + # self.db.rollback() + # db_logger.error( + # f"批量更新终端用户记忆配置时出错: app_id={app_id}, " + # f"memory_config_id={memory_config_id}, error={str(e)}" + # ) + # raise + + def batch_update_memory_config_id_by_workspace( + self, + workspace_id: uuid.UUID, + memory_config_id: uuid.UUID ) -> int: - """批量更新应用下所有终端用户的 memory_config_id - - Args: - app_id: 应用ID - memory_config_id: 新的记忆配置ID - - Returns: - int: 更新的行数 - """ + """批量更新工作空间下所有终端用户的 memory_config_id""" try: from sqlalchemy import update stmt = ( update(EndUser) - .where(EndUser.app_id == app_id) + .where(EndUser.workspace_id == workspace_id) .values(memory_config_id=memory_config_id) ) - + result = self.db.execute(stmt) self.db.commit() - + updated_count = result.rowcount - + db_logger.info( - f"批量更新终端用户记忆配置: app_id={app_id}, " + f"批量更新终端用户记忆配置: workspace_id={workspace_id}, " f"memory_config_id={memory_config_id}, updated_count={updated_count}" ) - + return updated_count - except Exception as e: self.db.rollback() db_logger.error( - f"批量更新终端用户记忆配置时出错: app_id={app_id}, " + f"批量更新终端用户记忆配置时出错: workspace_id={workspace_id}, " f"memory_config_id={memory_config_id}, error={str(e)}" ) raise @@ -492,7 +541,7 @@ class EndUserRepository: """ try: from sqlalchemy import update - + stmt = ( update(EndUser) .where(EndUser.memory_config_id == memory_config_id) @@ -519,10 +568,16 @@ class EndUserRepository: ) raise -def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]: - """根据应用ID查询宿主(返回 EndUser ORM 列表)""" +# def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]: +# """根据应用ID查询宿主(返回 EndUser ORM 列表)""" +# repo = EndUserRepository(db) +# end_users = repo.get_end_users_by_app_id(app_id) +# return end_users + +def get_end_users_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[EndUser]: + """根据工作空间ID查询终端用户(返回 EndUser ORM 列表)""" repo = EndUserRepository(db) - end_users = repo.get_end_users_by_app_id(app_id) + end_users = repo.get_end_users_by_workspace(workspace_id) return end_users def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]: diff --git a/api/app/repositories/tool_repository.py b/api/app/repositories/tool_repository.py index 257910c3..1a9b0b87 100644 --- a/api/app/repositories/tool_repository.py +++ b/api/app/repositories/tool_repository.py @@ -27,7 +27,7 @@ class ToolRepository: from app.models.app_model import App from app.models.workflow_model import WorkflowConfig from app.models.workspace_model import Workspace - + result = db.query(Workspace.tenant_id).join( App, App.workspace_id == Workspace.id ).join( @@ -35,7 +35,7 @@ class ToolRepository: ).filter( WorkflowConfig.id == workflow_id ).first() - + return result[0] if result else None @staticmethod @@ -67,18 +67,19 @@ class ToolRepository: @staticmethod def find_by_tenant( - db: Session, - tenant_id: uuid.UUID, - name: Optional[str] = None, - tool_type: Optional[ToolType] = None, - status: Optional[ToolStatus] = None, - is_enabled: Optional[bool] = None + db: Session, + tenant_id: uuid.UUID, + name: Optional[str] = None, + tool_type: Optional[ToolType] = None, + status: Optional[ToolStatus] = None, + is_enabled: Optional[bool] = None ) -> List[ToolConfig]: - """根据租户查找工具""" + """根据租户查找工具(只返回未删除的)""" query = db.query(ToolConfig).filter( - ToolConfig.tenant_id == tenant_id + ToolConfig.tenant_id == tenant_id, + ToolConfig.is_active.is_(True) ) - + if name: query = query.filter(ToolConfig.name.ilike(f"%{name}%")) if tool_type: @@ -91,8 +92,17 @@ class ToolRepository: return query.all() @staticmethod - def find_by_id_and_tenant(db:Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]: - """根据ID和租户查找工具""" + def find_by_id_and_tenant(db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]: + """根据ID和租户查找工具(只返回未删除的)""" + return db.query(ToolConfig).filter( + ToolConfig.id == tool_id, + ToolConfig.tenant_id == tenant_id, + ToolConfig.is_active.is_(True) + ).first() + + @staticmethod + def find_by_id_and_tenant_all(db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]: + """根据ID和租户查找工具(返回所有工具包括删除的)""" return db.query(ToolConfig).filter( ToolConfig.id == tool_id, ToolConfig.tenant_id == tenant_id @@ -100,29 +110,26 @@ class ToolRepository: @staticmethod def count_by_tenant(db: Session, tenant_id: uuid.UUID) -> int: - """统计租户工具数量""" + """统计租户工具数量(只统计未删除的)""" return db.query(ToolConfig).filter( - ToolConfig.tenant_id == tenant_id + ToolConfig.tenant_id == tenant_id, + ToolConfig.is_active.is_(True) ).count() @staticmethod def get_status_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]: """获取状态统计""" - return db.query( - ToolConfig.status, - func.count(ToolConfig.id).label('count') - ).filter( - ToolConfig.tenant_id == tenant_id + return db.query(ToolConfig.status, func.count(ToolConfig.id).label('count')).filter( + ToolConfig.tenant_id == tenant_id, + ToolConfig.is_active.is_(True) ).group_by(ToolConfig.status).all() @staticmethod def get_type_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]: """获取类型统计""" - return db.query( - ToolConfig.tool_type, - func.count(ToolConfig.id).label('count') - ).filter( - ToolConfig.tenant_id == tenant_id + return db.query(ToolConfig.tool_type, func.count(ToolConfig.id).label('count')).filter( + ToolConfig.tenant_id == tenant_id, + ToolConfig.is_active.is_(True) ).group_by(ToolConfig.tool_type).all() @staticmethod @@ -130,6 +137,7 @@ class ToolRepository: """统计租户启用的工具数量""" return db.query(ToolConfig).filter( ToolConfig.tenant_id == tenant_id, + ToolConfig.is_active.is_(True), ToolConfig.is_enabled == True ).count() @@ -138,7 +146,8 @@ class ToolRepository: """检查租户是否已有内置工具""" return db.query(ToolConfig).filter( ToolConfig.tenant_id == tenant_id, - ToolConfig.tool_type == ToolType.BUILTIN.value + ToolConfig.tool_type == ToolType.BUILTIN.value, + ToolConfig.is_active.is_(True) ).count() > 0 @@ -194,10 +203,10 @@ class ToolExecutionRepository: @staticmethod def find_by_tool_and_tenant( - db: Session, - tool_id: uuid.UUID, - tenant_id: uuid.UUID, - limit: int = 100 + db: Session, + tool_id: uuid.UUID, + tenant_id: uuid.UUID, + limit: int = 100 ) -> List[ToolExecution]: """根据工具和租户查找执行记录""" return db.query(ToolExecution).join( @@ -205,4 +214,4 @@ class ToolExecutionRepository: ).filter( ToolConfig.id == tool_id, ToolConfig.tenant_id == tenant_id - ).order_by(ToolExecution.started_at.desc()).limit(limit).all() \ No newline at end of file + ).order_by(ToolExecution.started_at.desc()).limit(limit).all() diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 8e7e4bd2..9e7bc876 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -125,6 +125,85 @@ class SkillConfig(BaseModel): all_skills: Optional[bool] = Field(default=False, description="是否允许访问所有技能") +# ---------- App Features ---------- + +class FileUploadConfig(BaseModel): + """文件上传配置""" + enabled: bool = Field(default=False) + # 允许的传输方式:local_file / remote_url,默认两种都允许 + allowed_transfer_methods: List[str] = Field( + default=["local_file", "remote_url"], + description="允许的传输方式" + ) + # 图片文件:PNG/JPG/JPEG/GIF/WEBP,最大 20MB + image_enabled: bool = Field(default=False) + image_max_size_mb: int = Field(default=20) + image_allowed_extensions: List[str] = Field( + default=["png", "jpg", "jpeg", "gif", "webp"] + ) + # 语音文件:MP3/WAV/M4A/OGG/FLAC,最大 50MB + audio_enabled: bool = Field(default=False) + audio_max_size_mb: int = Field(default=50) + audio_allowed_extensions: List[str] = Field( + default=["mp3", "wav", "m4a", "ogg", "flac"] + ) + # 通用文件:PDF/DOCX/XLSX/TXT/CSV/JSON,最大 100MB + document_enabled: bool = Field(default=False) + document_max_size_mb: int = Field(default=100) + document_allowed_extensions: List[str] = Field( + default=["pdf", "docx", "xlsx", "txt", "csv", "json"] + ) + # 视频文件:MP4/MOV/AVI/WebM,最大 500MB + video_enabled: bool = Field(default=False) + video_max_size_mb: int = Field(default=500) + video_allowed_extensions: List[str] = Field( + default=["mp4", "mov", "avi", "webm"] + ) + # 最大文件数量 + max_file_count: int = Field(default=5, ge=1, le=20) + + +class OpeningStatementConfig(BaseModel): + """对话开场白配置""" + enabled: bool = Field(default=False) + statement: Optional[str] = Field(default=None, description="开场白内容") + suggested_questions: List[str] = Field(default_factory=list, description="预设问题列表") + + +class SuggestedQuestionsConfig(BaseModel): + """下一步问题建议配置""" + enabled: bool = Field(default=False) + + +class TextToSpeechConfig(BaseModel): + """文字转语音配置""" + enabled: bool = Field(default=False) + voice: Optional[str] = Field(default=None, description="语音音色") + language: Optional[str] = Field(default=None, description="语言") + autoplay: bool = Field(default=False, description="是否自动播放") + + +class CitationConfig(BaseModel): + """引用和归属配置""" + enabled: bool = Field(default=False) + + +class WebSearchConfig(BaseModel): + """联网搜索配置""" + enabled: bool = Field(default=False) + search_engine: Optional[str] = Field(default=None, description="搜索引擎") + + +class AppFeatures(BaseModel): + """应用功能特性配置""" + file_upload: FileUploadConfig = Field(default_factory=FileUploadConfig) + opening_statement: OpeningStatementConfig = Field(default_factory=OpeningStatementConfig) + suggested_questions_after_answer: SuggestedQuestionsConfig = Field(default_factory=SuggestedQuestionsConfig) + text_to_speech: TextToSpeechConfig = Field(default_factory=TextToSpeechConfig) + citation: CitationConfig = Field(default_factory=CitationConfig) + web_search: WebSearchConfig = Field(default_factory=WebSearchConfig) + + class ToolOldConfig(BaseModel): """工具配置""" enabled: bool = Field(default=False, description="是否启用该工具") @@ -201,6 +280,9 @@ class AgentConfigCreate(BaseModel): # 技能配置 skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表") + # 功能特性 + features: Optional[AppFeatures] = Field(default=None, description="功能特性配置") + class AppCreate(BaseModel): name: str @@ -258,6 +340,9 @@ class AgentConfigUpdate(BaseModel): # 技能配置 skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表") + # 功能特性 + features: Optional[AppFeatures] = Field(default=None, description="功能特性配置") + # ---------- Output Schemas ---------- @@ -323,6 +408,8 @@ class AgentConfig(BaseModel): skills: Optional[SkillConfig] = {} + features: Optional[AppFeatures] = None + is_active: bool created_at: datetime.datetime updated_at: datetime.datetime @@ -359,6 +446,14 @@ class AgentConfig(BaseModel): return {} return v + @field_validator("features", mode="before") + @classmethod + def validate_features(cls, v): + """处理 None 值,返回默认 AppFeatures""" + if v is None: + return AppFeatures() + return v + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -500,12 +595,35 @@ class DraftRunRequest(BaseModel): files: Optional[List[FileInput]] = Field(default_factory=list, description="附件列表(支持多文件)") +class SuggestedQuestion(BaseModel): + """建议问题""" + content: str + + +class CitationSource(BaseModel): + """引用来源""" + title: str + content: str + score: Optional[float] = None + kb_id: Optional[str] = None + + class DraftRunResponse(BaseModel): """试运行响应(非流式)""" message: str = Field(..., description="AI 回复消息") conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)") usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况") elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)") + suggested_questions: List[str] = Field(default_factory=list, description="下一步建议问题") + citations: List[CitationSource] = Field(default_factory=list, description="引用来源") + audio_url: Optional[str] = Field(default=None, description="TTS 语音URL") + + +class OpeningResponse(BaseModel): + """应用开场白响应""" + enabled: bool + statement: Optional[str] = None + suggested_questions: List[str] = Field(default_factory=list) class DraftRunStreamChunk(BaseModel): diff --git a/api/app/schemas/end_user_schema.py b/api/app/schemas/end_user_schema.py index 6f7498a0..bbb6fd5c 100644 --- a/api/app/schemas/end_user_schema.py +++ b/api/app/schemas/end_user_schema.py @@ -8,7 +8,7 @@ class EndUser(BaseModel): model_config = ConfigDict(from_attributes=True) id: uuid.UUID = Field(description="终端用户ID") - app_id: uuid.UUID = Field(description="应用ID") + app_id: Optional[uuid.UUID] = Field(description="应用ID", default=None) # end_user_id: str = Field(description="终端用户ID") other_id: Optional[str] = Field(description="第三方ID", default=None) other_name: Optional[str] = Field(description="其他名称", default="") diff --git a/api/app/schemas/tool_schema.py b/api/app/schemas/tool_schema.py index 2ba86c2c..79e01688 100644 --- a/api/app/schemas/tool_schema.py +++ b/api/app/schemas/tool_schema.py @@ -90,6 +90,7 @@ class ToolInfo(BaseModel): parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数") config_data: Dict[str, Any] = Field(default_factory=dict, description="工具配置") status: ToolStatus = Field(ToolStatus.AVAILABLE, description="工具状态") + is_active: bool = Field(True, description="是否可用(False 表示已删除)") tags: List[str] = Field(default_factory=list, description="工具标签") tenant_id: Optional[str] = Field(None, description="租户ID") created_at: datetime = Field(..., description="创建时间") @@ -212,6 +213,11 @@ class ToolUpdateRequest(BaseModel): tags: Optional[List[str]] = None +class ToolActiveUpdate(BaseModel): + """工具可用状态更新""" + is_active: bool = Field(..., description="True=启用, False=禁用(逻辑删除)") + + class ToolExecuteRequest(BaseModel): """执行工具请求""" tool_id: str diff --git a/api/app/services/agent_config_converter.py b/api/app/services/agent_config_converter.py index fbc75f4c..f86b8f19 100644 --- a/api/app/services/agent_config_converter.py +++ b/api/app/services/agent_config_converter.py @@ -51,6 +51,9 @@ class AgentConfigConverter: if hasattr(config, "skills") and config.skills: result["skills"] = config.skills.model_dump() + + if hasattr(config, "features") and config.features: + result["features"] = config.features.model_dump() return result diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 9b2b2a77..cd9d3e81 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -49,12 +49,23 @@ class AppChatService: storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, workspace_id: Optional[str] = None, - files: Optional[List[FileInput]] = None # 新增:多模态文件 + files: Optional[List[FileInput]] = None ) -> Dict[str, Any]: """聊天(非流式)""" start_time = time.time() config_id = None + # 应用 features 配置 + features_config: dict = config.features or {} + if hasattr(features_config, 'model_dump'): + features_config = features_config.model_dump() + web_search_feature = features_config.get("web_search", {}) + if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")): + web_search = False + + # 校验文件上传 + self.agent_service._validate_file_upload(features_config, files) + variables = self.agent_service.prepare_variables(variables, config.variables) # 获取模型配置ID @@ -107,17 +118,14 @@ class AppChatService: ) # 加载历史消息 - history = [] - memory_config = {"enabled": True, 'max_history': 10} - if memory_config.get("enabled"): - messages = self.conversation_service.get_messages( - conversation_id=conversation_id, - limit=memory_config.get("max_history", 10) - ) - history = [ - {"role": msg.role, "content": msg.content} - for msg in messages - ] + messages = self.conversation_service.get_messages( + conversation_id=conversation_id, + limit=10 + ) + history = [ + {"role": msg.role, "content": msg.content} + for msg in messages + ] # 处理多模态文件 processed_files = None @@ -166,6 +174,23 @@ class AppChatService: elapsed_time = time.time() - start_time + # suggested_questions + suggested_questions = [] + sq_config = features_config.get("suggested_questions_after_answer", {}) + if isinstance(sq_config, dict) and sq_config.get("enabled"): + suggested_questions = await self.agent_service._generate_suggested_questions( + features_config, result["content"], + {"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key, + "api_base": api_key_obj.api_base}, {} + ) + + audio_url = await self.agent_service._generate_tts( + features_config, result["content"], + {"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key, + "api_base": api_key_obj.api_base, "provider": api_key_obj.provider}, + tenant_id=tenant_id, workspace_id=workspace_id + ) + return { "conversation_id": conversation_id, "message_id": str(message_id), @@ -175,7 +200,10 @@ class AppChatService: "completion_tokens": 0, "total_tokens": 0 }), - "elapsed_time": elapsed_time + "elapsed_time": elapsed_time, + "suggested_questions": suggested_questions, + "citations": self.agent_service._filter_citations(features_config, result.get("citations", [])), + "audio_url": audio_url, } async def agnet_chat_stream( @@ -190,7 +218,7 @@ class AppChatService: storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, workspace_id: Optional[str] = None, - files: Optional[List[FileInput]] = None # 新增:多模态文件 + files: Optional[List[FileInput]] = None ) -> AsyncGenerator[str, None]: """聊天(流式)""" @@ -198,10 +226,19 @@ class AppChatService: start_time = time.time() config_id = None message_id = uuid.uuid4() - yield f"event: start\ndata: {json.dumps({ - 'conversation_id': str(conversation_id), - "message_id": str(message_id) - }, ensure_ascii=False)}\n\n" + + # 应用 features 配置 + features_config: dict = config.features or {} + if hasattr(features_config, 'model_dump'): + features_config = features_config.model_dump() + web_search_feature = features_config.get("web_search", {}) + if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")): + web_search = False + + # 校验文件上传 + self.agent_service._validate_file_upload(features_config, files) + + yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), 'message_id': str(message_id)}, ensure_ascii=False)}\n\n" variables = self.agent_service.prepare_variables(variables, config.variables) # 获取模型配置ID @@ -327,8 +364,22 @@ class AppChatService: ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) - # 发送结束事件 - end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None} + # 发送结束事件(包含 suggested_questions、tts、citations) + end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None} + sq_config = features_config.get("suggested_questions_after_answer", {}) + if isinstance(sq_config, dict) and sq_config.get("enabled"): + end_data["suggested_questions"] = await self.agent_service._generate_suggested_questions( + features_config, full_content, + {"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key, + "api_base": api_key_obj.api_base}, {} + ) + end_data["audio_url"] = await self.agent_service._generate_tts( + features_config, full_content, + {"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key, + "api_base": api_key_obj.api_base, "provider": api_key_obj.provider}, + tenant_id=tenant_id, workspace_id=workspace_id + ) + end_data["citations"] = self.agent_service._filter_citations(features_config, []) yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" logger.info( @@ -442,7 +493,7 @@ class AppChatService: try: message_id = uuid.uuid4() # 发送开始事件 - yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), "message_id": str(message_id)}, ensure_ascii=False)}\n\n" + yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), 'message_id': str(message_id)}, ensure_ascii=False)}\n\n" full_content = "" total_tokens = 0 diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index c06d79a9..a551b24c 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -109,7 +109,7 @@ class AppService: return share is not None - def _validate_app_accessible(self, app: App, workspace_id: Optional[uuid.UUID]) -> None: + def _validate_app_accessible(self, app: App, workspace_id: Optional[uuid.UUID]) -> None: """验证应用是否可访问(包括共享应用,用于只读操作) Args: @@ -360,6 +360,7 @@ class AppService: variables=storage_data.get("variables", []), tools=storage_data.get("tools", []), skills=storage_data.get("skills", {}), + features=storage_data.get("features", {}), is_active=True, created_at=now, updated_at=now, @@ -1073,6 +1074,7 @@ class AppService: # if data.tools is not None: agent_cfg.tools = storage_data.get("tools", []) agent_cfg.skills = storage_data.get("skills", {}) + agent_cfg.features = storage_data.get("features", {}) agent_cfg.updated_at = now @@ -1173,6 +1175,7 @@ class AppService: variables=[], tools=[], skills=[], + features={}, is_active=True, created_at=now, updated_at=now, @@ -1389,15 +1392,15 @@ class AppService: return config.config_id - def _update_endusers_memory_config( + def _update_endusers_memory_config_by_workspace( self, - app_id: uuid.UUID, + workspace_id: uuid.UUID, memory_config_id: uuid.UUID ) -> int: """批量更新应用下所有终端用户的 memory_config_id Args: - app_id: 应用ID + workspace_id: 工作空间ID memory_config_id: 新的记忆配置ID Returns: @@ -1406,8 +1409,8 @@ class AppService: from app.repositories.end_user_repository import EndUserRepository repo = EndUserRepository(self.db) - updated_count = repo.batch_update_memory_config_id( - app_id=app_id, + updated_count = repo.batch_update_memory_config_id_by_workspace( + workspace_id=workspace_id, memory_config_id=memory_config_id ) @@ -1578,11 +1581,15 @@ class AppService: ) if memory_config_id: - updated_count = self._update_endusers_memory_config(app_id, memory_config_id) - logger.info( - f"发布时更新终端用户记忆配置: app_id={app_id}, " - f"memory_config_id={memory_config_id}, updated_count={updated_count}" - ) + app = self.db.query(App).filter(App.id == app_id).first() + if app: + updated_count = self._update_endusers_memory_config_by_workspace( + app.workspace_id, memory_config_id + ) + logger.info( + f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, " + f"memory_config_id={memory_config_id}, updated_count={updated_count}" + ) # 更新当前发布版本指针 app.current_release_id = release.id diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 619a5f10..1ba47bba 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -18,6 +18,7 @@ from sqlalchemy.orm import Session from app.celery_app import celery_app from app.core.agent.agent_middleware import AgentMiddleware from app.core.agent.langchain_agent import LangChainAgent +from app.core.config import settings from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger @@ -262,9 +263,12 @@ class AgentRunService: def load_tools_config(self, tools_config, web_search, tenant_id) -> list: """加载工具配置""" - if not tools_config: - return [] tools = [] + if web_search: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + if not tools_config: + return tools tool_service = ToolService(self.db) if tools_config and isinstance(tools_config, list): @@ -273,24 +277,15 @@ class AgentRunService: # 根据工具名称查找工具实例 tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue # 转换为LangChain工具 langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) tools.append(langchain_tool) - elif tools_config and isinstance(tools_config, dict): - web_search_choice = tools_config.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search and web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) return tools def load_skill_config( @@ -373,6 +368,86 @@ class AgentRunService: ) return tools, bool(memory_config.get("enabled")) + @staticmethod + def _validate_file_upload( + features_config: Dict[str, Any], + files: Optional[List[FileInput]] + ) -> None: + """校验上传文件是否符合 file_upload 配置""" + if not files: + return + fu = features_config.get("file_upload", {}) + if not (isinstance(fu, dict) and fu.get("enabled")): + raise BusinessException("该应用未开启文件上传功能", BizCode.BAD_REQUEST) + max_count = fu.get("max_file_count", 5) + if len(files) > max_count: + raise BusinessException(f"文件数量超过限制(最多 {max_count} 个)", BizCode.BAD_REQUEST) + + # 校验传输方式 + allowed_methods = fu.get("allowed_transfer_methods", ["local_file", "remote_url"]) + for f in files: + if str(f.transfer_method) not in allowed_methods: + raise BusinessException( + f"不支持的文件传输方式:{f.transfer_method},允许的方式:{', '.join(allowed_methods)}", + BizCode.BAD_REQUEST + ) + + # 各类型对应的开关和大小限制配置键 + type_cfg = { + "image": ("image_enabled", "image_max_size_mb", 20, "图片"), + "audio": ("audio_enabled", "audio_max_size_mb", 50, "音频"), + "document": ("document_enabled", "document_max_size_mb", 100, "文档"), + "video": ("video_enabled", "video_max_size_mb", 500, "视频"), + } + + for f in files: + ftype = str(f.type) # 如 "image", "audio", "document", "video" + cfg = type_cfg.get(ftype) + if cfg is None: + continue + enabled_key, size_key, default_max_mb, label = cfg + + # 校验类型开关 + if not fu.get(enabled_key): + raise BusinessException(f"该应用未开启{label}文件上传", BizCode.BAD_REQUEST) + + # 校验文件大小(仅当内容已加载时) + content = f.get_content() + if content is not None: + max_mb = fu.get(size_key, default_max_mb) + size_mb = len(content) / (1024 * 1024) + if size_mb > max_mb: + raise BusinessException( + f"{label}文件大小超过限制(最大 {max_mb}MB,当前 {size_mb:.1f}MB)", + BizCode.BAD_REQUEST + ) + + @staticmethod + def _inject_opening_statement( + features_config: Dict[str, Any], + system_prompt: str, + is_new_conversation: bool + ) -> str: + """首轮对话时将开场白注入 system_prompt""" + if not is_new_conversation: + return system_prompt + opening = features_config.get("opening_statement", {}) + if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")): + return system_prompt + statement = opening["statement"] + return f"{system_prompt}\n\n[对话开场白]\n{statement}" + + @staticmethod + def _filter_citations( + features_config: Dict[str, Any], + citations: List[Any] + ) -> List[Any]: + """根据 citation 开关决定是否返回引用来源""" + citation_cfg = features_config.get("citation", {}) + if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"): + return citations + return [] + async def run( self, *, @@ -415,6 +490,15 @@ class AgentRunService: skills_config: dict | None = agent_config.skills knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval memory_config: dict | None = agent_config.memory + features_config: dict = agent_config.features or {} + + # 从 features 中读取功能开关(优先级高于参数默认值) + web_search_feature = features_config.get("web_search", {}) + if not isinstance(web_search_feature, dict) or not web_search_feature.get("enabled"): + web_search = False + + # file_upload 校验 + self._validate_file_upload(features_config, files) try: # 1. 获取 API Key 配置 @@ -449,6 +533,10 @@ class AgentRunService: # 3. 处理系统提示词(支持变量替换) system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" + # opening_statement:首轮对话注入开场白 + is_new_conversation = not conversation_id + system_prompt = self._inject_opening_statement(features_config, system_prompt, is_new_conversation) + # 4. 准备工具列表 tools = [] @@ -491,12 +579,10 @@ class AgentRunService: ) # 6. 加载历史消息 - history = [] - if memory_config and memory_config.get("enabled"): - history = await self._load_conversation_history( - conversation_id=conversation_id, - max_history=agent_config.memory.get("max_history", 10) - ) + history = await self._load_conversation_history( + conversation_id=conversation_id, + max_history=10 + ) # 6. 处理多模态文件 processed_files = None @@ -551,7 +637,7 @@ class AgentRunService: ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id")) # 9. 保存会话消息 - if not sub_agent and memory_config and memory_config.get("enabled"): + if not sub_agent: await self._save_conversation_message( conversation_id=conversation_id, user_message=message, @@ -575,7 +661,15 @@ class AgentRunService: "completion_tokens": 0, "total_tokens": 0 }), - "elapsed_time": elapsed_time + "elapsed_time": elapsed_time, + "suggested_questions": await self._generate_suggested_questions( + features_config, result["content"], api_key_config, effective_params + ) if not sub_agent else [], + "citations": self._filter_citations(features_config, result.get("citations", [])), + "audio_url": await self._generate_tts( + features_config, result["content"], api_key_config, + tenant_id=tenant_id, workspace_id=workspace_id + ) if not sub_agent else None, } logger.info( @@ -630,6 +724,15 @@ class AgentRunService: skills_config: dict | None = agent_config.skills knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval memory_config: dict | None = agent_config.memory + features_config: dict = agent_config.features or {} + + # 从 features 中读取功能开关 + web_search_feature = features_config.get("web_search", {}) + if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")): + web_search = False + + # file_upload 校验 + self._validate_file_upload(features_config, files) start_time = time.time() @@ -659,6 +762,10 @@ class AgentRunService: # 3. 处理系统提示词(支持变量替换) system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" + # opening_statement:首轮对话注入开场白 + is_new_conversation = not conversation_id + system_prompt = self._inject_opening_statement(features_config, system_prompt, is_new_conversation) + # 4. 准备工具列表 tools = [] @@ -703,12 +810,10 @@ class AgentRunService: ) # 6. 加载历史消息 - history = [] - if memory_config and memory_config.get("enabled"): - history = await self._load_conversation_history( - conversation_id=conversation_id, - max_history=memory_config.get("max_history", 10) - ) + history = await self._load_conversation_history( + conversation_id=conversation_id, + max_history=memory_config.get("max_history", 10) + ) # 6. 处理多模态文件 processed_files = None @@ -774,7 +879,7 @@ class AgentRunService: }) # 10. 保存会话消息 - if not sub_agent and memory_config and memory_config.get("enabled"): + if not sub_agent: await self._save_conversation_message( conversation_id=conversation_id, user_message=message, @@ -786,12 +891,22 @@ class AgentRunService: } ) - # 11. 发送结束事件 - yield self._format_sse_event("end", { + # 11. 发送结束事件(包含 suggested_questions 和 tts) + end_data: Dict[str, Any] = { "conversation_id": conversation_id, "elapsed_time": elapsed_time, "message_length": len(full_content) - }) + } + if not sub_agent: + end_data["suggested_questions"] = await self._generate_suggested_questions( + features_config, full_content, api_key_config, effective_params + ) + end_data["audio_url"] = await self._generate_tts( + features_config, full_content, api_key_config, + tenant_id=tenant_id, workspace_id=workspace_id + ) + end_data["citations"] = self._filter_citations(features_config, []) + yield self._format_sse_event("end", end_data) logger.info( "流式试运行完成", @@ -1137,6 +1252,165 @@ class AgentRunService: logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)}) return {} + async def _generate_suggested_questions( + self, + features_config: Dict[str, Any], + assistant_message: str, + api_key_config: Dict[str, Any], + effective_params: Dict[str, Any] + ) -> List[str]: + """根据 suggested_questions_after_answer 配置生成下一步建议问题""" + sq_config = features_config.get("suggested_questions_after_answer", {}) + if not isinstance(sq_config, dict) or not sq_config.get("enabled"): + return [] + try: + from langchain_openai import ChatOpenAI + from langchain_core.messages import HumanMessage, SystemMessage + llm = ChatOpenAI( + model=api_key_config["model_name"], + api_key=api_key_config["api_key"], + base_url=api_key_config.get("api_base"), + temperature=0.5, + max_tokens=200, + ) + prompt = ( + f"根据以下AI回复,生成3个用户可能继续追问的简短问题,每行一个,不加序号:\n\n{assistant_message}" + ) + resp = await llm.ainvoke([HumanMessage(content=prompt)]) + lines = [l.strip() for l in resp.content.strip().split("\n") if l.strip()] + return lines[:3] + except Exception as e: + logger.warning(f"生成建议问题失败: {e}") + return [] + + async def _generate_tts( + self, + features_config: Dict[str, Any], + text: str, + api_key_config: Dict[str, Any], + tenant_id: Optional[uuid.UUID] = None, + workspace_id: Optional[uuid.UUID] = None, + ) -> Optional[str]: + """根据 text_to_speech 配置生成语音,上传到存储并返回 URL""" + tts_config = features_config.get("text_to_speech", {}) + if not isinstance(tts_config, dict) or not tts_config.get("enabled"): + return None + if not text or not text.strip(): + return None + + try: + from app.services.file_storage_service import FileStorageService + + provider = api_key_config.get("provider", "openai") + api_key = api_key_config.get("api_key") + api_base = api_key_config.get("api_base") + voice = tts_config.get("voice") + + if provider == "dashscope": + audio_bytes, file_ext, content_type = await self._tts_dashscope( + api_key=api_key, + text=text, + voice=voice or "longxiaochun", # 会根据 model 版本自动修正后缀 + tts_config=tts_config, + ) + else: + # OpenAI 兼容接口(openai / xinference / gpustack 等) + audio_bytes, file_ext, content_type = await self._tts_openai( + api_key=api_key, + api_base=api_base, + text=text, + voice=voice or "alloy", + ) + + storage_service = FileStorageService() + file_id = uuid.uuid4() + file_key = await storage_service.upload_file( + tenant_id=tenant_id, + workspace_id=workspace_id, + file_id=file_id, + file_ext=file_ext, + content=audio_bytes, + content_type=content_type, + ) + + # 保存文件元数据到数据库 + from app.models.file_metadata_model import FileMetadata + db_file = FileMetadata( + id=file_id, + tenant_id=tenant_id, + workspace_id=workspace_id, + file_key=file_key, + file_name=f"tts_{file_id}{file_ext}", + file_ext=file_ext, + file_size=len(audio_bytes), + content_type=content_type, + status="completed", + ) + self.db.add(db_file) + self.db.commit() + + server_url = settings.FILE_LOCAL_SERVER_URL + audio_url = f"{server_url}/storage/permanent/{file_id}" + logger.debug(f"TTS 生成成功,provider={provider}, file_key={file_key}") + return audio_url + + except Exception as e: + logger.warning(f"TTS 生成失败: {e}") + return None + + @staticmethod + async def _tts_openai( + api_key: str, + api_base: Optional[str], + text: str, + voice: str, + ) -> tuple: + """OpenAI 兼容 TTS,返回 (audio_bytes, file_ext, content_type)""" + from openai import AsyncOpenAI + client = AsyncOpenAI(api_key=api_key, base_url=api_base) + response = await client.audio.speech.create( + model="tts-1", + voice=voice, + input=text[:4096], + ) + return response.content, ".mp3", "audio/mpeg" + + @staticmethod + async def _tts_dashscope( + api_key: str, + text: str, + voice: str, + tts_config: Dict[str, Any], + ) -> tuple: + """DashScope CosyVoice TTS,返回 (audio_bytes, file_ext, content_type)""" + import dashscope + from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat + + model = tts_config.get("model") or "cosyvoice-v2" + is_v2 = model.endswith("-v2") + + # cosyvoice-v2 音色名带 _v2 后缀,v1 不带 + # 如果用户传入的 voice 不匹配当前模型版本,自动修正 + if is_v2 and not voice.endswith("_v2"): + voice = voice + "_v2" + elif not is_v2 and voice.endswith("_v2"): + voice = voice[:-3] # 去掉 _v2 + + def _sync_call() -> bytes: + dashscope.api_key = api_key + synthesizer = SpeechSynthesizer( + model=model, + voice=voice, + format=AudioFormat.MP3_22050HZ_MONO_256KBPS, + ) + audio = synthesizer.call(text[:4096]) + if audio is None: + raise RuntimeError("DashScope TTS 返回空音频") + return audio + + audio_bytes = await asyncio.to_thread(_sync_call) + return audio_bytes, ".mp3", "audio/mpeg" + def _replace_variables( self, text: str, @@ -1221,6 +1495,12 @@ class AgentRunService: } ) + # 提前校验文件上传(与 run() 内部保持一致) + features_config: dict = agent_config.features or {} + if hasattr(features_config, 'model_dump'): + features_config = features_config.model_dump() + # self._validate_file_upload(features_config, files) + async def run_single_model(model_info): """运行单个模型""" try: @@ -1271,6 +1551,9 @@ class AgentRunService: if elapsed > 0 and usage.get("completion_tokens") else None ), "cost_estimate": self._estimate_cost(usage, model_info["model_config"]), + "audio_url": result.get("audio_url"), + "citations": result.get("citations", []), + "suggested_questions": result.get("suggested_questions", []), "error": None } @@ -1343,7 +1626,12 @@ class AgentRunService: ) return { - "results": results, + "results": [{ + **r, + "audio_url": r.get("audio_url"), + "citations": r.get("citations", []), + "suggested_questions": r.get("suggested_questions", []), + } for r in results], "total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results), "successful_count": len(successful), "failed_count": len(failed), @@ -1434,6 +1722,12 @@ class AgentRunService: extra={"model_count": len(models), "parallel": parallel} ) + # 提前校验文件上传 + # features_config: dict = agent_config.features or {} + # if hasattr(features_config, 'model_dump'): + # features_config = features_config.model_dump() + # self._validate_file_upload(features_config, files) + # 发送开始事件 yield self._format_sse_event("compare_start", { "conversation_id": conversation_id, @@ -1465,6 +1759,9 @@ class AgentRunService: start_time = time.time() full_content = "" returned_conversation_id = model_conversation_id + audio_url = None + citations = [] + suggested_questions = [] # 临时修改参数 original_params = agent_config.model_parameters @@ -1518,6 +1815,12 @@ class AgentRunService: "content": chunk })) + # 从 end 事件中提取 features 输出字段 + if event_type == "end" and event_data: + audio_url = event_data.get("audio_url") + citations = event_data.get("citations", []) + suggested_questions = event_data.get("suggested_questions", []) + if event_type == "error" and event_data: await event_queue.put(self._format_sse_event("model_error", { "model_index": idx, @@ -1543,6 +1846,9 @@ class AgentRunService: "parameters_used": model_info["parameters"], "message": full_content, "elapsed_time": elapsed, + "audio_url": audio_url, + "citations": citations, + "suggested_questions": suggested_questions, "error": None } @@ -1554,6 +1860,9 @@ class AgentRunService: "conversation_id": returned_conversation_id, "elapsed_time": elapsed, "message_length": len(full_content), + "audio_url": audio_url, + "citations": citations, + "suggested_questions": suggested_questions, "timestamp": time.time() })) @@ -1685,8 +1994,11 @@ class AgentRunService: "model_name": r["model_name"], "label": r["label"], "conversation_id": r.get("conversation_id"), - "message": r.get("message"), # 包含完整消息 + "message": r.get("message"), "elapsed_time": r.get("elapsed_time", 0), + "audio_url": r.get("audio_url"), + "citations": r.get("citations", []), + "suggested_questions": r.get("suggested_questions", []), "error": r.get("error") }) diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index be656acb..d0078088 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -68,14 +68,14 @@ def get_workspace_end_users( return [] # 提取所有 app_id - app_ids = [app.id for app in apps_orm] + # app_ids = [app.id for app in apps_orm] # 批量查询所有 end_users(一次查询而非循环查询) # 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性 from app.models.end_user_model import EndUser as EndUserModel from sqlalchemy import desc, nullslast end_users_orm = db.query(EndUserModel).filter( - EndUserModel.app_id.in_(app_ids) + EndUserModel.workspace_id == workspace_id ).order_by( nullslast(desc(EndUserModel.created_at)), desc(EndUserModel.id) diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 23def7f8..089f0ec5 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -78,7 +78,7 @@ class ToolService: def get_tool_info(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolInfo]: """获取工具详情""" - config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id) + config = self.tool_repo.find_by_id_and_tenant_all(self.db, uuid.UUID(tool_id), tenant_id) return self._config_to_info(config) if config else None def _check_name_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, exclude_id: Optional[uuid.UUID] = None): @@ -237,7 +237,7 @@ class ToolService: return False def delete_tool(self, tool_id: str, tenant_id: uuid.UUID) -> bool: - """删除工具""" + """删除工具(逻辑删除)""" config = self._get_tool_config(tool_id, tenant_id) if not config: return False @@ -246,14 +246,7 @@ class ToolService: raise ValueError("内置工具不允许删除") try: - # 删除关联表记录 - if config.tool_type == ToolType.CUSTOM.value: - self.db.query(CustomToolConfig).filter(CustomToolConfig.id == config.id).delete() - elif config.tool_type == ToolType.MCP.value: - self.db.query(MCPToolConfig).filter(MCPToolConfig.id == config.id).delete() - - # 删除主表记录(ToolExecution会通过cascade自动删除) - self.db.delete(config) + config.is_active = False self._clear_tool_cache(tool_id) self.db.commit() return True @@ -262,6 +255,27 @@ class ToolService: logger.error(f"删除工具失败: {tool_id}, {e}") return False + def set_tool_active(self, tool_id: str, tenant_id: uuid.UUID, is_active: bool) -> bool: + """设置工具可用状态(启用/禁用)""" + # 直接查询,包含 is_active=False 的记录 + config = self.db.query(ToolConfig).filter( + ToolConfig.id == uuid.UUID(tool_id), + ToolConfig.tenant_id == tenant_id + ).first() + if not config: + return False + if config.tool_type == ToolType.BUILTIN.value: + raise ValueError("内置工具不允许修改可用状态") + try: + config.is_active = is_active + self._clear_tool_cache(tool_id) + self.db.commit() + return True + except Exception as e: + self.db.rollback() + logger.error(f"设置工具状态失败: {tool_id}, {e}") + return False + async def execute_tool( self, tool_id: str, @@ -378,7 +392,7 @@ class ToolService: Returns: 方法列表或None """ - config = self._get_tool_config(tool_id, tenant_id) + config = self._get_tool_config_all(tool_id, tenant_id) if not config: return None @@ -857,16 +871,20 @@ class ToolService: } def _get_tool_config(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]: - """获取工具配置""" + """获取工具配置(仅返回 is_active=True)""" return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id) + def _get_tool_config_all(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]: + """获取工具配置(返回所有)""" + return self.tool_repo.find_by_id_and_tenant_all(self.db, uuid.UUID(tool_id), tenant_id) + def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]: - """获取工具实例""" + """获取工具实例(仅返回 is_active=True 的工具)""" if tool_id in self._tool_cache: return self._tool_cache[tool_id] config = self._get_tool_config(tool_id, tenant_id) - if not config: + if not config or not config.is_active: return None try: @@ -980,6 +998,7 @@ class ToolService: tags=config.tags or [], tenant_id=str(config.tenant_id) if config.tenant_id else None, config_data=config_data, + is_active=config.is_active, created_at=config.created_at ) diff --git a/api/app/tasks.py b/api/app/tasks.py index cae3719b..f5258330 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1292,9 +1292,9 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: } # 2. 查询所有app下的end_user_id(去重) - app_ids = [app.id for app in apps] + # app_ids = [app.id for app in apps] end_users = db.query(EndUser.id).filter( - EndUser.app_id.in_(app_ids) + EndUser.workspace_id == workspace_id ).distinct().all() # 3. 遍历所有end_user,查询每个宿主的记忆总量并累加 @@ -1433,9 +1433,9 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: continue # 2. 查询所有app下的end_user_id(去重) - app_ids = [app.id for app in apps] + # app_ids = [app.id for app in apps] end_users = db.query(EndUser.id).filter( - EndUser.app_id.in_(app_ids) + EndUser.workspace_id == workspace_id ).distinct().all() # 3. 遍历所有end_user,查询每个宿主的记忆总量并累加