feat(app):

1. Add new functional features to the agent;
2. Enhance the voice output;
3. Modify the end_user binding;
4. Delete and modify the tools.
This commit is contained in:
Timebomb2018
2026-03-16 18:00:09 +08:00
parent b62c40dba3
commit ea391dc44e
22 changed files with 832 additions and 184 deletions

View File

@@ -254,6 +254,27 @@ def get_agent_config(
return success(data=app_schema.AgentConfig.model_validate(cfg)) return success(data=app_schema.AgentConfig.model_validate(cfg))
@router.get("/{app_id}/opening", summary="获取应用开场白配置")
@cur_workspace_access_guard()
def get_opening(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
workspace_id = current_user.current_workspace_id
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
features = cfg.features or {}
if hasattr(features, "model_dump"):
features = features.model_dump()
opening = features.get("opening_statement", {})
return success(data=app_schema.OpeningResponse(
enabled=opening.get("enabled", False),
statement=opening.get("statement"),
suggested_questions=opening.get("suggested_questions", []),
))
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)") @router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
@cur_workspace_access_guard() @cur_workspace_access_guard()
def publish_app( def publish_app(
@@ -513,11 +534,11 @@ async def draft_run(
service._validate_app_accessible(app, workspace_id) service._validate_app_accessible(app, workspace_id)
if payload.user_id is None: if payload.user_id is None:
# 先获取 app 的 workspace_id
end_user_repo = EndUserRepository(db) end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user( new_end_user = end_user_repo.get_or_create_end_user(
app_id=app_id, workspace_id=app.workspace_id,
other_id=str(current_user.id), other_id=str(current_user.id),
original_user_id=str(current_user.id) # Save original user_id to other_id
) )
payload.user_id = str(new_end_user.id) payload.user_id = str(new_end_user.id)
@@ -845,11 +866,11 @@ async def draft_run_compare(
service._validate_app_accessible(app, workspace_id) service._validate_app_accessible(app, workspace_id)
if payload.user_id is None: if payload.user_id is None:
# 先获取 app 的 workspace_id
end_user_repo = EndUserRepository(db) end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user( new_end_user = end_user_repo.get_or_create_end_user(
app_id=app_id, workspace_id=app.workspace_id,
other_id=str(current_user.id), other_id=str(current_user.id),
original_user_id=str(current_user.id) # Save original user_id to other_id
) )
payload.user_id = str(new_end_user.id) payload.user_id = str(new_end_user.id)
@@ -898,7 +919,12 @@ async def draft_run_compare(
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id "conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
}) })
# 从 features 中读取功能开关(与 draft_run 保持一致)
features_config: dict = agent_cfg.features or {}
if hasattr(features_config, 'model_dump'):
features_config = features_config.model_dump()
web_search_feature = features_config.get("web_search", {})
web_search = isinstance(web_search_feature, dict) and web_search_feature.get("enabled", False)
# 流式返回 # 流式返回
if payload.stream: if payload.stream:
@@ -915,7 +941,7 @@ async def draft_run_compare(
variables=payload.variables, variables=payload.variables,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id, user_rag_memory_id=user_rag_memory_id,
web_search=True, web_search=web_search,
memory=True, memory=True,
parallel=payload.parallel, parallel=payload.parallel,
timeout=payload.timeout or 60, timeout=payload.timeout or 60,
@@ -946,7 +972,7 @@ async def draft_run_compare(
variables=payload.variables, variables=payload.variables,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id, user_rag_memory_id=user_rag_memory_id,
web_search=True, web_search=web_search,
memory=True, memory=True,
parallel=payload.parallel, parallel=payload.parallel,
timeout=payload.timeout or 60, timeout=payload.timeout or 60,

View File

@@ -22,6 +22,7 @@ from app.schemas import release_share_schema, conversation_schema
from app.schemas.response_schema import PageData, PageMeta from app.schemas.response_schema import PageData, PageMeta
from app.services import workspace_service from app.services import workspace_service
from app.services.app_chat_service import AppChatService, get_app_chat_service from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.services.app_service import AppService
from app.services.auth_service import create_access_token from app.services.auth_service import create_access_token
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService from app.services.release_share_service import ReleaseShareService
@@ -215,8 +216,10 @@ def list_conversations(
service = SharedChatService(db) service = SharedChatService(db)
share, release = service.get_release_by_share_token(share_data.share_token, password) share, release = service.get_release_by_share_token(share_data.share_token, password)
end_user_repo = EndUserRepository(db) end_user_repo = EndUserRepository(db)
app_service = AppService(db)
app = app_service._get_app_or_404(share.app_id)
new_end_user = end_user_repo.get_or_create_end_user( new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id, workspace_id=app.workspace_id,
other_id=other_id other_id=other_id
) )
logger.debug(new_end_user.id) logger.debug(new_end_user.id)
@@ -308,25 +311,28 @@ async def chat(
# Store end_user_id in database with original user_id # Store end_user_id in database with original user_id
end_user_repo = EndUserRepository(db) end_user_repo = EndUserRepository(db)
app_service = AppService(db)
app = app_service._get_app_or_404(share.app_id)
workspace_id = app.workspace_id
new_end_user = end_user_repo.get_or_create_end_user( new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id, workspace_id=workspace_id,
other_id=other_id, other_id=other_id,
original_user_id=user_id # Save original user_id to other_id original_user_id=user_id
) )
end_user_id = str(new_end_user.id) end_user_id = str(new_end_user.id)
appid = share.app_id # appid = share.app_id
"""获取存储类型和工作空间的ID""" """获取存储类型和工作空间的ID"""
# 直接通过 SQLAlchemy 查询 app仅查询未删除的应用 # 直接通过 SQLAlchemy 查询 app仅查询未删除的应用
app = db.query(App).filter( # app = db.query(App).filter(
App.id == appid, # App.id == appid,
App.is_active.is_(True) # App.is_active.is_(True)
).first() # ).first()
if not app: # if not app:
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND) # raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
workspace_id = app.workspace_id # workspace_id = app.workspace_id
# 直接从 workspace 获取 storage_type公开分享场景无需权限检查 # 直接从 workspace 获取 storage_type公开分享场景无需权限检查
storage_type = workspace_service.get_workspace_storage_type_without_auth( storage_type = workspace_service.get_workspace_storage_type_without_auth(
@@ -654,17 +660,20 @@ async def config_query(
workflow_service = WorkflowService(db) workflow_service = WorkflowService(db)
content = { content = {
"app_type": release.app.type, "app_type": release.app.type,
"variables": workflow_service.get_start_node_variables(release.config) "variables": workflow_service.get_start_node_variables(release.config),
"features": release.config.get("features")
} }
elif release.app.type == AppType.AGENT: elif release.app.type == AppType.AGENT:
content = { content = {
"app_type": release.app.type, "app_type": release.app.type,
"variables": release.config.get("variables") "variables": release.config.get("variables"),
"features": release.config.get("features")
} }
elif release.app.type == AppType.MULTI_AGENT: elif release.app.type == AppType.MULTI_AGENT:
content = { content = {
"app_type": release.app.type, "app_type": release.app.type,
"variables": [] "variables": [],
"features": release.config.get("features")
} }
else: else:
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED) return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)

View File

@@ -94,9 +94,8 @@ async def chat(
workspace_id = app.workspace_id workspace_id = app.workspace_id
end_user_repo = EndUserRepository(db) end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user( new_end_user = end_user_repo.get_or_create_end_user(
app_id=app.id, workspace_id=workspace_id,
other_id=other_id, other_id=other_id,
original_user_id=other_id # Save original user_id to other_id
) )
end_user_id = str(new_end_user.id) end_user_id = str(new_end_user.id)
web_search = True web_search = True

View File

@@ -4,7 +4,8 @@ from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.schemas.tool_schema import ( from app.schemas.tool_schema import (
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest,
CustomToolTestRequest, ToolActiveUpdate
) )
from app.core.response_utils import success from app.core.response_utils import success
@@ -156,7 +157,7 @@ async def delete_tool(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service) service: ToolService = Depends(get_tool_service)
): ):
"""删除工具""" """删除工具逻辑删除is_active=False"""
try: try:
success_flag = service.delete_tool(tool_id, current_user.tenant_id) success_flag = service.delete_tool(tool_id, current_user.tenant_id)
if not success_flag: if not success_flag:
@@ -168,6 +169,30 @@ async def delete_tool(
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.patch("/{tool_id}/active", response_model=ApiResponse)
async def set_tool_active(
tool_id: str,
request: ToolActiveUpdate,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""设置工具可用状态(启用/禁用)
- is_active=true: 启用工具
- is_active=false: 禁用工具(等同于删除,但可恢复)
"""
try:
success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active)
if not success_flag:
raise HTTPException(status_code=404, detail="工具不存在")
action = "启用" if request.is_active else "禁用"
return success(msg=f"工具已{action}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/execution/execute", response_model=ApiResponse) @router.post("/execution/execute", response_model=ApiResponse)
async def execute_tool( async def execute_tool(
request: ToolExecuteRequest, request: ToolExecuteRequest,

View File

@@ -23,7 +23,7 @@ class SimpleMCPClient:
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None): def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
self.server_url = server_url self.server_url = server_url
self.connection_config = connection_config or {} self.connection_config = connection_config or {}
self.timeout = self.connection_config.get("timeout", 30) self.timeout = self.connection_config.get("timeout", 10)
# 确定连接类型 # 确定连接类型
self.is_websocket = server_url.startswith(("ws://", "wss://")) self.is_websocket = server_url.startswith(("ws://", "wss://"))

View File

@@ -16,7 +16,7 @@ engine = create_engine(
pool_recycle=settings.DB_POOL_RECYCLE, pool_recycle=settings.DB_POOL_RECYCLE,
pool_timeout=settings.DB_POOL_TIMEOUT, pool_timeout=settings.DB_POOL_TIMEOUT,
connect_args={ connect_args={
"options": "-c timezone=Asia/Shanghai -c statement_timeout=60000" "options": "-c timezone=UTC -c statement_timeout=60000"
}, },
) )
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

View File

@@ -31,6 +31,7 @@ class AgentConfig(Base):
variables = Column(JSON, default=list, nullable=True, comment="变量配置") variables = Column(JSON, default=list, nullable=True, comment="变量配置")
tools = Column(JSON, default=list, nullable=True, comment="工具配置") tools = Column(JSON, default=list, nullable=True, comment="工具配置")
skills = Column(JSON, default=dict, nullable=True, comment="技能配置") skills = Column(JSON, default=dict, nullable=True, comment="技能配置")
features = Column(JSON, default=dict, nullable=True, comment="功能特性配置")
# 多 Agent 相关字段 # 多 Agent 相关字段
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone") agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")

View File

@@ -12,7 +12,8 @@ class EndUser(Base):
__tablename__ = "end_users" __tablename__ = "end_users"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True)
app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False) app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=True)
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False)
# end_user_id = Column(String, nullable=False, index=True) # end_user_id = Column(String, nullable=False, index=True)
other_id = Column(String, nullable=True) # Store original user_id other_id = Column(String, nullable=True) # Store original user_id
other_name = Column(String, default="", nullable=False) other_name = Column(String, default="", nullable=False)
@@ -61,4 +62,7 @@ class EndUser(Base):
app = relationship( app = relationship(
"App", "App",
back_populates="end_users" back_populates="end_users"
) )
# 与 WorkSpace 的反向关系
workspace = relationship("Workspace", back_populates="end_users")

View File

@@ -110,7 +110,10 @@ class ToolConfig(Base):
# 元数据 # 元数据
version = Column(String(50), default="1.0.0") version = Column(String(50), default="1.0.0")
tags = Column(JSON, default=list) # 标签列表 tags = Column(JSON, default=list) # 标签列表
# 逻辑删除标志
is_active = Column(Boolean, default=True, server_default='true', nullable=False, index=True, comment="是否可用False表示已删除")
# 时间戳 # 时间戳
created_at = Column(DateTime, default=datetime.now, nullable=False) created_at = Column(DateTime, default=datetime.now, nullable=False)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False) updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)

View File

@@ -38,6 +38,7 @@ class Workspace(Base):
members = relationship("WorkspaceMember", back_populates="workspace") # users collaborate through membership members = relationship("WorkspaceMember", back_populates="workspace") # users collaborate through membership
api_keys = relationship("ApiKey", back_populates="workspace", cascade="all, delete-orphan") # API Keys api_keys = relationship("ApiKey", back_populates="workspace", cascade="all, delete-orphan") # API Keys
memory_increments = relationship("MemoryIncrement", back_populates="workspace") memory_increments = relationship("MemoryIncrement", back_populates="workspace")
end_users = relationship("EndUser", back_populates="workspace", cascade="all, delete-orphan")
class WorkspaceMember(Base): class WorkspaceMember(Base):
__tablename__ = "workspace_members" __tablename__ = "workspace_members"

View File

@@ -32,6 +32,21 @@ class EndUserRepository:
db_logger.error(f"查询应用 {app_id} 下宿主时出错: {str(e)}") db_logger.error(f"查询应用 {app_id} 下宿主时出错: {str(e)}")
raise raise
def get_end_users_by_workspace(self, workspace_id: uuid.UUID) -> List[EndUser]:
"""获取指定 workspace 下的所有 end_user"""
try:
end_users = (
self.db.query(EndUser)
.filter(EndUser.workspace_id == workspace_id)
.all()
)
db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(end_users)} 个终端用户")
return end_users
except Exception as e:
self.db.rollback()
db_logger.error(f"查询工作空间 {workspace_id} 下终端用户时出错: {str(e)}")
raise
def get_end_user_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]: def get_end_user_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]:
"""根据 end_user_id 查询宿主""" """根据 end_user_id 查询宿主"""
try: try:
@@ -52,14 +67,14 @@ class EndUserRepository:
def get_or_create_end_user( def get_or_create_end_user(
self, self,
app_id: uuid.UUID, workspace_id: uuid.UUID,
other_id: str, other_id: str,
original_user_id: Optional[str] = None original_user_id: Optional[str] = None
) -> EndUser: ) -> EndUser:
"""获取或创建终端用户 """获取或创建终端用户
Args: Args:
app_id: 应用ID workspace_id: 工作空间ID
other_id: 第三方ID other_id: 第三方ID
original_user_id: 原始用户ID (存储到 other_id) original_user_id: 原始用户ID (存储到 other_id)
""" """
@@ -68,26 +83,27 @@ class EndUserRepository:
end_user = ( end_user = (
self.db.query(EndUser) self.db.query(EndUser)
.filter( .filter(
EndUser.app_id == app_id, EndUser.workspace_id == workspace_id,
EndUser.other_id == other_id EndUser.other_id == other_id
) )
.order_by(EndUser.created_at.asc())
.first() .first()
) )
if end_user: if end_user:
db_logger.debug(f"找到现有终端用户: 应用ID {app_id}、第三方ID {other_id}") db_logger.debug(f"找到现有终端用户: 应用ID {workspace_id}、第三方ID {other_id}")
return end_user return end_user
# 创建新用户 # 创建新用户
end_user = EndUser( end_user = EndUser(
app_id=app_id, workspace_id=workspace_id,
other_id=other_id other_id=other_id
) )
self.db.add(end_user) self.db.add(end_user)
self.db.commit() self.db.commit()
self.db.refresh(end_user) self.db.refresh(end_user)
db_logger.info(f"创建新终端用户: (other_id: {other_id}) for app {app_id}") db_logger.info(f"创建新终端用户: (other_id: {other_id}) for workspace {workspace_id}")
return end_user return end_user
except Exception as e: except Exception as e:
@@ -314,8 +330,7 @@ class EndUserRepository:
try: try:
end_users = ( end_users = (
self.db.query(EndUser) self.db.query(EndUser)
.join(App, EndUser.app_id == App.id) .filter(EndUser.workspace_id == workspace_id)
.filter(App.workspace_id == workspace_id)
.all() .all()
) )
db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(end_users)} 个终端用户") db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(end_users)} 个终端用户")
@@ -402,45 +417,79 @@ class EndUserRepository:
db_logger.error(f"获取终端用户 {end_user_id} 的 memory_config_id 时出错: {str(e)}") db_logger.error(f"获取终端用户 {end_user_id} 的 memory_config_id 时出错: {str(e)}")
raise raise
def batch_update_memory_config_id( # def batch_update_memory_config_id(
self, # self,
app_id: uuid.UUID, # app_id: uuid.UUID,
memory_config_id: uuid.UUID # memory_config_id: uuid.UUID
# ) -> int:
# """批量更新应用下所有终端用户的 memory_config_id
#
# Args:
# app_id: 应用ID
# memory_config_id: 新的记忆配置ID
#
# Returns:
# int: 更新的行数
# """
# try:
# from sqlalchemy import update
#
# stmt = (
# update(EndUser)
# .where(EndUser.app_id == app_id)
# .values(memory_config_id=memory_config_id)
# )
#
# result = self.db.execute(stmt)
# self.db.commit()
#
# updated_count = result.rowcount
#
# db_logger.info(
# f"批量更新终端用户记忆配置: app_id={app_id}, "
# f"memory_config_id={memory_config_id}, updated_count={updated_count}"
# )
#
# return updated_count
#
# except Exception as e:
# self.db.rollback()
# db_logger.error(
# f"批量更新终端用户记忆配置时出错: app_id={app_id}, "
# f"memory_config_id={memory_config_id}, error={str(e)}"
# )
# raise
def batch_update_memory_config_id_by_workspace(
self,
workspace_id: uuid.UUID,
memory_config_id: uuid.UUID
) -> int: ) -> int:
"""批量更新应用下所有终端用户的 memory_config_id """批量更新工作空间下所有终端用户的 memory_config_id"""
Args:
app_id: 应用ID
memory_config_id: 新的记忆配置ID
Returns:
int: 更新的行数
"""
try: try:
from sqlalchemy import update from sqlalchemy import update
stmt = ( stmt = (
update(EndUser) update(EndUser)
.where(EndUser.app_id == app_id) .where(EndUser.workspace_id == workspace_id)
.values(memory_config_id=memory_config_id) .values(memory_config_id=memory_config_id)
) )
result = self.db.execute(stmt) result = self.db.execute(stmt)
self.db.commit() self.db.commit()
updated_count = result.rowcount updated_count = result.rowcount
db_logger.info( db_logger.info(
f"批量更新终端用户记忆配置: app_id={app_id}, " f"批量更新终端用户记忆配置: workspace_id={workspace_id}, "
f"memory_config_id={memory_config_id}, updated_count={updated_count}" f"memory_config_id={memory_config_id}, updated_count={updated_count}"
) )
return updated_count return updated_count
except Exception as e: except Exception as e:
self.db.rollback() self.db.rollback()
db_logger.error( db_logger.error(
f"批量更新终端用户记忆配置时出错: app_id={app_id}, " f"批量更新终端用户记忆配置时出错: workspace_id={workspace_id}, "
f"memory_config_id={memory_config_id}, error={str(e)}" f"memory_config_id={memory_config_id}, error={str(e)}"
) )
raise raise
@@ -492,7 +541,7 @@ class EndUserRepository:
""" """
try: try:
from sqlalchemy import update from sqlalchemy import update
stmt = ( stmt = (
update(EndUser) update(EndUser)
.where(EndUser.memory_config_id == memory_config_id) .where(EndUser.memory_config_id == memory_config_id)
@@ -519,10 +568,16 @@ class EndUserRepository:
) )
raise raise
def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]: # def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]:
"""根据应用ID查询宿主返回 EndUser ORM 列表)""" # """根据应用ID查询宿主返回 EndUser ORM 列表)"""
# repo = EndUserRepository(db)
# end_users = repo.get_end_users_by_app_id(app_id)
# return end_users
def get_end_users_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[EndUser]:
"""根据工作空间ID查询终端用户返回 EndUser ORM 列表)"""
repo = EndUserRepository(db) repo = EndUserRepository(db)
end_users = repo.get_end_users_by_app_id(app_id) end_users = repo.get_end_users_by_workspace(workspace_id)
return end_users return end_users
def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]: def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]:

View File

@@ -27,7 +27,7 @@ class ToolRepository:
from app.models.app_model import App from app.models.app_model import App
from app.models.workflow_model import WorkflowConfig from app.models.workflow_model import WorkflowConfig
from app.models.workspace_model import Workspace from app.models.workspace_model import Workspace
result = db.query(Workspace.tenant_id).join( result = db.query(Workspace.tenant_id).join(
App, App.workspace_id == Workspace.id App, App.workspace_id == Workspace.id
).join( ).join(
@@ -35,7 +35,7 @@ class ToolRepository:
).filter( ).filter(
WorkflowConfig.id == workflow_id WorkflowConfig.id == workflow_id
).first() ).first()
return result[0] if result else None return result[0] if result else None
@staticmethod @staticmethod
@@ -67,18 +67,19 @@ class ToolRepository:
@staticmethod @staticmethod
def find_by_tenant( def find_by_tenant(
db: Session, db: Session,
tenant_id: uuid.UUID, tenant_id: uuid.UUID,
name: Optional[str] = None, name: Optional[str] = None,
tool_type: Optional[ToolType] = None, tool_type: Optional[ToolType] = None,
status: Optional[ToolStatus] = None, status: Optional[ToolStatus] = None,
is_enabled: Optional[bool] = None is_enabled: Optional[bool] = None
) -> List[ToolConfig]: ) -> List[ToolConfig]:
"""根据租户查找工具""" """根据租户查找工具(只返回未删除的)"""
query = db.query(ToolConfig).filter( query = db.query(ToolConfig).filter(
ToolConfig.tenant_id == tenant_id ToolConfig.tenant_id == tenant_id,
ToolConfig.is_active.is_(True)
) )
if name: if name:
query = query.filter(ToolConfig.name.ilike(f"%{name}%")) query = query.filter(ToolConfig.name.ilike(f"%{name}%"))
if tool_type: if tool_type:
@@ -91,8 +92,17 @@ class ToolRepository:
return query.all() return query.all()
@staticmethod @staticmethod
def find_by_id_and_tenant(db:Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]: def find_by_id_and_tenant(db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
"""根据ID和租户查找工具""" """根据ID和租户查找工具(只返回未删除的)"""
return db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == tenant_id,
ToolConfig.is_active.is_(True)
).first()
@staticmethod
def find_by_id_and_tenant_all(db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
"""根据ID和租户查找工具返回所有工具包括删除的"""
return db.query(ToolConfig).filter( return db.query(ToolConfig).filter(
ToolConfig.id == tool_id, ToolConfig.id == tool_id,
ToolConfig.tenant_id == tenant_id ToolConfig.tenant_id == tenant_id
@@ -100,29 +110,26 @@ class ToolRepository:
@staticmethod @staticmethod
def count_by_tenant(db: Session, tenant_id: uuid.UUID) -> int: def count_by_tenant(db: Session, tenant_id: uuid.UUID) -> int:
"""统计租户工具数量""" """统计租户工具数量(只统计未删除的)"""
return db.query(ToolConfig).filter( return db.query(ToolConfig).filter(
ToolConfig.tenant_id == tenant_id ToolConfig.tenant_id == tenant_id,
ToolConfig.is_active.is_(True)
).count() ).count()
@staticmethod @staticmethod
def get_status_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]: def get_status_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
"""获取状态统计""" """获取状态统计"""
return db.query( return db.query(ToolConfig.status, func.count(ToolConfig.id).label('count')).filter(
ToolConfig.status, ToolConfig.tenant_id == tenant_id,
func.count(ToolConfig.id).label('count') ToolConfig.is_active.is_(True)
).filter(
ToolConfig.tenant_id == tenant_id
).group_by(ToolConfig.status).all() ).group_by(ToolConfig.status).all()
@staticmethod @staticmethod
def get_type_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]: def get_type_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
"""获取类型统计""" """获取类型统计"""
return db.query( return db.query(ToolConfig.tool_type, func.count(ToolConfig.id).label('count')).filter(
ToolConfig.tool_type, ToolConfig.tenant_id == tenant_id,
func.count(ToolConfig.id).label('count') ToolConfig.is_active.is_(True)
).filter(
ToolConfig.tenant_id == tenant_id
).group_by(ToolConfig.tool_type).all() ).group_by(ToolConfig.tool_type).all()
@staticmethod @staticmethod
@@ -130,6 +137,7 @@ class ToolRepository:
"""统计租户启用的工具数量""" """统计租户启用的工具数量"""
return db.query(ToolConfig).filter( return db.query(ToolConfig).filter(
ToolConfig.tenant_id == tenant_id, ToolConfig.tenant_id == tenant_id,
ToolConfig.is_active.is_(True),
ToolConfig.is_enabled == True ToolConfig.is_enabled == True
).count() ).count()
@@ -138,7 +146,8 @@ class ToolRepository:
"""检查租户是否已有内置工具""" """检查租户是否已有内置工具"""
return db.query(ToolConfig).filter( return db.query(ToolConfig).filter(
ToolConfig.tenant_id == tenant_id, ToolConfig.tenant_id == tenant_id,
ToolConfig.tool_type == ToolType.BUILTIN.value ToolConfig.tool_type == ToolType.BUILTIN.value,
ToolConfig.is_active.is_(True)
).count() > 0 ).count() > 0
@@ -194,10 +203,10 @@ class ToolExecutionRepository:
@staticmethod @staticmethod
def find_by_tool_and_tenant( def find_by_tool_and_tenant(
db: Session, db: Session,
tool_id: uuid.UUID, tool_id: uuid.UUID,
tenant_id: uuid.UUID, tenant_id: uuid.UUID,
limit: int = 100 limit: int = 100
) -> List[ToolExecution]: ) -> List[ToolExecution]:
"""根据工具和租户查找执行记录""" """根据工具和租户查找执行记录"""
return db.query(ToolExecution).join( return db.query(ToolExecution).join(
@@ -205,4 +214,4 @@ class ToolExecutionRepository:
).filter( ).filter(
ToolConfig.id == tool_id, ToolConfig.id == tool_id,
ToolConfig.tenant_id == tenant_id ToolConfig.tenant_id == tenant_id
).order_by(ToolExecution.started_at.desc()).limit(limit).all() ).order_by(ToolExecution.started_at.desc()).limit(limit).all()

View File

@@ -125,6 +125,85 @@ class SkillConfig(BaseModel):
all_skills: Optional[bool] = Field(default=False, description="是否允许访问所有技能") all_skills: Optional[bool] = Field(default=False, description="是否允许访问所有技能")
# ---------- App Features ----------
class FileUploadConfig(BaseModel):
"""文件上传配置"""
enabled: bool = Field(default=False)
# 允许的传输方式local_file / remote_url默认两种都允许
allowed_transfer_methods: List[str] = Field(
default=["local_file", "remote_url"],
description="允许的传输方式"
)
# 图片文件PNG/JPG/JPEG/GIF/WEBP最大 20MB
image_enabled: bool = Field(default=False)
image_max_size_mb: int = Field(default=20)
image_allowed_extensions: List[str] = Field(
default=["png", "jpg", "jpeg", "gif", "webp"]
)
# 语音文件MP3/WAV/M4A/OGG/FLAC最大 50MB
audio_enabled: bool = Field(default=False)
audio_max_size_mb: int = Field(default=50)
audio_allowed_extensions: List[str] = Field(
default=["mp3", "wav", "m4a", "ogg", "flac"]
)
# 通用文件PDF/DOCX/XLSX/TXT/CSV/JSON最大 100MB
document_enabled: bool = Field(default=False)
document_max_size_mb: int = Field(default=100)
document_allowed_extensions: List[str] = Field(
default=["pdf", "docx", "xlsx", "txt", "csv", "json"]
)
# 视频文件MP4/MOV/AVI/WebM最大 500MB
video_enabled: bool = Field(default=False)
video_max_size_mb: int = Field(default=500)
video_allowed_extensions: List[str] = Field(
default=["mp4", "mov", "avi", "webm"]
)
# 最大文件数量
max_file_count: int = Field(default=5, ge=1, le=20)
class OpeningStatementConfig(BaseModel):
"""对话开场白配置"""
enabled: bool = Field(default=False)
statement: Optional[str] = Field(default=None, description="开场白内容")
suggested_questions: List[str] = Field(default_factory=list, description="预设问题列表")
class SuggestedQuestionsConfig(BaseModel):
"""下一步问题建议配置"""
enabled: bool = Field(default=False)
class TextToSpeechConfig(BaseModel):
"""文字转语音配置"""
enabled: bool = Field(default=False)
voice: Optional[str] = Field(default=None, description="语音音色")
language: Optional[str] = Field(default=None, description="语言")
autoplay: bool = Field(default=False, description="是否自动播放")
class CitationConfig(BaseModel):
"""引用和归属配置"""
enabled: bool = Field(default=False)
class WebSearchConfig(BaseModel):
"""联网搜索配置"""
enabled: bool = Field(default=False)
search_engine: Optional[str] = Field(default=None, description="搜索引擎")
class AppFeatures(BaseModel):
"""应用功能特性配置"""
file_upload: FileUploadConfig = Field(default_factory=FileUploadConfig)
opening_statement: OpeningStatementConfig = Field(default_factory=OpeningStatementConfig)
suggested_questions_after_answer: SuggestedQuestionsConfig = Field(default_factory=SuggestedQuestionsConfig)
text_to_speech: TextToSpeechConfig = Field(default_factory=TextToSpeechConfig)
citation: CitationConfig = Field(default_factory=CitationConfig)
web_search: WebSearchConfig = Field(default_factory=WebSearchConfig)
class ToolOldConfig(BaseModel): class ToolOldConfig(BaseModel):
"""工具配置""" """工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具") enabled: bool = Field(default=False, description="是否启用该工具")
@@ -201,6 +280,9 @@ class AgentConfigCreate(BaseModel):
# 技能配置 # 技能配置
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表") skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
# 功能特性
features: Optional[AppFeatures] = Field(default=None, description="功能特性配置")
class AppCreate(BaseModel): class AppCreate(BaseModel):
name: str name: str
@@ -258,6 +340,9 @@ class AgentConfigUpdate(BaseModel):
# 技能配置 # 技能配置
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表") skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
# 功能特性
features: Optional[AppFeatures] = Field(default=None, description="功能特性配置")
# ---------- Output Schemas ---------- # ---------- Output Schemas ----------
@@ -323,6 +408,8 @@ class AgentConfig(BaseModel):
skills: Optional[SkillConfig] = {} skills: Optional[SkillConfig] = {}
features: Optional[AppFeatures] = None
is_active: bool is_active: bool
created_at: datetime.datetime created_at: datetime.datetime
updated_at: datetime.datetime updated_at: datetime.datetime
@@ -359,6 +446,14 @@ class AgentConfig(BaseModel):
return {} return {}
return v return v
@field_validator("features", mode="before")
@classmethod
def validate_features(cls, v):
"""处理 None 值,返回默认 AppFeatures"""
if v is None:
return AppFeatures()
return v
@field_serializer("created_at", when_used="json") @field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime): def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
@@ -500,12 +595,35 @@ class DraftRunRequest(BaseModel):
files: Optional[List[FileInput]] = Field(default_factory=list, description="附件列表(支持多文件)") files: Optional[List[FileInput]] = Field(default_factory=list, description="附件列表(支持多文件)")
class SuggestedQuestion(BaseModel):
"""建议问题"""
content: str
class CitationSource(BaseModel):
"""引用来源"""
title: str
content: str
score: Optional[float] = None
kb_id: Optional[str] = None
class DraftRunResponse(BaseModel): class DraftRunResponse(BaseModel):
"""试运行响应(非流式)""" """试运行响应(非流式)"""
message: str = Field(..., description="AI 回复消息") message: str = Field(..., description="AI 回复消息")
conversation_id: Optional[str] = Field(default=None, description="会话ID用于多轮对话") conversation_id: Optional[str] = Field(default=None, description="会话ID用于多轮对话")
usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况") usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况")
elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)") elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)")
suggested_questions: List[str] = Field(default_factory=list, description="下一步建议问题")
citations: List[CitationSource] = Field(default_factory=list, description="引用来源")
audio_url: Optional[str] = Field(default=None, description="TTS 语音URL")
class OpeningResponse(BaseModel):
"""应用开场白响应"""
enabled: bool
statement: Optional[str] = None
suggested_questions: List[str] = Field(default_factory=list)
class DraftRunStreamChunk(BaseModel): class DraftRunStreamChunk(BaseModel):

View File

@@ -8,7 +8,7 @@ class EndUser(BaseModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
id: uuid.UUID = Field(description="终端用户ID") id: uuid.UUID = Field(description="终端用户ID")
app_id: uuid.UUID = Field(description="应用ID") app_id: Optional[uuid.UUID] = Field(description="应用ID", default=None)
# end_user_id: str = Field(description="终端用户ID") # end_user_id: str = Field(description="终端用户ID")
other_id: Optional[str] = Field(description="第三方ID", default=None) other_id: Optional[str] = Field(description="第三方ID", default=None)
other_name: Optional[str] = Field(description="其他名称", default="") other_name: Optional[str] = Field(description="其他名称", default="")

View File

@@ -90,6 +90,7 @@ class ToolInfo(BaseModel):
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数") parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
config_data: Dict[str, Any] = Field(default_factory=dict, description="工具配置") config_data: Dict[str, Any] = Field(default_factory=dict, description="工具配置")
status: ToolStatus = Field(ToolStatus.AVAILABLE, description="工具状态") status: ToolStatus = Field(ToolStatus.AVAILABLE, description="工具状态")
is_active: bool = Field(True, description="是否可用False 表示已删除)")
tags: List[str] = Field(default_factory=list, description="工具标签") tags: List[str] = Field(default_factory=list, description="工具标签")
tenant_id: Optional[str] = Field(None, description="租户ID") tenant_id: Optional[str] = Field(None, description="租户ID")
created_at: datetime = Field(..., description="创建时间") created_at: datetime = Field(..., description="创建时间")
@@ -212,6 +213,11 @@ class ToolUpdateRequest(BaseModel):
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
class ToolActiveUpdate(BaseModel):
"""工具可用状态更新"""
is_active: bool = Field(..., description="True=启用, False=禁用(逻辑删除)")
class ToolExecuteRequest(BaseModel): class ToolExecuteRequest(BaseModel):
"""执行工具请求""" """执行工具请求"""
tool_id: str tool_id: str

View File

@@ -51,6 +51,9 @@ class AgentConfigConverter:
if hasattr(config, "skills") and config.skills: if hasattr(config, "skills") and config.skills:
result["skills"] = config.skills.model_dump() result["skills"] = config.skills.model_dump()
if hasattr(config, "features") and config.features:
result["features"] = config.features.model_dump()
return result return result

View File

@@ -49,12 +49,23 @@ class AppChatService:
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None, workspace_id: Optional[str] = None,
files: Optional[List[FileInput]] = None # 新增:多模态文件 files: Optional[List[FileInput]] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""聊天(非流式)""" """聊天(非流式)"""
start_time = time.time() start_time = time.time()
config_id = None config_id = None
# 应用 features 配置
features_config: dict = config.features or {}
if hasattr(features_config, 'model_dump'):
features_config = features_config.model_dump()
web_search_feature = features_config.get("web_search", {})
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
web_search = False
# 校验文件上传
self.agent_service._validate_file_upload(features_config, files)
variables = self.agent_service.prepare_variables(variables, config.variables) variables = self.agent_service.prepare_variables(variables, config.variables)
# 获取模型配置ID # 获取模型配置ID
@@ -107,17 +118,14 @@ class AppChatService:
) )
# 加载历史消息 # 加载历史消息
history = [] messages = self.conversation_service.get_messages(
memory_config = {"enabled": True, 'max_history': 10} conversation_id=conversation_id,
if memory_config.get("enabled"): limit=10
messages = self.conversation_service.get_messages( )
conversation_id=conversation_id, history = [
limit=memory_config.get("max_history", 10) {"role": msg.role, "content": msg.content}
) for msg in messages
history = [ ]
{"role": msg.role, "content": msg.content}
for msg in messages
]
# 处理多模态文件 # 处理多模态文件
processed_files = None processed_files = None
@@ -166,6 +174,23 @@ class AppChatService:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# suggested_questions
suggested_questions = []
sq_config = features_config.get("suggested_questions_after_answer", {})
if isinstance(sq_config, dict) and sq_config.get("enabled"):
suggested_questions = await self.agent_service._generate_suggested_questions(
features_config, result["content"],
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
"api_base": api_key_obj.api_base}, {}
)
audio_url = await self.agent_service._generate_tts(
features_config, result["content"],
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
"api_base": api_key_obj.api_base, "provider": api_key_obj.provider},
tenant_id=tenant_id, workspace_id=workspace_id
)
return { return {
"conversation_id": conversation_id, "conversation_id": conversation_id,
"message_id": str(message_id), "message_id": str(message_id),
@@ -175,7 +200,10 @@ class AppChatService:
"completion_tokens": 0, "completion_tokens": 0,
"total_tokens": 0 "total_tokens": 0
}), }),
"elapsed_time": elapsed_time "elapsed_time": elapsed_time,
"suggested_questions": suggested_questions,
"citations": self.agent_service._filter_citations(features_config, result.get("citations", [])),
"audio_url": audio_url,
} }
async def agnet_chat_stream( async def agnet_chat_stream(
@@ -190,7 +218,7 @@ class AppChatService:
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None, workspace_id: Optional[str] = None,
files: Optional[List[FileInput]] = None # 新增:多模态文件 files: Optional[List[FileInput]] = None
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""聊天(流式)""" """聊天(流式)"""
@@ -198,10 +226,19 @@ class AppChatService:
start_time = time.time() start_time = time.time()
config_id = None config_id = None
message_id = uuid.uuid4() message_id = uuid.uuid4()
yield f"event: start\ndata: {json.dumps({
'conversation_id': str(conversation_id), # 应用 features 配置
"message_id": str(message_id) features_config: dict = config.features or {}
}, ensure_ascii=False)}\n\n" if hasattr(features_config, 'model_dump'):
features_config = features_config.model_dump()
web_search_feature = features_config.get("web_search", {})
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
web_search = False
# 校验文件上传
self.agent_service._validate_file_upload(features_config, files)
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), 'message_id': str(message_id)}, ensure_ascii=False)}\n\n"
variables = self.agent_service.prepare_variables(variables, config.variables) variables = self.agent_service.prepare_variables(variables, config.variables)
# 获取模型配置ID # 获取模型配置ID
@@ -327,8 +364,22 @@ class AppChatService:
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
# 发送结束事件 # 发送结束事件(包含 suggested_questions、tts、citations
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None} end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
sq_config = features_config.get("suggested_questions_after_answer", {})
if isinstance(sq_config, dict) and sq_config.get("enabled"):
end_data["suggested_questions"] = await self.agent_service._generate_suggested_questions(
features_config, full_content,
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
"api_base": api_key_obj.api_base}, {}
)
end_data["audio_url"] = await self.agent_service._generate_tts(
features_config, full_content,
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
"api_base": api_key_obj.api_base, "provider": api_key_obj.provider},
tenant_id=tenant_id, workspace_id=workspace_id
)
end_data["citations"] = self.agent_service._filter_citations(features_config, [])
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
logger.info( logger.info(
@@ -442,7 +493,7 @@ class AppChatService:
try: try:
message_id = uuid.uuid4() message_id = uuid.uuid4()
# 发送开始事件 # 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), "message_id": str(message_id)}, ensure_ascii=False)}\n\n" yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), 'message_id': str(message_id)}, ensure_ascii=False)}\n\n"
full_content = "" full_content = ""
total_tokens = 0 total_tokens = 0

View File

@@ -109,7 +109,7 @@ class AppService:
return share is not None return share is not None
def _validate_app_accessible(self, app: App, workspace_id: Optional[uuid.UUID]) -> None: def _validate_app_accessible(self, app: App, workspace_id: Optional[uuid.UUID]) -> None:
"""验证应用是否可访问(包括共享应用,用于只读操作) """验证应用是否可访问(包括共享应用,用于只读操作)
Args: Args:
@@ -360,6 +360,7 @@ class AppService:
variables=storage_data.get("variables", []), variables=storage_data.get("variables", []),
tools=storage_data.get("tools", []), tools=storage_data.get("tools", []),
skills=storage_data.get("skills", {}), skills=storage_data.get("skills", {}),
features=storage_data.get("features", {}),
is_active=True, is_active=True,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
@@ -1073,6 +1074,7 @@ class AppService:
# if data.tools is not None: # if data.tools is not None:
agent_cfg.tools = storage_data.get("tools", []) agent_cfg.tools = storage_data.get("tools", [])
agent_cfg.skills = storage_data.get("skills", {}) agent_cfg.skills = storage_data.get("skills", {})
agent_cfg.features = storage_data.get("features", {})
agent_cfg.updated_at = now agent_cfg.updated_at = now
@@ -1173,6 +1175,7 @@ class AppService:
variables=[], variables=[],
tools=[], tools=[],
skills=[], skills=[],
features={},
is_active=True, is_active=True,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
@@ -1389,15 +1392,15 @@ class AppService:
return config.config_id return config.config_id
def _update_endusers_memory_config( def _update_endusers_memory_config_by_workspace(
self, self,
app_id: uuid.UUID, workspace_id: uuid.UUID,
memory_config_id: uuid.UUID memory_config_id: uuid.UUID
) -> int: ) -> int:
"""批量更新应用下所有终端用户的 memory_config_id """批量更新应用下所有终端用户的 memory_config_id
Args: Args:
app_id: 应用ID workspace_id: 工作空间ID
memory_config_id: 新的记忆配置ID memory_config_id: 新的记忆配置ID
Returns: Returns:
@@ -1406,8 +1409,8 @@ class AppService:
from app.repositories.end_user_repository import EndUserRepository from app.repositories.end_user_repository import EndUserRepository
repo = EndUserRepository(self.db) repo = EndUserRepository(self.db)
updated_count = repo.batch_update_memory_config_id( updated_count = repo.batch_update_memory_config_id_by_workspace(
app_id=app_id, workspace_id=workspace_id,
memory_config_id=memory_config_id memory_config_id=memory_config_id
) )
@@ -1578,11 +1581,15 @@ class AppService:
) )
if memory_config_id: if memory_config_id:
updated_count = self._update_endusers_memory_config(app_id, memory_config_id) app = self.db.query(App).filter(App.id == app_id).first()
logger.info( if app:
f"发布时更新终端用户记忆配置: app_id={app_id}, " updated_count = self._update_endusers_memory_config_by_workspace(
f"memory_config_id={memory_config_id}, updated_count={updated_count}" app.workspace_id, memory_config_id
) )
logger.info(
f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, "
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
)
# 更新当前发布版本指针 # 更新当前发布版本指针
app.current_release_id = release.id app.current_release_id = release.id

View File

@@ -18,6 +18,7 @@ from sqlalchemy.orm import Session
from app.celery_app import celery_app from app.celery_app import celery_app
from app.core.agent.agent_middleware import AgentMiddleware from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
from app.core.config import settings
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
@@ -262,9 +263,12 @@ class AgentRunService:
def load_tools_config(self, tools_config, web_search, tenant_id) -> list: def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
"""加载工具配置""" """加载工具配置"""
if not tools_config:
return []
tools = [] tools = []
if web_search:
search_tool = create_web_search_tool({})
tools.append(search_tool)
if not tools_config:
return tools
tool_service = ToolService(self.db) tool_service = ToolService(self.db)
if tools_config and isinstance(tools_config, list): if tools_config and isinstance(tools_config, list):
@@ -273,24 +277,15 @@ class AgentRunService:
# 根据工具名称查找工具实例 # 根据工具名称查找工具实例
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id) tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance: if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具 # 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool) tools.append(langchain_tool)
elif tools_config and isinstance(tools_config, dict): logger.debug(
web_search_choice = tools_config.get("web_search", {}) "已添加网络搜索工具",
web_search_enable = web_search_choice.get("enabled", False) extra={
if web_search and web_search_enable: "tool_count": len(tools)
search_tool = create_web_search_tool({}) }
tools.append(search_tool) )
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
return tools return tools
def load_skill_config( def load_skill_config(
@@ -373,6 +368,86 @@ class AgentRunService:
) )
return tools, bool(memory_config.get("enabled")) return tools, bool(memory_config.get("enabled"))
@staticmethod
def _validate_file_upload(
features_config: Dict[str, Any],
files: Optional[List[FileInput]]
) -> None:
"""校验上传文件是否符合 file_upload 配置"""
if not files:
return
fu = features_config.get("file_upload", {})
if not (isinstance(fu, dict) and fu.get("enabled")):
raise BusinessException("该应用未开启文件上传功能", BizCode.BAD_REQUEST)
max_count = fu.get("max_file_count", 5)
if len(files) > max_count:
raise BusinessException(f"文件数量超过限制(最多 {max_count} 个)", BizCode.BAD_REQUEST)
# 校验传输方式
allowed_methods = fu.get("allowed_transfer_methods", ["local_file", "remote_url"])
for f in files:
if str(f.transfer_method) not in allowed_methods:
raise BusinessException(
f"不支持的文件传输方式:{f.transfer_method},允许的方式:{', '.join(allowed_methods)}",
BizCode.BAD_REQUEST
)
# 各类型对应的开关和大小限制配置键
type_cfg = {
"image": ("image_enabled", "image_max_size_mb", 20, "图片"),
"audio": ("audio_enabled", "audio_max_size_mb", 50, "音频"),
"document": ("document_enabled", "document_max_size_mb", 100, "文档"),
"video": ("video_enabled", "video_max_size_mb", 500, "视频"),
}
for f in files:
ftype = str(f.type) # 如 "image", "audio", "document", "video"
cfg = type_cfg.get(ftype)
if cfg is None:
continue
enabled_key, size_key, default_max_mb, label = cfg
# 校验类型开关
if not fu.get(enabled_key):
raise BusinessException(f"该应用未开启{label}文件上传", BizCode.BAD_REQUEST)
# 校验文件大小(仅当内容已加载时)
content = f.get_content()
if content is not None:
max_mb = fu.get(size_key, default_max_mb)
size_mb = len(content) / (1024 * 1024)
if size_mb > max_mb:
raise BusinessException(
f"{label}文件大小超过限制(最大 {max_mb}MB当前 {size_mb:.1f}MB",
BizCode.BAD_REQUEST
)
@staticmethod
def _inject_opening_statement(
features_config: Dict[str, Any],
system_prompt: str,
is_new_conversation: bool
) -> str:
"""首轮对话时将开场白注入 system_prompt"""
if not is_new_conversation:
return system_prompt
opening = features_config.get("opening_statement", {})
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
return system_prompt
statement = opening["statement"]
return f"{system_prompt}\n\n[对话开场白]\n{statement}"
@staticmethod
def _filter_citations(
features_config: Dict[str, Any],
citations: List[Any]
) -> List[Any]:
"""根据 citation 开关决定是否返回引用来源"""
citation_cfg = features_config.get("citation", {})
if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"):
return citations
return []
async def run( async def run(
self, self,
*, *,
@@ -415,6 +490,15 @@ class AgentRunService:
skills_config: dict | None = agent_config.skills skills_config: dict | None = agent_config.skills
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
memory_config: dict | None = agent_config.memory memory_config: dict | None = agent_config.memory
features_config: dict = agent_config.features or {}
# 从 features 中读取功能开关(优先级高于参数默认值)
web_search_feature = features_config.get("web_search", {})
if not isinstance(web_search_feature, dict) or not web_search_feature.get("enabled"):
web_search = False
# file_upload 校验
self._validate_file_upload(features_config, files)
try: try:
# 1. 获取 API Key 配置 # 1. 获取 API Key 配置
@@ -449,6 +533,10 @@ class AgentRunService:
# 3. 处理系统提示词(支持变量替换) # 3. 处理系统提示词(支持变量替换)
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
# opening_statement首轮对话注入开场白
is_new_conversation = not conversation_id
system_prompt = self._inject_opening_statement(features_config, system_prompt, is_new_conversation)
# 4. 准备工具列表 # 4. 准备工具列表
tools = [] tools = []
@@ -491,12 +579,10 @@ class AgentRunService:
) )
# 6. 加载历史消息 # 6. 加载历史消息
history = [] history = await self._load_conversation_history(
if memory_config and memory_config.get("enabled"): conversation_id=conversation_id,
history = await self._load_conversation_history( max_history=10
conversation_id=conversation_id, )
max_history=agent_config.memory.get("max_history", 10)
)
# 6. 处理多模态文件 # 6. 处理多模态文件
processed_files = None processed_files = None
@@ -551,7 +637,7 @@ class AgentRunService:
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id")) ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
# 9. 保存会话消息 # 9. 保存会话消息
if not sub_agent and memory_config and memory_config.get("enabled"): if not sub_agent:
await self._save_conversation_message( await self._save_conversation_message(
conversation_id=conversation_id, conversation_id=conversation_id,
user_message=message, user_message=message,
@@ -575,7 +661,15 @@ class AgentRunService:
"completion_tokens": 0, "completion_tokens": 0,
"total_tokens": 0 "total_tokens": 0
}), }),
"elapsed_time": elapsed_time "elapsed_time": elapsed_time,
"suggested_questions": await self._generate_suggested_questions(
features_config, result["content"], api_key_config, effective_params
) if not sub_agent else [],
"citations": self._filter_citations(features_config, result.get("citations", [])),
"audio_url": await self._generate_tts(
features_config, result["content"], api_key_config,
tenant_id=tenant_id, workspace_id=workspace_id
) if not sub_agent else None,
} }
logger.info( logger.info(
@@ -630,6 +724,15 @@ class AgentRunService:
skills_config: dict | None = agent_config.skills skills_config: dict | None = agent_config.skills
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
memory_config: dict | None = agent_config.memory memory_config: dict | None = agent_config.memory
features_config: dict = agent_config.features or {}
# 从 features 中读取功能开关
web_search_feature = features_config.get("web_search", {})
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
web_search = False
# file_upload 校验
self._validate_file_upload(features_config, files)
start_time = time.time() start_time = time.time()
@@ -659,6 +762,10 @@ class AgentRunService:
# 3. 处理系统提示词(支持变量替换) # 3. 处理系统提示词(支持变量替换)
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
# opening_statement首轮对话注入开场白
is_new_conversation = not conversation_id
system_prompt = self._inject_opening_statement(features_config, system_prompt, is_new_conversation)
# 4. 准备工具列表 # 4. 准备工具列表
tools = [] tools = []
@@ -703,12 +810,10 @@ class AgentRunService:
) )
# 6. 加载历史消息 # 6. 加载历史消息
history = [] history = await self._load_conversation_history(
if memory_config and memory_config.get("enabled"): conversation_id=conversation_id,
history = await self._load_conversation_history( max_history=memory_config.get("max_history", 10)
conversation_id=conversation_id, )
max_history=memory_config.get("max_history", 10)
)
# 6. 处理多模态文件 # 6. 处理多模态文件
processed_files = None processed_files = None
@@ -774,7 +879,7 @@ class AgentRunService:
}) })
# 10. 保存会话消息 # 10. 保存会话消息
if not sub_agent and memory_config and memory_config.get("enabled"): if not sub_agent:
await self._save_conversation_message( await self._save_conversation_message(
conversation_id=conversation_id, conversation_id=conversation_id,
user_message=message, user_message=message,
@@ -786,12 +891,22 @@ class AgentRunService:
} }
) )
# 11. 发送结束事件 # 11. 发送结束事件(包含 suggested_questions 和 tts
yield self._format_sse_event("end", { end_data: Dict[str, Any] = {
"conversation_id": conversation_id, "conversation_id": conversation_id,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"message_length": len(full_content) "message_length": len(full_content)
}) }
if not sub_agent:
end_data["suggested_questions"] = await self._generate_suggested_questions(
features_config, full_content, api_key_config, effective_params
)
end_data["audio_url"] = await self._generate_tts(
features_config, full_content, api_key_config,
tenant_id=tenant_id, workspace_id=workspace_id
)
end_data["citations"] = self._filter_citations(features_config, [])
yield self._format_sse_event("end", end_data)
logger.info( logger.info(
"流式试运行完成", "流式试运行完成",
@@ -1137,6 +1252,165 @@ class AgentRunService:
logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)}) logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)})
return {} return {}
async def _generate_suggested_questions(
self,
features_config: Dict[str, Any],
assistant_message: str,
api_key_config: Dict[str, Any],
effective_params: Dict[str, Any]
) -> List[str]:
"""根据 suggested_questions_after_answer 配置生成下一步建议问题"""
sq_config = features_config.get("suggested_questions_after_answer", {})
if not isinstance(sq_config, dict) or not sq_config.get("enabled"):
return []
try:
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
llm = ChatOpenAI(
model=api_key_config["model_name"],
api_key=api_key_config["api_key"],
base_url=api_key_config.get("api_base"),
temperature=0.5,
max_tokens=200,
)
prompt = (
f"根据以下AI回复生成3个用户可能继续追问的简短问题每行一个不加序号\n\n{assistant_message}"
)
resp = await llm.ainvoke([HumanMessage(content=prompt)])
lines = [l.strip() for l in resp.content.strip().split("\n") if l.strip()]
return lines[:3]
except Exception as e:
logger.warning(f"生成建议问题失败: {e}")
return []
async def _generate_tts(
self,
features_config: Dict[str, Any],
text: str,
api_key_config: Dict[str, Any],
tenant_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
) -> Optional[str]:
"""根据 text_to_speech 配置生成语音,上传到存储并返回 URL"""
tts_config = features_config.get("text_to_speech", {})
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
return None
if not text or not text.strip():
return None
try:
from app.services.file_storage_service import FileStorageService
provider = api_key_config.get("provider", "openai")
api_key = api_key_config.get("api_key")
api_base = api_key_config.get("api_base")
voice = tts_config.get("voice")
if provider == "dashscope":
audio_bytes, file_ext, content_type = await self._tts_dashscope(
api_key=api_key,
text=text,
voice=voice or "longxiaochun", # 会根据 model 版本自动修正后缀
tts_config=tts_config,
)
else:
# OpenAI 兼容接口openai / xinference / gpustack 等)
audio_bytes, file_ext, content_type = await self._tts_openai(
api_key=api_key,
api_base=api_base,
text=text,
voice=voice or "alloy",
)
storage_service = FileStorageService()
file_id = uuid.uuid4()
file_key = await storage_service.upload_file(
tenant_id=tenant_id,
workspace_id=workspace_id,
file_id=file_id,
file_ext=file_ext,
content=audio_bytes,
content_type=content_type,
)
# 保存文件元数据到数据库
from app.models.file_metadata_model import FileMetadata
db_file = FileMetadata(
id=file_id,
tenant_id=tenant_id,
workspace_id=workspace_id,
file_key=file_key,
file_name=f"tts_{file_id}{file_ext}",
file_ext=file_ext,
file_size=len(audio_bytes),
content_type=content_type,
status="completed",
)
self.db.add(db_file)
self.db.commit()
server_url = settings.FILE_LOCAL_SERVER_URL
audio_url = f"{server_url}/storage/permanent/{file_id}"
logger.debug(f"TTS 生成成功provider={provider}, file_key={file_key}")
return audio_url
except Exception as e:
logger.warning(f"TTS 生成失败: {e}")
return None
@staticmethod
async def _tts_openai(
api_key: str,
api_base: Optional[str],
text: str,
voice: str,
) -> tuple:
"""OpenAI 兼容 TTS返回 (audio_bytes, file_ext, content_type)"""
from openai import AsyncOpenAI
client = AsyncOpenAI(api_key=api_key, base_url=api_base)
response = await client.audio.speech.create(
model="tts-1",
voice=voice,
input=text[:4096],
)
return response.content, ".mp3", "audio/mpeg"
@staticmethod
async def _tts_dashscope(
api_key: str,
text: str,
voice: str,
tts_config: Dict[str, Any],
) -> tuple:
"""DashScope CosyVoice TTS返回 (audio_bytes, file_ext, content_type)"""
import dashscope
from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat
model = tts_config.get("model") or "cosyvoice-v2"
is_v2 = model.endswith("-v2")
# cosyvoice-v2 音色名带 _v2 后缀v1 不带
# 如果用户传入的 voice 不匹配当前模型版本,自动修正
if is_v2 and not voice.endswith("_v2"):
voice = voice + "_v2"
elif not is_v2 and voice.endswith("_v2"):
voice = voice[:-3] # 去掉 _v2
def _sync_call() -> bytes:
dashscope.api_key = api_key
synthesizer = SpeechSynthesizer(
model=model,
voice=voice,
format=AudioFormat.MP3_22050HZ_MONO_256KBPS,
)
audio = synthesizer.call(text[:4096])
if audio is None:
raise RuntimeError("DashScope TTS 返回空音频")
return audio
audio_bytes = await asyncio.to_thread(_sync_call)
return audio_bytes, ".mp3", "audio/mpeg"
def _replace_variables( def _replace_variables(
self, self,
text: str, text: str,
@@ -1221,6 +1495,12 @@ class AgentRunService:
} }
) )
# 提前校验文件上传(与 run() 内部保持一致)
features_config: dict = agent_config.features or {}
if hasattr(features_config, 'model_dump'):
features_config = features_config.model_dump()
# self._validate_file_upload(features_config, files)
async def run_single_model(model_info): async def run_single_model(model_info):
"""运行单个模型""" """运行单个模型"""
try: try:
@@ -1271,6 +1551,9 @@ class AgentRunService:
if elapsed > 0 and usage.get("completion_tokens") else None if elapsed > 0 and usage.get("completion_tokens") else None
), ),
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]), "cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
"audio_url": result.get("audio_url"),
"citations": result.get("citations", []),
"suggested_questions": result.get("suggested_questions", []),
"error": None "error": None
} }
@@ -1343,7 +1626,12 @@ class AgentRunService:
) )
return { return {
"results": results, "results": [{
**r,
"audio_url": r.get("audio_url"),
"citations": r.get("citations", []),
"suggested_questions": r.get("suggested_questions", []),
} for r in results],
"total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results), "total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results),
"successful_count": len(successful), "successful_count": len(successful),
"failed_count": len(failed), "failed_count": len(failed),
@@ -1434,6 +1722,12 @@ class AgentRunService:
extra={"model_count": len(models), "parallel": parallel} extra={"model_count": len(models), "parallel": parallel}
) )
# 提前校验文件上传
# features_config: dict = agent_config.features or {}
# if hasattr(features_config, 'model_dump'):
# features_config = features_config.model_dump()
# self._validate_file_upload(features_config, files)
# 发送开始事件 # 发送开始事件
yield self._format_sse_event("compare_start", { yield self._format_sse_event("compare_start", {
"conversation_id": conversation_id, "conversation_id": conversation_id,
@@ -1465,6 +1759,9 @@ class AgentRunService:
start_time = time.time() start_time = time.time()
full_content = "" full_content = ""
returned_conversation_id = model_conversation_id returned_conversation_id = model_conversation_id
audio_url = None
citations = []
suggested_questions = []
# 临时修改参数 # 临时修改参数
original_params = agent_config.model_parameters original_params = agent_config.model_parameters
@@ -1518,6 +1815,12 @@ class AgentRunService:
"content": chunk "content": chunk
})) }))
# 从 end 事件中提取 features 输出字段
if event_type == "end" and event_data:
audio_url = event_data.get("audio_url")
citations = event_data.get("citations", [])
suggested_questions = event_data.get("suggested_questions", [])
if event_type == "error" and event_data: if event_type == "error" and event_data:
await event_queue.put(self._format_sse_event("model_error", { await event_queue.put(self._format_sse_event("model_error", {
"model_index": idx, "model_index": idx,
@@ -1543,6 +1846,9 @@ class AgentRunService:
"parameters_used": model_info["parameters"], "parameters_used": model_info["parameters"],
"message": full_content, "message": full_content,
"elapsed_time": elapsed, "elapsed_time": elapsed,
"audio_url": audio_url,
"citations": citations,
"suggested_questions": suggested_questions,
"error": None "error": None
} }
@@ -1554,6 +1860,9 @@ class AgentRunService:
"conversation_id": returned_conversation_id, "conversation_id": returned_conversation_id,
"elapsed_time": elapsed, "elapsed_time": elapsed,
"message_length": len(full_content), "message_length": len(full_content),
"audio_url": audio_url,
"citations": citations,
"suggested_questions": suggested_questions,
"timestamp": time.time() "timestamp": time.time()
})) }))
@@ -1685,8 +1994,11 @@ class AgentRunService:
"model_name": r["model_name"], "model_name": r["model_name"],
"label": r["label"], "label": r["label"],
"conversation_id": r.get("conversation_id"), "conversation_id": r.get("conversation_id"),
"message": r.get("message"), # 包含完整消息 "message": r.get("message"),
"elapsed_time": r.get("elapsed_time", 0), "elapsed_time": r.get("elapsed_time", 0),
"audio_url": r.get("audio_url"),
"citations": r.get("citations", []),
"suggested_questions": r.get("suggested_questions", []),
"error": r.get("error") "error": r.get("error")
}) })

View File

@@ -68,14 +68,14 @@ def get_workspace_end_users(
return [] return []
# 提取所有 app_id # 提取所有 app_id
app_ids = [app.id for app in apps_orm] # app_ids = [app.id for app in apps_orm]
# 批量查询所有 end_users一次查询而非循环查询 # 批量查询所有 end_users一次查询而非循环查询
# 按 created_at 降序排序NULL 值排在最后id 作为次级排序键保证确定性 # 按 created_at 降序排序NULL 值排在最后id 作为次级排序键保证确定性
from app.models.end_user_model import EndUser as EndUserModel from app.models.end_user_model import EndUser as EndUserModel
from sqlalchemy import desc, nullslast from sqlalchemy import desc, nullslast
end_users_orm = db.query(EndUserModel).filter( end_users_orm = db.query(EndUserModel).filter(
EndUserModel.app_id.in_(app_ids) EndUserModel.workspace_id == workspace_id
).order_by( ).order_by(
nullslast(desc(EndUserModel.created_at)), nullslast(desc(EndUserModel.created_at)),
desc(EndUserModel.id) desc(EndUserModel.id)

View File

@@ -78,7 +78,7 @@ class ToolService:
def get_tool_info(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolInfo]: def get_tool_info(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolInfo]:
"""获取工具详情""" """获取工具详情"""
config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id) config = self.tool_repo.find_by_id_and_tenant_all(self.db, uuid.UUID(tool_id), tenant_id)
return self._config_to_info(config) if config else None return self._config_to_info(config) if config else None
def _check_name_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, exclude_id: Optional[uuid.UUID] = None): def _check_name_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, exclude_id: Optional[uuid.UUID] = None):
@@ -237,7 +237,7 @@ class ToolService:
return False return False
def delete_tool(self, tool_id: str, tenant_id: uuid.UUID) -> bool: def delete_tool(self, tool_id: str, tenant_id: uuid.UUID) -> bool:
"""删除工具""" """删除工具(逻辑删除)"""
config = self._get_tool_config(tool_id, tenant_id) config = self._get_tool_config(tool_id, tenant_id)
if not config: if not config:
return False return False
@@ -246,14 +246,7 @@ class ToolService:
raise ValueError("内置工具不允许删除") raise ValueError("内置工具不允许删除")
try: try:
# 删除关联表记录 config.is_active = False
if config.tool_type == ToolType.CUSTOM.value:
self.db.query(CustomToolConfig).filter(CustomToolConfig.id == config.id).delete()
elif config.tool_type == ToolType.MCP.value:
self.db.query(MCPToolConfig).filter(MCPToolConfig.id == config.id).delete()
# 删除主表记录ToolExecution会通过cascade自动删除
self.db.delete(config)
self._clear_tool_cache(tool_id) self._clear_tool_cache(tool_id)
self.db.commit() self.db.commit()
return True return True
@@ -262,6 +255,27 @@ class ToolService:
logger.error(f"删除工具失败: {tool_id}, {e}") logger.error(f"删除工具失败: {tool_id}, {e}")
return False return False
def set_tool_active(self, tool_id: str, tenant_id: uuid.UUID, is_active: bool) -> bool:
"""设置工具可用状态(启用/禁用)"""
# 直接查询,包含 is_active=False 的记录
config = self.db.query(ToolConfig).filter(
ToolConfig.id == uuid.UUID(tool_id),
ToolConfig.tenant_id == tenant_id
).first()
if not config:
return False
if config.tool_type == ToolType.BUILTIN.value:
raise ValueError("内置工具不允许修改可用状态")
try:
config.is_active = is_active
self._clear_tool_cache(tool_id)
self.db.commit()
return True
except Exception as e:
self.db.rollback()
logger.error(f"设置工具状态失败: {tool_id}, {e}")
return False
async def execute_tool( async def execute_tool(
self, self,
tool_id: str, tool_id: str,
@@ -378,7 +392,7 @@ class ToolService:
Returns: Returns:
方法列表或None 方法列表或None
""" """
config = self._get_tool_config(tool_id, tenant_id) config = self._get_tool_config_all(tool_id, tenant_id)
if not config: if not config:
return None return None
@@ -857,16 +871,20 @@ class ToolService:
} }
def _get_tool_config(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]: def _get_tool_config(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
"""获取工具配置""" """获取工具配置(仅返回 is_active=True)"""
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id) return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
def _get_tool_config_all(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
"""获取工具配置(返回所有)"""
return self.tool_repo.find_by_id_and_tenant_all(self.db, uuid.UUID(tool_id), tenant_id)
def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]: def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
"""获取工具实例""" """获取工具实例(仅返回 is_active=True 的工具)"""
if tool_id in self._tool_cache: if tool_id in self._tool_cache:
return self._tool_cache[tool_id] return self._tool_cache[tool_id]
config = self._get_tool_config(tool_id, tenant_id) config = self._get_tool_config(tool_id, tenant_id)
if not config: if not config or not config.is_active:
return None return None
try: try:
@@ -980,6 +998,7 @@ class ToolService:
tags=config.tags or [], tags=config.tags or [],
tenant_id=str(config.tenant_id) if config.tenant_id else None, tenant_id=str(config.tenant_id) if config.tenant_id else None,
config_data=config_data, config_data=config_data,
is_active=config.is_active,
created_at=config.created_at created_at=config.created_at
) )

View File

@@ -1292,9 +1292,9 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
} }
# 2. 查询所有app下的end_user_id去重 # 2. 查询所有app下的end_user_id去重
app_ids = [app.id for app in apps] # app_ids = [app.id for app in apps]
end_users = db.query(EndUser.id).filter( end_users = db.query(EndUser.id).filter(
EndUser.app_id.in_(app_ids) EndUser.workspace_id == workspace_id
).distinct().all() ).distinct().all()
# 3. 遍历所有end_user查询每个宿主的记忆总量并累加 # 3. 遍历所有end_user查询每个宿主的记忆总量并累加
@@ -1433,9 +1433,9 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
continue continue
# 2. 查询所有app下的end_user_id去重 # 2. 查询所有app下的end_user_id去重
app_ids = [app.id for app in apps] # app_ids = [app.id for app in apps]
end_users = db.query(EndUser.id).filter( end_users = db.query(EndUser.id).filter(
EndUser.app_id.in_(app_ids) EndUser.workspace_id == workspace_id
).distinct().all() ).distinct().all()
# 3. 遍历所有end_user查询每个宿主的记忆总量并累加 # 3. 遍历所有end_user查询每个宿主的记忆总量并累加