Merge pull request #578 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
Feature/app
This commit is contained in:
@@ -254,6 +254,27 @@ def get_agent_config(
|
||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/opening", summary="获取应用开场白配置")
|
||||
@cur_workspace_access_guard()
|
||||
def get_opening(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
if hasattr(features, "model_dump"):
|
||||
features = features.model_dump()
|
||||
opening = features.get("opening_statement", {})
|
||||
return success(data=app_schema.OpeningResponse(
|
||||
enabled=opening.get("enabled", False),
|
||||
statement=opening.get("statement"),
|
||||
suggested_questions=opening.get("suggested_questions", []),
|
||||
))
|
||||
|
||||
|
||||
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
||||
@cur_workspace_access_guard()
|
||||
def publish_app(
|
||||
@@ -513,11 +534,11 @@ async def draft_run(
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
# 先获取 app 的 workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
other_id=str(current_user.id),
|
||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
@@ -845,11 +866,11 @@ async def draft_run_compare(
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
# 先获取 app 的 workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
other_id=str(current_user.id),
|
||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
@@ -898,7 +919,12 @@ async def draft_run_compare(
|
||||
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||
})
|
||||
|
||||
|
||||
# 从 features 中读取功能开关(与 draft_run 保持一致)
|
||||
features_config: dict = agent_cfg.features or {}
|
||||
if hasattr(features_config, 'model_dump'):
|
||||
features_config = features_config.model_dump()
|
||||
web_search_feature = features_config.get("web_search", {})
|
||||
web_search = isinstance(web_search_feature, dict) and web_search_feature.get("enabled", False)
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
@@ -915,7 +941,7 @@ async def draft_run_compare(
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
web_search=web_search,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60,
|
||||
@@ -946,7 +972,7 @@ async def draft_run_compare(
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
web_search=web_search,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60,
|
||||
|
||||
@@ -22,6 +22,7 @@ from app.schemas import release_share_schema, conversation_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.app_service import AppService
|
||||
from app.services.auth_service import create_access_token
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
@@ -215,8 +216,10 @@ def list_conversations(
|
||||
service = SharedChatService(db)
|
||||
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
||||
end_user_repo = EndUserRepository(db)
|
||||
app_service = AppService(db)
|
||||
app = app_service._get_app_or_404(share.app_id)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
other_id=other_id
|
||||
)
|
||||
logger.debug(new_end_user.id)
|
||||
@@ -308,25 +311,28 @@ async def chat(
|
||||
|
||||
# Store end_user_id in database with original user_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
app_service = AppService(db)
|
||||
app = app_service._get_app_or_404(share.app_id)
|
||||
workspace_id = app.workspace_id
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id,
|
||||
original_user_id=user_id # Save original user_id to other_id
|
||||
original_user_id=user_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
|
||||
appid = share.app_id
|
||||
# appid = share.app_id
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||
app = db.query(App).filter(
|
||||
App.id == appid,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
if not app:
|
||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
# app = db.query(App).filter(
|
||||
# App.id == appid,
|
||||
# App.is_active.is_(True)
|
||||
# ).first()
|
||||
# if not app:
|
||||
# raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
|
||||
workspace_id = app.workspace_id
|
||||
# workspace_id = app.workspace_id
|
||||
|
||||
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||
@@ -654,17 +660,20 @@ async def config_query(
|
||||
workflow_service = WorkflowService(db)
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": workflow_service.get_start_node_variables(release.config)
|
||||
"variables": workflow_service.get_start_node_variables(release.config),
|
||||
"features": release.config.get("features")
|
||||
}
|
||||
elif release.app.type == AppType.AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": release.config.get("variables")
|
||||
"variables": release.config.get("variables"),
|
||||
"features": release.config.get("features")
|
||||
}
|
||||
elif release.app.type == AppType.MULTI_AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": []
|
||||
"variables": [],
|
||||
"features": release.config.get("features")
|
||||
}
|
||||
else:
|
||||
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
@@ -94,9 +94,8 @@ async def chat(
|
||||
workspace_id = app.workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id,
|
||||
original_user_id=other_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
web_search = True
|
||||
|
||||
@@ -4,7 +4,8 @@ from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from app.schemas.tool_schema import (
|
||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest
|
||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest,
|
||||
CustomToolTestRequest, ToolActiveUpdate
|
||||
)
|
||||
|
||||
from app.core.response_utils import success
|
||||
@@ -156,7 +157,7 @@ async def delete_tool(
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""删除工具"""
|
||||
"""删除工具(逻辑删除,is_active=False)"""
|
||||
try:
|
||||
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
||||
if not success_flag:
|
||||
@@ -168,6 +169,30 @@ async def delete_tool(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/{tool_id}/active", response_model=ApiResponse)
|
||||
async def set_tool_active(
|
||||
tool_id: str,
|
||||
request: ToolActiveUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""设置工具可用状态(启用/禁用)
|
||||
|
||||
- is_active=true: 启用工具
|
||||
- is_active=false: 禁用工具(等同于删除,但可恢复)
|
||||
"""
|
||||
try:
|
||||
success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active)
|
||||
if not success_flag:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
action = "启用" if request.is_active else "禁用"
|
||||
return success(msg=f"工具已{action}")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/execution/execute", response_model=ApiResponse)
|
||||
async def execute_tool(
|
||||
request: ToolExecuteRequest,
|
||||
|
||||
@@ -23,7 +23,7 @@ class SimpleMCPClient:
|
||||
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
|
||||
self.server_url = server_url
|
||||
self.connection_config = connection_config or {}
|
||||
self.timeout = self.connection_config.get("timeout", 30)
|
||||
self.timeout = self.connection_config.get("timeout", 10)
|
||||
|
||||
# 确定连接类型
|
||||
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
||||
|
||||
@@ -16,7 +16,7 @@ engine = create_engine(
|
||||
pool_recycle=settings.DB_POOL_RECYCLE,
|
||||
pool_timeout=settings.DB_POOL_TIMEOUT,
|
||||
connect_args={
|
||||
"options": "-c timezone=Asia/Shanghai -c statement_timeout=60000"
|
||||
"options": "-c timezone=UTC -c statement_timeout=60000"
|
||||
},
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
@@ -31,6 +31,7 @@ class AgentConfig(Base):
|
||||
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
|
||||
tools = Column(JSON, default=list, nullable=True, comment="工具配置")
|
||||
skills = Column(JSON, default=dict, nullable=True, comment="技能配置")
|
||||
features = Column(JSON, default=dict, nullable=True, comment="功能特性配置")
|
||||
|
||||
# 多 Agent 相关字段
|
||||
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")
|
||||
|
||||
@@ -12,7 +12,8 @@ class EndUser(Base):
|
||||
__tablename__ = "end_users"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True)
|
||||
app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False)
|
||||
app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=True)
|
||||
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False)
|
||||
# end_user_id = Column(String, nullable=False, index=True)
|
||||
other_id = Column(String, nullable=True) # Store original user_id
|
||||
other_name = Column(String, default="", nullable=False)
|
||||
@@ -61,4 +62,7 @@ class EndUser(Base):
|
||||
app = relationship(
|
||||
"App",
|
||||
back_populates="end_users"
|
||||
)
|
||||
)
|
||||
|
||||
# 与 WorkSpace 的反向关系
|
||||
workspace = relationship("Workspace", back_populates="end_users")
|
||||
@@ -110,7 +110,10 @@ class ToolConfig(Base):
|
||||
# 元数据
|
||||
version = Column(String(50), default="1.0.0")
|
||||
tags = Column(JSON, default=list) # 标签列表
|
||||
|
||||
|
||||
# 逻辑删除标志
|
||||
is_active = Column(Boolean, default=True, server_default='true', nullable=False, index=True, comment="是否可用,False表示已删除")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.now, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)
|
||||
|
||||
@@ -38,6 +38,7 @@ class Workspace(Base):
|
||||
members = relationship("WorkspaceMember", back_populates="workspace") # users collaborate through membership
|
||||
api_keys = relationship("ApiKey", back_populates="workspace", cascade="all, delete-orphan") # API Keys
|
||||
memory_increments = relationship("MemoryIncrement", back_populates="workspace")
|
||||
end_users = relationship("EndUser", back_populates="workspace", cascade="all, delete-orphan")
|
||||
|
||||
class WorkspaceMember(Base):
|
||||
__tablename__ = "workspace_members"
|
||||
|
||||
@@ -32,6 +32,21 @@ class EndUserRepository:
|
||||
db_logger.error(f"查询应用 {app_id} 下宿主时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_end_users_by_workspace(self, workspace_id: uuid.UUID) -> List[EndUser]:
|
||||
"""获取指定 workspace 下的所有 end_user"""
|
||||
try:
|
||||
end_users = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.workspace_id == workspace_id)
|
||||
.all()
|
||||
)
|
||||
db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(end_users)} 个终端用户")
|
||||
return end_users
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"查询工作空间 {workspace_id} 下终端用户时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_end_user_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
||||
"""根据 end_user_id 查询宿主"""
|
||||
try:
|
||||
@@ -52,14 +67,14 @@ class EndUserRepository:
|
||||
|
||||
def get_or_create_end_user(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
other_id: str,
|
||||
original_user_id: Optional[str] = None
|
||||
) -> EndUser:
|
||||
"""获取或创建终端用户
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
workspace_id: 工作空间ID
|
||||
other_id: 第三方ID
|
||||
original_user_id: 原始用户ID (存储到 other_id)
|
||||
"""
|
||||
@@ -68,26 +83,27 @@ class EndUserRepository:
|
||||
end_user = (
|
||||
self.db.query(EndUser)
|
||||
.filter(
|
||||
EndUser.app_id == app_id,
|
||||
EndUser.workspace_id == workspace_id,
|
||||
EndUser.other_id == other_id
|
||||
)
|
||||
.order_by(EndUser.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if end_user:
|
||||
db_logger.debug(f"找到现有终端用户: 应用ID {app_id}、第三方ID {other_id}")
|
||||
db_logger.debug(f"找到现有终端用户: 应用ID {workspace_id}、第三方ID {other_id}")
|
||||
return end_user
|
||||
|
||||
# 创建新用户
|
||||
end_user = EndUser(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id
|
||||
)
|
||||
self.db.add(end_user)
|
||||
self.db.commit()
|
||||
self.db.refresh(end_user)
|
||||
|
||||
db_logger.info(f"创建新终端用户: (other_id: {other_id}) for app {app_id}")
|
||||
db_logger.info(f"创建新终端用户: (other_id: {other_id}) for workspace {workspace_id}")
|
||||
return end_user
|
||||
|
||||
except Exception as e:
|
||||
@@ -314,8 +330,7 @@ class EndUserRepository:
|
||||
try:
|
||||
end_users = (
|
||||
self.db.query(EndUser)
|
||||
.join(App, EndUser.app_id == App.id)
|
||||
.filter(App.workspace_id == workspace_id)
|
||||
.filter(EndUser.workspace_id == workspace_id)
|
||||
.all()
|
||||
)
|
||||
db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(end_users)} 个终端用户")
|
||||
@@ -402,45 +417,79 @@ class EndUserRepository:
|
||||
db_logger.error(f"获取终端用户 {end_user_id} 的 memory_config_id 时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def batch_update_memory_config_id(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
memory_config_id: uuid.UUID
|
||||
# def batch_update_memory_config_id(
|
||||
# self,
|
||||
# app_id: uuid.UUID,
|
||||
# memory_config_id: uuid.UUID
|
||||
# ) -> int:
|
||||
# """批量更新应用下所有终端用户的 memory_config_id
|
||||
#
|
||||
# Args:
|
||||
# app_id: 应用ID
|
||||
# memory_config_id: 新的记忆配置ID
|
||||
#
|
||||
# Returns:
|
||||
# int: 更新的行数
|
||||
# """
|
||||
# try:
|
||||
# from sqlalchemy import update
|
||||
#
|
||||
# stmt = (
|
||||
# update(EndUser)
|
||||
# .where(EndUser.app_id == app_id)
|
||||
# .values(memory_config_id=memory_config_id)
|
||||
# )
|
||||
#
|
||||
# result = self.db.execute(stmt)
|
||||
# self.db.commit()
|
||||
#
|
||||
# updated_count = result.rowcount
|
||||
#
|
||||
# db_logger.info(
|
||||
# f"批量更新终端用户记忆配置: app_id={app_id}, "
|
||||
# f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||
# )
|
||||
#
|
||||
# return updated_count
|
||||
#
|
||||
# except Exception as e:
|
||||
# self.db.rollback()
|
||||
# db_logger.error(
|
||||
# f"批量更新终端用户记忆配置时出错: app_id={app_id}, "
|
||||
# f"memory_config_id={memory_config_id}, error={str(e)}"
|
||||
# )
|
||||
# raise
|
||||
|
||||
def batch_update_memory_config_id_by_workspace(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
memory_config_id: uuid.UUID
|
||||
) -> int:
|
||||
"""批量更新应用下所有终端用户的 memory_config_id
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
memory_config_id: 新的记忆配置ID
|
||||
|
||||
Returns:
|
||||
int: 更新的行数
|
||||
"""
|
||||
"""批量更新工作空间下所有终端用户的 memory_config_id"""
|
||||
try:
|
||||
from sqlalchemy import update
|
||||
|
||||
stmt = (
|
||||
update(EndUser)
|
||||
.where(EndUser.app_id == app_id)
|
||||
.where(EndUser.workspace_id == workspace_id)
|
||||
.values(memory_config_id=memory_config_id)
|
||||
)
|
||||
|
||||
|
||||
result = self.db.execute(stmt)
|
||||
self.db.commit()
|
||||
|
||||
|
||||
updated_count = result.rowcount
|
||||
|
||||
|
||||
db_logger.info(
|
||||
f"批量更新终端用户记忆配置: app_id={app_id}, "
|
||||
f"批量更新终端用户记忆配置: workspace_id={workspace_id}, "
|
||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||
)
|
||||
|
||||
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(
|
||||
f"批量更新终端用户记忆配置时出错: app_id={app_id}, "
|
||||
f"批量更新终端用户记忆配置时出错: workspace_id={workspace_id}, "
|
||||
f"memory_config_id={memory_config_id}, error={str(e)}"
|
||||
)
|
||||
raise
|
||||
@@ -492,7 +541,7 @@ class EndUserRepository:
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import update
|
||||
|
||||
|
||||
stmt = (
|
||||
update(EndUser)
|
||||
.where(EndUser.memory_config_id == memory_config_id)
|
||||
@@ -519,10 +568,16 @@ class EndUserRepository:
|
||||
)
|
||||
raise
|
||||
|
||||
def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]:
|
||||
"""根据应用ID查询宿主(返回 EndUser ORM 列表)"""
|
||||
# def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]:
|
||||
# """根据应用ID查询宿主(返回 EndUser ORM 列表)"""
|
||||
# repo = EndUserRepository(db)
|
||||
# end_users = repo.get_end_users_by_app_id(app_id)
|
||||
# return end_users
|
||||
|
||||
def get_end_users_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[EndUser]:
|
||||
"""根据工作空间ID查询终端用户(返回 EndUser ORM 列表)"""
|
||||
repo = EndUserRepository(db)
|
||||
end_users = repo.get_end_users_by_app_id(app_id)
|
||||
end_users = repo.get_end_users_by_workspace(workspace_id)
|
||||
return end_users
|
||||
|
||||
def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
||||
|
||||
@@ -27,7 +27,7 @@ class ToolRepository:
|
||||
from app.models.app_model import App
|
||||
from app.models.workflow_model import WorkflowConfig
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
|
||||
result = db.query(Workspace.tenant_id).join(
|
||||
App, App.workspace_id == Workspace.id
|
||||
).join(
|
||||
@@ -35,7 +35,7 @@ class ToolRepository:
|
||||
).filter(
|
||||
WorkflowConfig.id == workflow_id
|
||||
).first()
|
||||
|
||||
|
||||
return result[0] if result else None
|
||||
|
||||
@staticmethod
|
||||
@@ -67,18 +67,19 @@ class ToolRepository:
|
||||
|
||||
@staticmethod
|
||||
def find_by_tenant(
|
||||
db: Session,
|
||||
tenant_id: uuid.UUID,
|
||||
name: Optional[str] = None,
|
||||
tool_type: Optional[ToolType] = None,
|
||||
status: Optional[ToolStatus] = None,
|
||||
is_enabled: Optional[bool] = None
|
||||
db: Session,
|
||||
tenant_id: uuid.UUID,
|
||||
name: Optional[str] = None,
|
||||
tool_type: Optional[ToolType] = None,
|
||||
status: Optional[ToolStatus] = None,
|
||||
is_enabled: Optional[bool] = None
|
||||
) -> List[ToolConfig]:
|
||||
"""根据租户查找工具"""
|
||||
"""根据租户查找工具(只返回未删除的)"""
|
||||
query = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.is_active.is_(True)
|
||||
)
|
||||
|
||||
|
||||
if name:
|
||||
query = query.filter(ToolConfig.name.ilike(f"%{name}%"))
|
||||
if tool_type:
|
||||
@@ -91,8 +92,17 @@ class ToolRepository:
|
||||
return query.all()
|
||||
|
||||
@staticmethod
|
||||
def find_by_id_and_tenant(db:Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||||
"""根据ID和租户查找工具"""
|
||||
def find_by_id_and_tenant(db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||||
"""根据ID和租户查找工具(只返回未删除的)"""
|
||||
return db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
@staticmethod
|
||||
def find_by_id_and_tenant_all(db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||||
"""根据ID和租户查找工具(返回所有工具包括删除的)"""
|
||||
return db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
@@ -100,29 +110,26 @@ class ToolRepository:
|
||||
|
||||
@staticmethod
|
||||
def count_by_tenant(db: Session, tenant_id: uuid.UUID) -> int:
|
||||
"""统计租户工具数量"""
|
||||
"""统计租户工具数量(只统计未删除的)"""
|
||||
return db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.is_active.is_(True)
|
||||
).count()
|
||||
|
||||
@staticmethod
|
||||
def get_status_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
|
||||
"""获取状态统计"""
|
||||
return db.query(
|
||||
ToolConfig.status,
|
||||
func.count(ToolConfig.id).label('count')
|
||||
).filter(
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
return db.query(ToolConfig.status, func.count(ToolConfig.id).label('count')).filter(
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.is_active.is_(True)
|
||||
).group_by(ToolConfig.status).all()
|
||||
|
||||
@staticmethod
|
||||
def get_type_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
|
||||
"""获取类型统计"""
|
||||
return db.query(
|
||||
ToolConfig.tool_type,
|
||||
func.count(ToolConfig.id).label('count')
|
||||
).filter(
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
return db.query(ToolConfig.tool_type, func.count(ToolConfig.id).label('count')).filter(
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.is_active.is_(True)
|
||||
).group_by(ToolConfig.tool_type).all()
|
||||
|
||||
@staticmethod
|
||||
@@ -130,6 +137,7 @@ class ToolRepository:
|
||||
"""统计租户启用的工具数量"""
|
||||
return db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.is_active.is_(True),
|
||||
ToolConfig.is_enabled == True
|
||||
).count()
|
||||
|
||||
@@ -138,7 +146,8 @@ class ToolRepository:
|
||||
"""检查租户是否已有内置工具"""
|
||||
return db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.tool_type == ToolType.BUILTIN.value
|
||||
ToolConfig.tool_type == ToolType.BUILTIN.value,
|
||||
ToolConfig.is_active.is_(True)
|
||||
).count() > 0
|
||||
|
||||
|
||||
@@ -194,10 +203,10 @@ class ToolExecutionRepository:
|
||||
|
||||
@staticmethod
|
||||
def find_by_tool_and_tenant(
|
||||
db: Session,
|
||||
tool_id: uuid.UUID,
|
||||
tenant_id: uuid.UUID,
|
||||
limit: int = 100
|
||||
db: Session,
|
||||
tool_id: uuid.UUID,
|
||||
tenant_id: uuid.UUID,
|
||||
limit: int = 100
|
||||
) -> List[ToolExecution]:
|
||||
"""根据工具和租户查找执行记录"""
|
||||
return db.query(ToolExecution).join(
|
||||
@@ -205,4 +214,4 @@ class ToolExecutionRepository:
|
||||
).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
).order_by(ToolExecution.started_at.desc()).limit(limit).all()
|
||||
).order_by(ToolExecution.started_at.desc()).limit(limit).all()
|
||||
|
||||
@@ -125,6 +125,85 @@ class SkillConfig(BaseModel):
|
||||
all_skills: Optional[bool] = Field(default=False, description="是否允许访问所有技能")
|
||||
|
||||
|
||||
# ---------- App Features ----------
|
||||
|
||||
class FileUploadConfig(BaseModel):
|
||||
"""文件上传配置"""
|
||||
enabled: bool = Field(default=False)
|
||||
# 允许的传输方式:local_file / remote_url,默认两种都允许
|
||||
allowed_transfer_methods: List[str] = Field(
|
||||
default=["local_file", "remote_url"],
|
||||
description="允许的传输方式"
|
||||
)
|
||||
# 图片文件:PNG/JPG/JPEG/GIF/WEBP,最大 20MB
|
||||
image_enabled: bool = Field(default=False)
|
||||
image_max_size_mb: int = Field(default=20)
|
||||
image_allowed_extensions: List[str] = Field(
|
||||
default=["png", "jpg", "jpeg", "gif", "webp"]
|
||||
)
|
||||
# 语音文件:MP3/WAV/M4A/OGG/FLAC,最大 50MB
|
||||
audio_enabled: bool = Field(default=False)
|
||||
audio_max_size_mb: int = Field(default=50)
|
||||
audio_allowed_extensions: List[str] = Field(
|
||||
default=["mp3", "wav", "m4a", "ogg", "flac"]
|
||||
)
|
||||
# 通用文件:PDF/DOCX/XLSX/TXT/CSV/JSON,最大 100MB
|
||||
document_enabled: bool = Field(default=False)
|
||||
document_max_size_mb: int = Field(default=100)
|
||||
document_allowed_extensions: List[str] = Field(
|
||||
default=["pdf", "docx", "xlsx", "txt", "csv", "json"]
|
||||
)
|
||||
# 视频文件:MP4/MOV/AVI/WebM,最大 500MB
|
||||
video_enabled: bool = Field(default=False)
|
||||
video_max_size_mb: int = Field(default=500)
|
||||
video_allowed_extensions: List[str] = Field(
|
||||
default=["mp4", "mov", "avi", "webm"]
|
||||
)
|
||||
# 最大文件数量
|
||||
max_file_count: int = Field(default=5, ge=1, le=20)
|
||||
|
||||
|
||||
class OpeningStatementConfig(BaseModel):
|
||||
"""对话开场白配置"""
|
||||
enabled: bool = Field(default=False)
|
||||
statement: Optional[str] = Field(default=None, description="开场白内容")
|
||||
suggested_questions: List[str] = Field(default_factory=list, description="预设问题列表")
|
||||
|
||||
|
||||
class SuggestedQuestionsConfig(BaseModel):
|
||||
"""下一步问题建议配置"""
|
||||
enabled: bool = Field(default=False)
|
||||
|
||||
|
||||
class TextToSpeechConfig(BaseModel):
|
||||
"""文字转语音配置"""
|
||||
enabled: bool = Field(default=False)
|
||||
voice: Optional[str] = Field(default=None, description="语音音色")
|
||||
language: Optional[str] = Field(default=None, description="语言")
|
||||
autoplay: bool = Field(default=False, description="是否自动播放")
|
||||
|
||||
|
||||
class CitationConfig(BaseModel):
|
||||
"""引用和归属配置"""
|
||||
enabled: bool = Field(default=False)
|
||||
|
||||
|
||||
class WebSearchConfig(BaseModel):
|
||||
"""联网搜索配置"""
|
||||
enabled: bool = Field(default=False)
|
||||
search_engine: Optional[str] = Field(default=None, description="搜索引擎")
|
||||
|
||||
|
||||
class AppFeatures(BaseModel):
|
||||
"""应用功能特性配置"""
|
||||
file_upload: FileUploadConfig = Field(default_factory=FileUploadConfig)
|
||||
opening_statement: OpeningStatementConfig = Field(default_factory=OpeningStatementConfig)
|
||||
suggested_questions_after_answer: SuggestedQuestionsConfig = Field(default_factory=SuggestedQuestionsConfig)
|
||||
text_to_speech: TextToSpeechConfig = Field(default_factory=TextToSpeechConfig)
|
||||
citation: CitationConfig = Field(default_factory=CitationConfig)
|
||||
web_search: WebSearchConfig = Field(default_factory=WebSearchConfig)
|
||||
|
||||
|
||||
class ToolOldConfig(BaseModel):
|
||||
"""工具配置"""
|
||||
enabled: bool = Field(default=False, description="是否启用该工具")
|
||||
@@ -201,6 +280,9 @@ class AgentConfigCreate(BaseModel):
|
||||
# 技能配置
|
||||
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
|
||||
|
||||
# 功能特性
|
||||
features: Optional[AppFeatures] = Field(default=None, description="功能特性配置")
|
||||
|
||||
|
||||
class AppCreate(BaseModel):
|
||||
name: str
|
||||
@@ -258,6 +340,9 @@ class AgentConfigUpdate(BaseModel):
|
||||
# 技能配置
|
||||
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
|
||||
|
||||
# 功能特性
|
||||
features: Optional[AppFeatures] = Field(default=None, description="功能特性配置")
|
||||
|
||||
|
||||
# ---------- Output Schemas ----------
|
||||
|
||||
@@ -323,6 +408,8 @@ class AgentConfig(BaseModel):
|
||||
|
||||
skills: Optional[SkillConfig] = {}
|
||||
|
||||
features: Optional[AppFeatures] = None
|
||||
|
||||
is_active: bool
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
@@ -359,6 +446,14 @@ class AgentConfig(BaseModel):
|
||||
return {}
|
||||
return v
|
||||
|
||||
@field_validator("features", mode="before")
|
||||
@classmethod
|
||||
def validate_features(cls, v):
|
||||
"""处理 None 值,返回默认 AppFeatures"""
|
||||
if v is None:
|
||||
return AppFeatures()
|
||||
return v
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -500,12 +595,35 @@ class DraftRunRequest(BaseModel):
|
||||
files: Optional[List[FileInput]] = Field(default_factory=list, description="附件列表(支持多文件)")
|
||||
|
||||
|
||||
class SuggestedQuestion(BaseModel):
|
||||
"""建议问题"""
|
||||
content: str
|
||||
|
||||
|
||||
class CitationSource(BaseModel):
|
||||
"""引用来源"""
|
||||
title: str
|
||||
content: str
|
||||
score: Optional[float] = None
|
||||
kb_id: Optional[str] = None
|
||||
|
||||
|
||||
class DraftRunResponse(BaseModel):
|
||||
"""试运行响应(非流式)"""
|
||||
message: str = Field(..., description="AI 回复消息")
|
||||
conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)")
|
||||
usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况")
|
||||
elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)")
|
||||
suggested_questions: List[str] = Field(default_factory=list, description="下一步建议问题")
|
||||
citations: List[CitationSource] = Field(default_factory=list, description="引用来源")
|
||||
audio_url: Optional[str] = Field(default=None, description="TTS 语音URL")
|
||||
|
||||
|
||||
class OpeningResponse(BaseModel):
|
||||
"""应用开场白响应"""
|
||||
enabled: bool
|
||||
statement: Optional[str] = None
|
||||
suggested_questions: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DraftRunStreamChunk(BaseModel):
|
||||
|
||||
@@ -8,7 +8,7 @@ class EndUser(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID = Field(description="终端用户ID")
|
||||
app_id: uuid.UUID = Field(description="应用ID")
|
||||
app_id: Optional[uuid.UUID] = Field(description="应用ID", default=None)
|
||||
# end_user_id: str = Field(description="终端用户ID")
|
||||
other_id: Optional[str] = Field(description="第三方ID", default=None)
|
||||
other_name: Optional[str] = Field(description="其他名称", default="")
|
||||
|
||||
@@ -90,6 +90,7 @@ class ToolInfo(BaseModel):
|
||||
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
|
||||
config_data: Dict[str, Any] = Field(default_factory=dict, description="工具配置")
|
||||
status: ToolStatus = Field(ToolStatus.AVAILABLE, description="工具状态")
|
||||
is_active: bool = Field(True, description="是否可用(False 表示已删除)")
|
||||
tags: List[str] = Field(default_factory=list, description="工具标签")
|
||||
tenant_id: Optional[str] = Field(None, description="租户ID")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
@@ -212,6 +213,11 @@ class ToolUpdateRequest(BaseModel):
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ToolActiveUpdate(BaseModel):
|
||||
"""工具可用状态更新"""
|
||||
is_active: bool = Field(..., description="True=启用, False=禁用(逻辑删除)")
|
||||
|
||||
|
||||
class ToolExecuteRequest(BaseModel):
|
||||
"""执行工具请求"""
|
||||
tool_id: str
|
||||
|
||||
@@ -51,6 +51,9 @@ class AgentConfigConverter:
|
||||
|
||||
if hasattr(config, "skills") and config.skills:
|
||||
result["skills"] = config.skills.model_dump()
|
||||
|
||||
if hasattr(config, "features") and config.features:
|
||||
result["features"] = config.features.model_dump()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -49,12 +49,23 @@ class AppChatService:
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||
files: Optional[List[FileInput]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
|
||||
# 应用 features 配置
|
||||
features_config: dict = config.features or {}
|
||||
if hasattr(features_config, 'model_dump'):
|
||||
features_config = features_config.model_dump()
|
||||
web_search_feature = features_config.get("web_search", {})
|
||||
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
|
||||
web_search = False
|
||||
|
||||
# 校验文件上传
|
||||
self.agent_service._validate_file_upload(features_config, files)
|
||||
|
||||
variables = self.agent_service.prepare_variables(variables, config.variables)
|
||||
|
||||
# 获取模型配置ID
|
||||
@@ -107,17 +118,14 @@ class AppChatService:
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
history = []
|
||||
memory_config = {"enabled": True, 'max_history': 10}
|
||||
if memory_config.get("enabled"):
|
||||
messages = self.conversation_service.get_messages(
|
||||
conversation_id=conversation_id,
|
||||
limit=memory_config.get("max_history", 10)
|
||||
)
|
||||
history = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
messages = self.conversation_service.get_messages(
|
||||
conversation_id=conversation_id,
|
||||
limit=10
|
||||
)
|
||||
history = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
@@ -166,6 +174,23 @@ class AppChatService:
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# suggested_questions
|
||||
suggested_questions = []
|
||||
sq_config = features_config.get("suggested_questions_after_answer", {})
|
||||
if isinstance(sq_config, dict) and sq_config.get("enabled"):
|
||||
suggested_questions = await self.agent_service._generate_suggested_questions(
|
||||
features_config, result["content"],
|
||||
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
|
||||
"api_base": api_key_obj.api_base}, {}
|
||||
)
|
||||
|
||||
audio_url = await self.agent_service._generate_tts(
|
||||
features_config, result["content"],
|
||||
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
|
||||
"api_base": api_key_obj.api_base, "provider": api_key_obj.provider},
|
||||
tenant_id=tenant_id, workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return {
|
||||
"conversation_id": conversation_id,
|
||||
"message_id": str(message_id),
|
||||
@@ -175,7 +200,10 @@ class AppChatService:
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}),
|
||||
"elapsed_time": elapsed_time
|
||||
"elapsed_time": elapsed_time,
|
||||
"suggested_questions": suggested_questions,
|
||||
"citations": self.agent_service._filter_citations(features_config, result.get("citations", [])),
|
||||
"audio_url": audio_url,
|
||||
}
|
||||
|
||||
async def agnet_chat_stream(
|
||||
@@ -190,7 +218,7 @@ class AppChatService:
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||
files: Optional[List[FileInput]] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
|
||||
@@ -198,10 +226,19 @@ class AppChatService:
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
message_id = uuid.uuid4()
|
||||
yield f"event: start\ndata: {json.dumps({
|
||||
'conversation_id': str(conversation_id),
|
||||
"message_id": str(message_id)
|
||||
}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 应用 features 配置
|
||||
features_config: dict = config.features or {}
|
||||
if hasattr(features_config, 'model_dump'):
|
||||
features_config = features_config.model_dump()
|
||||
web_search_feature = features_config.get("web_search", {})
|
||||
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
|
||||
web_search = False
|
||||
|
||||
# 校验文件上传
|
||||
self.agent_service._validate_file_upload(features_config, files)
|
||||
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), 'message_id': str(message_id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
variables = self.agent_service.prepare_variables(variables, config.variables)
|
||||
# 获取模型配置ID
|
||||
@@ -327,8 +364,22 @@ class AppChatService:
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
# 发送结束事件
|
||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
|
||||
# 发送结束事件(包含 suggested_questions、tts、citations)
|
||||
end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
|
||||
sq_config = features_config.get("suggested_questions_after_answer", {})
|
||||
if isinstance(sq_config, dict) and sq_config.get("enabled"):
|
||||
end_data["suggested_questions"] = await self.agent_service._generate_suggested_questions(
|
||||
features_config, full_content,
|
||||
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
|
||||
"api_base": api_key_obj.api_base}, {}
|
||||
)
|
||||
end_data["audio_url"] = await self.agent_service._generate_tts(
|
||||
features_config, full_content,
|
||||
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
|
||||
"api_base": api_key_obj.api_base, "provider": api_key_obj.provider},
|
||||
tenant_id=tenant_id, workspace_id=workspace_id
|
||||
)
|
||||
end_data["citations"] = self.agent_service._filter_citations(features_config, [])
|
||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
logger.info(
|
||||
@@ -442,7 +493,7 @@ class AppChatService:
|
||||
try:
|
||||
message_id = uuid.uuid4()
|
||||
# 发送开始事件
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), "message_id": str(message_id)}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), 'message_id': str(message_id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
full_content = ""
|
||||
total_tokens = 0
|
||||
|
||||
@@ -109,7 +109,7 @@ class AppService:
|
||||
|
||||
return share is not None
|
||||
|
||||
def _validate_app_accessible(self, app: App, workspace_id: Optional[uuid.UUID]) -> None:
|
||||
def _validate_app_accessible(self, app: App, workspace_id: Optional[uuid.UUID]) -> None:
|
||||
"""验证应用是否可访问(包括共享应用,用于只读操作)
|
||||
|
||||
Args:
|
||||
@@ -360,6 +360,7 @@ class AppService:
|
||||
variables=storage_data.get("variables", []),
|
||||
tools=storage_data.get("tools", []),
|
||||
skills=storage_data.get("skills", {}),
|
||||
features=storage_data.get("features", {}),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -1073,6 +1074,7 @@ class AppService:
|
||||
# if data.tools is not None:
|
||||
agent_cfg.tools = storage_data.get("tools", [])
|
||||
agent_cfg.skills = storage_data.get("skills", {})
|
||||
agent_cfg.features = storage_data.get("features", {})
|
||||
|
||||
agent_cfg.updated_at = now
|
||||
|
||||
@@ -1173,6 +1175,7 @@ class AppService:
|
||||
variables=[],
|
||||
tools=[],
|
||||
skills=[],
|
||||
features={},
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -1389,15 +1392,15 @@ class AppService:
|
||||
|
||||
return config.config_id
|
||||
|
||||
def _update_endusers_memory_config(
|
||||
def _update_endusers_memory_config_by_workspace(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
memory_config_id: uuid.UUID
|
||||
) -> int:
|
||||
"""批量更新应用下所有终端用户的 memory_config_id
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
workspace_id: 工作空间ID
|
||||
memory_config_id: 新的记忆配置ID
|
||||
|
||||
Returns:
|
||||
@@ -1406,8 +1409,8 @@ class AppService:
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
repo = EndUserRepository(self.db)
|
||||
updated_count = repo.batch_update_memory_config_id(
|
||||
app_id=app_id,
|
||||
updated_count = repo.batch_update_memory_config_id_by_workspace(
|
||||
workspace_id=workspace_id,
|
||||
memory_config_id=memory_config_id
|
||||
)
|
||||
|
||||
@@ -1578,11 +1581,15 @@ class AppService:
|
||||
)
|
||||
|
||||
if memory_config_id:
|
||||
updated_count = self._update_endusers_memory_config(app_id, memory_config_id)
|
||||
logger.info(
|
||||
f"发布时更新终端用户记忆配置: app_id={app_id}, "
|
||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||
)
|
||||
app = self.db.query(App).filter(App.id == app_id).first()
|
||||
if app:
|
||||
updated_count = self._update_endusers_memory_config_by_workspace(
|
||||
app.workspace_id, memory_config_id
|
||||
)
|
||||
logger.info(
|
||||
f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, "
|
||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||
)
|
||||
|
||||
# 更新当前发布版本指针
|
||||
app.current_release_id = release.id
|
||||
|
||||
@@ -18,6 +18,7 @@ from sqlalchemy.orm import Session
|
||||
from app.celery_app import celery_app
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -262,9 +263,12 @@ class AgentRunService:
|
||||
|
||||
def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
|
||||
"""加载工具配置"""
|
||||
if not tools_config:
|
||||
return []
|
||||
tools = []
|
||||
if web_search:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
if not tools_config:
|
||||
return tools
|
||||
tool_service = ToolService(self.db)
|
||||
|
||||
if tools_config and isinstance(tools_config, list):
|
||||
@@ -273,24 +277,15 @@ class AgentRunService:
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif tools_config and isinstance(tools_config, dict):
|
||||
web_search_choice = tools_config.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search and web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
return tools
|
||||
|
||||
def load_skill_config(
|
||||
@@ -373,6 +368,86 @@ class AgentRunService:
|
||||
)
|
||||
return tools, bool(memory_config.get("enabled"))
|
||||
|
||||
@staticmethod
|
||||
def _validate_file_upload(
|
||||
features_config: Dict[str, Any],
|
||||
files: Optional[List[FileInput]]
|
||||
) -> None:
|
||||
"""校验上传文件是否符合 file_upload 配置"""
|
||||
if not files:
|
||||
return
|
||||
fu = features_config.get("file_upload", {})
|
||||
if not (isinstance(fu, dict) and fu.get("enabled")):
|
||||
raise BusinessException("该应用未开启文件上传功能", BizCode.BAD_REQUEST)
|
||||
max_count = fu.get("max_file_count", 5)
|
||||
if len(files) > max_count:
|
||||
raise BusinessException(f"文件数量超过限制(最多 {max_count} 个)", BizCode.BAD_REQUEST)
|
||||
|
||||
# 校验传输方式
|
||||
allowed_methods = fu.get("allowed_transfer_methods", ["local_file", "remote_url"])
|
||||
for f in files:
|
||||
if str(f.transfer_method) not in allowed_methods:
|
||||
raise BusinessException(
|
||||
f"不支持的文件传输方式:{f.transfer_method},允许的方式:{', '.join(allowed_methods)}",
|
||||
BizCode.BAD_REQUEST
|
||||
)
|
||||
|
||||
# 各类型对应的开关和大小限制配置键
|
||||
type_cfg = {
|
||||
"image": ("image_enabled", "image_max_size_mb", 20, "图片"),
|
||||
"audio": ("audio_enabled", "audio_max_size_mb", 50, "音频"),
|
||||
"document": ("document_enabled", "document_max_size_mb", 100, "文档"),
|
||||
"video": ("video_enabled", "video_max_size_mb", 500, "视频"),
|
||||
}
|
||||
|
||||
for f in files:
|
||||
ftype = str(f.type) # 如 "image", "audio", "document", "video"
|
||||
cfg = type_cfg.get(ftype)
|
||||
if cfg is None:
|
||||
continue
|
||||
enabled_key, size_key, default_max_mb, label = cfg
|
||||
|
||||
# 校验类型开关
|
||||
if not fu.get(enabled_key):
|
||||
raise BusinessException(f"该应用未开启{label}文件上传", BizCode.BAD_REQUEST)
|
||||
|
||||
# 校验文件大小(仅当内容已加载时)
|
||||
content = f.get_content()
|
||||
if content is not None:
|
||||
max_mb = fu.get(size_key, default_max_mb)
|
||||
size_mb = len(content) / (1024 * 1024)
|
||||
if size_mb > max_mb:
|
||||
raise BusinessException(
|
||||
f"{label}文件大小超过限制(最大 {max_mb}MB,当前 {size_mb:.1f}MB)",
|
||||
BizCode.BAD_REQUEST
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _inject_opening_statement(
|
||||
features_config: Dict[str, Any],
|
||||
system_prompt: str,
|
||||
is_new_conversation: bool
|
||||
) -> str:
|
||||
"""首轮对话时将开场白注入 system_prompt"""
|
||||
if not is_new_conversation:
|
||||
return system_prompt
|
||||
opening = features_config.get("opening_statement", {})
|
||||
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
|
||||
return system_prompt
|
||||
statement = opening["statement"]
|
||||
return f"{system_prompt}\n\n[对话开场白]\n{statement}"
|
||||
|
||||
@staticmethod
|
||||
def _filter_citations(
|
||||
features_config: Dict[str, Any],
|
||||
citations: List[Any]
|
||||
) -> List[Any]:
|
||||
"""根据 citation 开关决定是否返回引用来源"""
|
||||
citation_cfg = features_config.get("citation", {})
|
||||
if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"):
|
||||
return citations
|
||||
return []
|
||||
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
@@ -415,6 +490,15 @@ class AgentRunService:
|
||||
skills_config: dict | None = agent_config.skills
|
||||
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
||||
memory_config: dict | None = agent_config.memory
|
||||
features_config: dict = agent_config.features or {}
|
||||
|
||||
# 从 features 中读取功能开关(优先级高于参数默认值)
|
||||
web_search_feature = features_config.get("web_search", {})
|
||||
if not isinstance(web_search_feature, dict) or not web_search_feature.get("enabled"):
|
||||
web_search = False
|
||||
|
||||
# file_upload 校验
|
||||
self._validate_file_upload(features_config, files)
|
||||
|
||||
try:
|
||||
# 1. 获取 API Key 配置
|
||||
@@ -449,6 +533,10 @@ class AgentRunService:
|
||||
# 3. 处理系统提示词(支持变量替换)
|
||||
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
||||
|
||||
# opening_statement:首轮对话注入开场白
|
||||
is_new_conversation = not conversation_id
|
||||
system_prompt = self._inject_opening_statement(features_config, system_prompt, is_new_conversation)
|
||||
|
||||
# 4. 准备工具列表
|
||||
tools = []
|
||||
|
||||
@@ -491,12 +579,10 @@ class AgentRunService:
|
||||
)
|
||||
|
||||
# 6. 加载历史消息
|
||||
history = []
|
||||
if memory_config and memory_config.get("enabled"):
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=agent_config.memory.get("max_history", 10)
|
||||
)
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=10
|
||||
)
|
||||
|
||||
# 6. 处理多模态文件
|
||||
processed_files = None
|
||||
@@ -551,7 +637,7 @@ class AgentRunService:
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
||||
|
||||
# 9. 保存会话消息
|
||||
if not sub_agent and memory_config and memory_config.get("enabled"):
|
||||
if not sub_agent:
|
||||
await self._save_conversation_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=message,
|
||||
@@ -575,7 +661,15 @@ class AgentRunService:
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}),
|
||||
"elapsed_time": elapsed_time
|
||||
"elapsed_time": elapsed_time,
|
||||
"suggested_questions": await self._generate_suggested_questions(
|
||||
features_config, result["content"], api_key_config, effective_params
|
||||
) if not sub_agent else [],
|
||||
"citations": self._filter_citations(features_config, result.get("citations", [])),
|
||||
"audio_url": await self._generate_tts(
|
||||
features_config, result["content"], api_key_config,
|
||||
tenant_id=tenant_id, workspace_id=workspace_id
|
||||
) if not sub_agent else None,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
@@ -630,6 +724,15 @@ class AgentRunService:
|
||||
skills_config: dict | None = agent_config.skills
|
||||
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
||||
memory_config: dict | None = agent_config.memory
|
||||
features_config: dict = agent_config.features or {}
|
||||
|
||||
# 从 features 中读取功能开关
|
||||
web_search_feature = features_config.get("web_search", {})
|
||||
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
|
||||
web_search = False
|
||||
|
||||
# file_upload 校验
|
||||
self._validate_file_upload(features_config, files)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
@@ -659,6 +762,10 @@ class AgentRunService:
|
||||
# 3. 处理系统提示词(支持变量替换)
|
||||
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
||||
|
||||
# opening_statement:首轮对话注入开场白
|
||||
is_new_conversation = not conversation_id
|
||||
system_prompt = self._inject_opening_statement(features_config, system_prompt, is_new_conversation)
|
||||
|
||||
# 4. 准备工具列表
|
||||
tools = []
|
||||
|
||||
@@ -703,12 +810,10 @@ class AgentRunService:
|
||||
)
|
||||
|
||||
# 6. 加载历史消息
|
||||
history = []
|
||||
if memory_config and memory_config.get("enabled"):
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=memory_config.get("max_history", 10)
|
||||
)
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=memory_config.get("max_history", 10)
|
||||
)
|
||||
|
||||
# 6. 处理多模态文件
|
||||
processed_files = None
|
||||
@@ -774,7 +879,7 @@ class AgentRunService:
|
||||
})
|
||||
|
||||
# 10. 保存会话消息
|
||||
if not sub_agent and memory_config and memory_config.get("enabled"):
|
||||
if not sub_agent:
|
||||
await self._save_conversation_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=message,
|
||||
@@ -786,12 +891,22 @@ class AgentRunService:
|
||||
}
|
||||
)
|
||||
|
||||
# 11. 发送结束事件
|
||||
yield self._format_sse_event("end", {
|
||||
# 11. 发送结束事件(包含 suggested_questions 和 tts)
|
||||
end_data: Dict[str, Any] = {
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"message_length": len(full_content)
|
||||
})
|
||||
}
|
||||
if not sub_agent:
|
||||
end_data["suggested_questions"] = await self._generate_suggested_questions(
|
||||
features_config, full_content, api_key_config, effective_params
|
||||
)
|
||||
end_data["audio_url"] = await self._generate_tts(
|
||||
features_config, full_content, api_key_config,
|
||||
tenant_id=tenant_id, workspace_id=workspace_id
|
||||
)
|
||||
end_data["citations"] = self._filter_citations(features_config, [])
|
||||
yield self._format_sse_event("end", end_data)
|
||||
|
||||
logger.info(
|
||||
"流式试运行完成",
|
||||
@@ -1137,6 +1252,165 @@ class AgentRunService:
|
||||
logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)})
|
||||
return {}
|
||||
|
||||
async def _generate_suggested_questions(
|
||||
self,
|
||||
features_config: Dict[str, Any],
|
||||
assistant_message: str,
|
||||
api_key_config: Dict[str, Any],
|
||||
effective_params: Dict[str, Any]
|
||||
) -> List[str]:
|
||||
"""根据 suggested_questions_after_answer 配置生成下一步建议问题"""
|
||||
sq_config = features_config.get("suggested_questions_after_answer", {})
|
||||
if not isinstance(sq_config, dict) or not sq_config.get("enabled"):
|
||||
return []
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
llm = ChatOpenAI(
|
||||
model=api_key_config["model_name"],
|
||||
api_key=api_key_config["api_key"],
|
||||
base_url=api_key_config.get("api_base"),
|
||||
temperature=0.5,
|
||||
max_tokens=200,
|
||||
)
|
||||
prompt = (
|
||||
f"根据以下AI回复,生成3个用户可能继续追问的简短问题,每行一个,不加序号:\n\n{assistant_message}"
|
||||
)
|
||||
resp = await llm.ainvoke([HumanMessage(content=prompt)])
|
||||
lines = [l.strip() for l in resp.content.strip().split("\n") if l.strip()]
|
||||
return lines[:3]
|
||||
except Exception as e:
|
||||
logger.warning(f"生成建议问题失败: {e}")
|
||||
return []
|
||||
|
||||
async def _generate_tts(
|
||||
self,
|
||||
features_config: Dict[str, Any],
|
||||
text: str,
|
||||
api_key_config: Dict[str, Any],
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
) -> Optional[str]:
|
||||
"""根据 text_to_speech 配置生成语音,上传到存储并返回 URL"""
|
||||
tts_config = features_config.get("text_to_speech", {})
|
||||
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
|
||||
return None
|
||||
if not text or not text.strip():
|
||||
return None
|
||||
|
||||
try:
|
||||
from app.services.file_storage_service import FileStorageService
|
||||
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
api_key = api_key_config.get("api_key")
|
||||
api_base = api_key_config.get("api_base")
|
||||
voice = tts_config.get("voice")
|
||||
|
||||
if provider == "dashscope":
|
||||
audio_bytes, file_ext, content_type = await self._tts_dashscope(
|
||||
api_key=api_key,
|
||||
text=text,
|
||||
voice=voice or "longxiaochun", # 会根据 model 版本自动修正后缀
|
||||
tts_config=tts_config,
|
||||
)
|
||||
else:
|
||||
# OpenAI 兼容接口(openai / xinference / gpustack 等)
|
||||
audio_bytes, file_ext, content_type = await self._tts_openai(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
text=text,
|
||||
voice=voice or "alloy",
|
||||
)
|
||||
|
||||
storage_service = FileStorageService()
|
||||
file_id = uuid.uuid4()
|
||||
file_key = await storage_service.upload_file(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
content=audio_bytes,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
# 保存文件元数据到数据库
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
db_file = FileMetadata(
|
||||
id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_key=file_key,
|
||||
file_name=f"tts_{file_id}{file_ext}",
|
||||
file_ext=file_ext,
|
||||
file_size=len(audio_bytes),
|
||||
content_type=content_type,
|
||||
status="completed",
|
||||
)
|
||||
self.db.add(db_file)
|
||||
self.db.commit()
|
||||
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
audio_url = f"{server_url}/storage/permanent/{file_id}"
|
||||
logger.debug(f"TTS 生成成功,provider={provider}, file_key={file_key}")
|
||||
return audio_url
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS 生成失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _tts_openai(
|
||||
api_key: str,
|
||||
api_base: Optional[str],
|
||||
text: str,
|
||||
voice: str,
|
||||
) -> tuple:
|
||||
"""OpenAI 兼容 TTS,返回 (audio_bytes, file_ext, content_type)"""
|
||||
from openai import AsyncOpenAI
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
||||
response = await client.audio.speech.create(
|
||||
model="tts-1",
|
||||
voice=voice,
|
||||
input=text[:4096],
|
||||
)
|
||||
return response.content, ".mp3", "audio/mpeg"
|
||||
|
||||
@staticmethod
|
||||
async def _tts_dashscope(
|
||||
api_key: str,
|
||||
text: str,
|
||||
voice: str,
|
||||
tts_config: Dict[str, Any],
|
||||
) -> tuple:
|
||||
"""DashScope CosyVoice TTS,返回 (audio_bytes, file_ext, content_type)"""
|
||||
import dashscope
|
||||
from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat
|
||||
|
||||
model = tts_config.get("model") or "cosyvoice-v2"
|
||||
is_v2 = model.endswith("-v2")
|
||||
|
||||
# cosyvoice-v2 音色名带 _v2 后缀,v1 不带
|
||||
# 如果用户传入的 voice 不匹配当前模型版本,自动修正
|
||||
if is_v2 and not voice.endswith("_v2"):
|
||||
voice = voice + "_v2"
|
||||
elif not is_v2 and voice.endswith("_v2"):
|
||||
voice = voice[:-3] # 去掉 _v2
|
||||
|
||||
def _sync_call() -> bytes:
|
||||
dashscope.api_key = api_key
|
||||
synthesizer = SpeechSynthesizer(
|
||||
model=model,
|
||||
voice=voice,
|
||||
format=AudioFormat.MP3_22050HZ_MONO_256KBPS,
|
||||
)
|
||||
audio = synthesizer.call(text[:4096])
|
||||
if audio is None:
|
||||
raise RuntimeError("DashScope TTS 返回空音频")
|
||||
return audio
|
||||
|
||||
audio_bytes = await asyncio.to_thread(_sync_call)
|
||||
return audio_bytes, ".mp3", "audio/mpeg"
|
||||
|
||||
def _replace_variables(
|
||||
self,
|
||||
text: str,
|
||||
@@ -1221,6 +1495,12 @@ class AgentRunService:
|
||||
}
|
||||
)
|
||||
|
||||
# 提前校验文件上传(与 run() 内部保持一致)
|
||||
features_config: dict = agent_config.features or {}
|
||||
if hasattr(features_config, 'model_dump'):
|
||||
features_config = features_config.model_dump()
|
||||
# self._validate_file_upload(features_config, files)
|
||||
|
||||
async def run_single_model(model_info):
|
||||
"""运行单个模型"""
|
||||
try:
|
||||
@@ -1271,6 +1551,9 @@ class AgentRunService:
|
||||
if elapsed > 0 and usage.get("completion_tokens") else None
|
||||
),
|
||||
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
|
||||
"audio_url": result.get("audio_url"),
|
||||
"citations": result.get("citations", []),
|
||||
"suggested_questions": result.get("suggested_questions", []),
|
||||
"error": None
|
||||
}
|
||||
|
||||
@@ -1343,7 +1626,12 @@ class AgentRunService:
|
||||
)
|
||||
|
||||
return {
|
||||
"results": results,
|
||||
"results": [{
|
||||
**r,
|
||||
"audio_url": r.get("audio_url"),
|
||||
"citations": r.get("citations", []),
|
||||
"suggested_questions": r.get("suggested_questions", []),
|
||||
} for r in results],
|
||||
"total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results),
|
||||
"successful_count": len(successful),
|
||||
"failed_count": len(failed),
|
||||
@@ -1434,6 +1722,12 @@ class AgentRunService:
|
||||
extra={"model_count": len(models), "parallel": parallel}
|
||||
)
|
||||
|
||||
# 提前校验文件上传
|
||||
# features_config: dict = agent_config.features or {}
|
||||
# if hasattr(features_config, 'model_dump'):
|
||||
# features_config = features_config.model_dump()
|
||||
# self._validate_file_upload(features_config, files)
|
||||
|
||||
# 发送开始事件
|
||||
yield self._format_sse_event("compare_start", {
|
||||
"conversation_id": conversation_id,
|
||||
@@ -1465,6 +1759,9 @@ class AgentRunService:
|
||||
start_time = time.time()
|
||||
full_content = ""
|
||||
returned_conversation_id = model_conversation_id
|
||||
audio_url = None
|
||||
citations = []
|
||||
suggested_questions = []
|
||||
|
||||
# 临时修改参数
|
||||
original_params = agent_config.model_parameters
|
||||
@@ -1518,6 +1815,12 @@ class AgentRunService:
|
||||
"content": chunk
|
||||
}))
|
||||
|
||||
# 从 end 事件中提取 features 输出字段
|
||||
if event_type == "end" and event_data:
|
||||
audio_url = event_data.get("audio_url")
|
||||
citations = event_data.get("citations", [])
|
||||
suggested_questions = event_data.get("suggested_questions", [])
|
||||
|
||||
if event_type == "error" and event_data:
|
||||
await event_queue.put(self._format_sse_event("model_error", {
|
||||
"model_index": idx,
|
||||
@@ -1543,6 +1846,9 @@ class AgentRunService:
|
||||
"parameters_used": model_info["parameters"],
|
||||
"message": full_content,
|
||||
"elapsed_time": elapsed,
|
||||
"audio_url": audio_url,
|
||||
"citations": citations,
|
||||
"suggested_questions": suggested_questions,
|
||||
"error": None
|
||||
}
|
||||
|
||||
@@ -1554,6 +1860,9 @@ class AgentRunService:
|
||||
"conversation_id": returned_conversation_id,
|
||||
"elapsed_time": elapsed,
|
||||
"message_length": len(full_content),
|
||||
"audio_url": audio_url,
|
||||
"citations": citations,
|
||||
"suggested_questions": suggested_questions,
|
||||
"timestamp": time.time()
|
||||
}))
|
||||
|
||||
@@ -1685,8 +1994,11 @@ class AgentRunService:
|
||||
"model_name": r["model_name"],
|
||||
"label": r["label"],
|
||||
"conversation_id": r.get("conversation_id"),
|
||||
"message": r.get("message"), # 包含完整消息
|
||||
"message": r.get("message"),
|
||||
"elapsed_time": r.get("elapsed_time", 0),
|
||||
"audio_url": r.get("audio_url"),
|
||||
"citations": r.get("citations", []),
|
||||
"suggested_questions": r.get("suggested_questions", []),
|
||||
"error": r.get("error")
|
||||
})
|
||||
|
||||
|
||||
@@ -68,14 +68,14 @@ def get_workspace_end_users(
|
||||
return []
|
||||
|
||||
# 提取所有 app_id
|
||||
app_ids = [app.id for app in apps_orm]
|
||||
# app_ids = [app.id for app in apps_orm]
|
||||
|
||||
# 批量查询所有 end_users(一次查询而非循环查询)
|
||||
# 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||
from app.models.end_user_model import EndUser as EndUserModel
|
||||
from sqlalchemy import desc, nullslast
|
||||
end_users_orm = db.query(EndUserModel).filter(
|
||||
EndUserModel.app_id.in_(app_ids)
|
||||
EndUserModel.workspace_id == workspace_id
|
||||
).order_by(
|
||||
nullslast(desc(EndUserModel.created_at)),
|
||||
desc(EndUserModel.id)
|
||||
|
||||
@@ -78,7 +78,7 @@ class ToolService:
|
||||
|
||||
def get_tool_info(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolInfo]:
|
||||
"""获取工具详情"""
|
||||
config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
||||
config = self.tool_repo.find_by_id_and_tenant_all(self.db, uuid.UUID(tool_id), tenant_id)
|
||||
return self._config_to_info(config) if config else None
|
||||
|
||||
def _check_name_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, exclude_id: Optional[uuid.UUID] = None):
|
||||
@@ -237,7 +237,7 @@ class ToolService:
|
||||
return False
|
||||
|
||||
def delete_tool(self, tool_id: str, tenant_id: uuid.UUID) -> bool:
|
||||
"""删除工具"""
|
||||
"""删除工具(逻辑删除)"""
|
||||
config = self._get_tool_config(tool_id, tenant_id)
|
||||
if not config:
|
||||
return False
|
||||
@@ -246,14 +246,7 @@ class ToolService:
|
||||
raise ValueError("内置工具不允许删除")
|
||||
|
||||
try:
|
||||
# 删除关联表记录
|
||||
if config.tool_type == ToolType.CUSTOM.value:
|
||||
self.db.query(CustomToolConfig).filter(CustomToolConfig.id == config.id).delete()
|
||||
elif config.tool_type == ToolType.MCP.value:
|
||||
self.db.query(MCPToolConfig).filter(MCPToolConfig.id == config.id).delete()
|
||||
|
||||
# 删除主表记录(ToolExecution会通过cascade自动删除)
|
||||
self.db.delete(config)
|
||||
config.is_active = False
|
||||
self._clear_tool_cache(tool_id)
|
||||
self.db.commit()
|
||||
return True
|
||||
@@ -262,6 +255,27 @@ class ToolService:
|
||||
logger.error(f"删除工具失败: {tool_id}, {e}")
|
||||
return False
|
||||
|
||||
def set_tool_active(self, tool_id: str, tenant_id: uuid.UUID, is_active: bool) -> bool:
|
||||
"""设置工具可用状态(启用/禁用)"""
|
||||
# 直接查询,包含 is_active=False 的记录
|
||||
config = self.db.query(ToolConfig).filter(
|
||||
ToolConfig.id == uuid.UUID(tool_id),
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
).first()
|
||||
if not config:
|
||||
return False
|
||||
if config.tool_type == ToolType.BUILTIN.value:
|
||||
raise ValueError("内置工具不允许修改可用状态")
|
||||
try:
|
||||
config.is_active = is_active
|
||||
self._clear_tool_cache(tool_id)
|
||||
self.db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"设置工具状态失败: {tool_id}, {e}")
|
||||
return False
|
||||
|
||||
async def execute_tool(
|
||||
self,
|
||||
tool_id: str,
|
||||
@@ -378,7 +392,7 @@ class ToolService:
|
||||
Returns:
|
||||
方法列表或None
|
||||
"""
|
||||
config = self._get_tool_config(tool_id, tenant_id)
|
||||
config = self._get_tool_config_all(tool_id, tenant_id)
|
||||
if not config:
|
||||
return None
|
||||
|
||||
@@ -857,16 +871,20 @@ class ToolService:
|
||||
}
|
||||
|
||||
def _get_tool_config(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||||
"""获取工具配置"""
|
||||
"""获取工具配置(仅返回 is_active=True)"""
|
||||
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
||||
|
||||
def _get_tool_config_all(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||||
"""获取工具配置(返回所有)"""
|
||||
return self.tool_repo.find_by_id_and_tenant_all(self.db, uuid.UUID(tool_id), tenant_id)
|
||||
|
||||
def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
|
||||
"""获取工具实例"""
|
||||
"""获取工具实例(仅返回 is_active=True 的工具)"""
|
||||
if tool_id in self._tool_cache:
|
||||
return self._tool_cache[tool_id]
|
||||
|
||||
config = self._get_tool_config(tool_id, tenant_id)
|
||||
if not config:
|
||||
if not config or not config.is_active:
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -980,6 +998,7 @@ class ToolService:
|
||||
tags=config.tags or [],
|
||||
tenant_id=str(config.tenant_id) if config.tenant_id else None,
|
||||
config_data=config_data,
|
||||
is_active=config.is_active,
|
||||
created_at=config.created_at
|
||||
)
|
||||
|
||||
|
||||
@@ -1292,9 +1292,9 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
# 2. 查询所有app下的end_user_id(去重)
|
||||
app_ids = [app.id for app in apps]
|
||||
# app_ids = [app.id for app in apps]
|
||||
end_users = db.query(EndUser.id).filter(
|
||||
EndUser.app_id.in_(app_ids)
|
||||
EndUser.workspace_id == workspace_id
|
||||
).distinct().all()
|
||||
|
||||
# 3. 遍历所有end_user,查询每个宿主的记忆总量并累加
|
||||
@@ -1433,9 +1433,9 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
continue
|
||||
|
||||
# 2. 查询所有app下的end_user_id(去重)
|
||||
app_ids = [app.id for app in apps]
|
||||
# app_ids = [app.id for app in apps]
|
||||
end_users = db.query(EndUser.id).filter(
|
||||
EndUser.app_id.in_(app_ids)
|
||||
EndUser.workspace_id == workspace_id
|
||||
).distinct().all()
|
||||
|
||||
# 3. 遍历所有end_user,查询每个宿主的记忆总量并累加
|
||||
|
||||
Reference in New Issue
Block a user