feat: Add base project structure with API and web components
This commit is contained in:
60
api/app/controllers/__init__.py
Normal file
60
api/app/controllers/__init__.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""管理端接口 - 基于 JWT 认证
|
||||
|
||||
路由前缀: /
|
||||
认证方式: JWT Token
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
from . import (
|
||||
model_controller,
|
||||
task_controller,
|
||||
test_controller,
|
||||
user_controller,
|
||||
auth_controller,
|
||||
workspace_controller,
|
||||
setup_controller,
|
||||
file_controller,
|
||||
document_controller,
|
||||
knowledge_controller,
|
||||
chunk_controller,
|
||||
knowledgeshare_controller,
|
||||
app_controller,
|
||||
upload_controller,
|
||||
memory_agent_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_storage_controller,
|
||||
memory_dashboard_controller,
|
||||
api_key_controller,
|
||||
release_share_controller,
|
||||
public_share_controller,
|
||||
multi_agent_controller,
|
||||
)
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
manager_router = APIRouter()
|
||||
|
||||
# 注册所有管理端路由
|
||||
manager_router.include_router(task_controller.router)
|
||||
manager_router.include_router(user_controller.router)
|
||||
manager_router.include_router(auth_controller.router)
|
||||
manager_router.include_router(workspace_controller.router)
|
||||
manager_router.include_router(workspace_controller.public_router) # 公开路由(无需认证)
|
||||
manager_router.include_router(setup_controller.router)
|
||||
manager_router.include_router(model_controller.router)
|
||||
manager_router.include_router(file_controller.router)
|
||||
manager_router.include_router(document_controller.router)
|
||||
manager_router.include_router(knowledge_controller.router)
|
||||
manager_router.include_router(chunk_controller.router)
|
||||
manager_router.include_router(test_controller.router)
|
||||
manager_router.include_router(knowledgeshare_controller.router)
|
||||
manager_router.include_router(app_controller.router)
|
||||
manager_router.include_router(upload_controller.router)
|
||||
manager_router.include_router(memory_agent_controller.router)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
manager_router.include_router(memory_storage_controller.router)
|
||||
manager_router.include_router(api_key_controller.router)
|
||||
manager_router.include_router(release_share_controller.router)
|
||||
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
manager_router.include_router(multi_agent_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
151
api/app/controllers/api_key_controller.py
Normal file
151
api/app/controllers/api_key_controller.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""API Key 管理接口 - 基于 JWT 认证"""
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models.user_model import User
|
||||
from app.core.response_utils import success
|
||||
from app.schemas import api_key_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.api_key_service import ApiKeyService
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
router = APIRouter(prefix="/apikeys", tags=["API Keys"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def create_api_key(
|
||||
data: api_key_schema.ApiKeyCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建 API Key
|
||||
|
||||
- 创建后返回明文 API Key(仅此一次)
|
||||
- 支持设置权限范围、速率限制、配额等
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key_obj, api_key = ApiKeyService.create_api_key(
|
||||
db,
|
||||
workspace_id=workspace_id,
|
||||
user_id=current_user.id,
|
||||
data=data
|
||||
)
|
||||
|
||||
# 返回包含明文 Key 的响应
|
||||
response_data = api_key_schema.ApiKeyResponse(
|
||||
**api_key_obj.__dict__,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
return success(data=response_data, msg="API Key 创建成功")
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def list_api_keys(
|
||||
type: api_key_schema.ApiKeyType = Query(None),
|
||||
is_active: bool = Query(None),
|
||||
resource_id: uuid.UUID = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
pagesize: int = Query(10, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""列出 API Keys"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
query = api_key_schema.ApiKeyQuery(
|
||||
type=type,
|
||||
is_active=is_active,
|
||||
resource_id=resource_id,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
result = ApiKeyService.list_api_keys(db, workspace_id, query)
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.get("/{api_key_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_api_key(
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取 API Key 详情"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
return success(data=api_key_schema.ApiKey.model_validate(api_key))
|
||||
|
||||
|
||||
@router.put("/{api_key_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def update_api_key(
|
||||
api_key_id: uuid.UUID,
|
||||
data: api_key_schema.ApiKeyUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新 API Key"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key = ApiKeyService.update_api_key(db, api_key_id, workspace_id, data)
|
||||
|
||||
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
|
||||
|
||||
|
||||
@router.delete("/{api_key_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def delete_api_key(
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除 API Key"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
ApiKeyService.delete_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
return success(msg="API Key 删除成功")
|
||||
|
||||
|
||||
@router.post("/{api_key_id}/regenerate", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def regenerate_api_key(
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""重新生成 API Key
|
||||
|
||||
- 生成新的 API Key 并返回明文(仅此一次)
|
||||
- 旧的 API Key 立即失效
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key_obj, api_key = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
# 返回包含明文 Key 的响应
|
||||
response_data = api_key_schema.ApiKeyResponse(
|
||||
**api_key_obj.__dict__,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
return success(data=response_data, msg="API Key 重新生成成功")
|
||||
|
||||
|
||||
@router.get("/{api_key_id}/stats", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_api_key_stats(
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取 API Key 使用统计"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
stats = ApiKeyService.get_stats(db, api_key_id, workspace_id)
|
||||
|
||||
return success(data=stats)
|
||||
716
api/app/controllers/app_controller.py
Normal file
716
api/app/controllers/app_controller.py
Normal file
@@ -0,0 +1,716 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models import User
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas import app_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services import app_service, workspace_service
|
||||
from app.services.app_service import AppService
|
||||
from app.services.agent_config_helper import enrich_agent_config
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard, workspace_access_guard
|
||||
from fastapi.responses import StreamingResponse
|
||||
from app.models.app_model import AppType
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.post("", summary="创建应用(可选创建 Agent 配置)")
|
||||
@cur_workspace_access_guard()
|
||||
def create_app(
|
||||
payload: app_schema.AppCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
app = app_service.create_app(db, user_id=current_user.id, workspace_id=workspace_id, data=payload)
|
||||
return success(data=app_schema.App.model_validate(app))
|
||||
|
||||
|
||||
@router.get("", summary="应用列表(分页)")
|
||||
@cur_workspace_access_guard()
|
||||
def list_apps(
|
||||
type: str | None = None,
|
||||
visibility: str | None = None,
|
||||
status: str | None = None,
|
||||
search: str | None = None,
|
||||
include_shared: bool = True,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""列出应用
|
||||
|
||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
items_orm, total = app_service.list_apps(
|
||||
db,
|
||||
workspace_id=workspace_id,
|
||||
type=type,
|
||||
visibility=visibility,
|
||||
status=status,
|
||||
search=search,
|
||||
include_shared=include_shared,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
)
|
||||
|
||||
# 使用 AppService 的转换方法来设置 is_shared 字段
|
||||
service = app_service.AppService(db)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
@router.get("/{app_id}", summary="获取应用详情")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取应用详细信息
|
||||
|
||||
- 支持获取本工作空间的应用
|
||||
- 支持获取分享给本工作空间的应用
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
app = service.get_app(app_id, workspace_id)
|
||||
|
||||
# 转换为 Schema 并设置 is_shared 字段
|
||||
app_schema_obj = service._convert_to_schema(app, workspace_id)
|
||||
return success(data=app_schema_obj)
|
||||
|
||||
|
||||
@router.put("/{app_id}", summary="更新应用基本信息")
|
||||
@cur_workspace_access_guard()
|
||||
def update_app(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AppUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
app = app_service.update_app(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
return success(data=app_schema.App.model_validate(app))
|
||||
|
||||
|
||||
@router.delete("/{app_id}", summary="删除应用")
|
||||
@cur_workspace_access_guard()
|
||||
def delete_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""删除应用
|
||||
|
||||
会级联删除:
|
||||
- Agent 配置
|
||||
- 发布版本
|
||||
- 会话和消息
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
logger.info(
|
||||
f"用户请求删除应用",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"user_id": str(current_user.id),
|
||||
"workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
|
||||
app_service.delete_app(db, app_id=app_id, workspace_id=workspace_id)
|
||||
|
||||
return success(msg="应用删除成功")
|
||||
|
||||
|
||||
@router.post("/{app_id}/copy", summary="复制应用")
|
||||
@cur_workspace_access_guard()
|
||||
def copy_app(
|
||||
app_id: uuid.UUID,
|
||||
new_name: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""复制应用(包括基础信息和配置)
|
||||
|
||||
- 复制应用的基础信息(名称、描述、图标等)
|
||||
- 复制 Agent 配置(如果是 agent 类型)
|
||||
- 新应用默认为草稿状态
|
||||
- 不影响原应用
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
logger.info(
|
||||
f"用户请求复制应用",
|
||||
extra={
|
||||
"source_app_id": str(app_id),
|
||||
"user_id": str(current_user.id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"new_name": new_name
|
||||
}
|
||||
)
|
||||
|
||||
service = AppService(db)
|
||||
new_app = service.copy_app(
|
||||
app_id=app_id,
|
||||
user_id=current_user.id,
|
||||
workspace_id=workspace_id,
|
||||
new_name=new_name
|
||||
)
|
||||
|
||||
return success(data=app_schema.App.model_validate(new_app), msg="应用复制成功")
|
||||
|
||||
|
||||
@router.put("/{app_id}/config", summary="更新 Agent 配置")
|
||||
@cur_workspace_access_guard()
|
||||
def update_agent_config(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AgentConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.update_agent_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
cfg = enrich_agent_config(cfg)
|
||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/config", summary="获取 Agent 配置")
|
||||
@cur_workspace_access_guard()
|
||||
def get_agent_config(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
# 配置总是存在(不存在时返回默认模板)
|
||||
cfg = enrich_agent_config(cfg)
|
||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||
|
||||
|
||||
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
||||
@cur_workspace_access_guard()
|
||||
def publish_app(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.PublishRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
release = app_service.publish(
|
||||
db,
|
||||
app_id=app_id,
|
||||
publisher_id=current_user.id,
|
||||
workspace_id=workspace_id,
|
||||
version_name = payload.version_name,
|
||||
release_notes=payload.release_notes
|
||||
)
|
||||
return success(data=app_schema.AppRelease.model_validate(release))
|
||||
|
||||
|
||||
@router.get("/{app_id}/release", summary="获取当前发布版本")
|
||||
@cur_workspace_access_guard()
|
||||
def get_current_release(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
release = app_service.get_current_release(db, app_id=app_id, workspace_id=workspace_id)
|
||||
if not release:
|
||||
return success(data=None)
|
||||
return success(data=app_schema.AppRelease.model_validate(release))
|
||||
|
||||
|
||||
@router.get("/{app_id}/releases", summary="列出历史发布版本(倒序)")
|
||||
@cur_workspace_access_guard()
|
||||
def list_releases(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
releases = app_service.list_releases(db, app_id=app_id, workspace_id=workspace_id)
|
||||
data = [app_schema.AppRelease.model_validate(r) for r in releases]
|
||||
return success(data=data)
|
||||
|
||||
|
||||
@router.post("/{app_id}/rollback/{version}", summary="回滚到指定版本")
|
||||
@cur_workspace_access_guard()
|
||||
def rollback(
|
||||
app_id: uuid.UUID,
|
||||
version: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
release = app_service.rollback(db, app_id=app_id, version=version, workspace_id=workspace_id)
|
||||
return success(data=app_schema.AppRelease.model_validate(release))
|
||||
|
||||
|
||||
@router.post("/{app_id}/share", summary="分享应用到其他工作空间")
|
||||
@cur_workspace_access_guard()
|
||||
def share_app(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.AppShareCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""分享应用到其他工作空间
|
||||
|
||||
- 只能分享自己工作空间的应用
|
||||
- 不能分享到自己的工作空间
|
||||
- 同一个应用不能重复分享到同一个工作空间
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
service = app_service.AppService(db)
|
||||
shares = service.share_app(
|
||||
app_id=app_id,
|
||||
target_workspace_ids=payload.target_workspace_ids,
|
||||
user_id=current_user.id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
return success(data=data, msg=f"应用已分享到 {len(shares)} 个工作空间")
|
||||
|
||||
|
||||
@router.delete("/{app_id}/share/{target_workspace_id}", summary="取消应用分享")
|
||||
@cur_workspace_access_guard()
|
||||
def unshare_app(
|
||||
app_id: uuid.UUID,
|
||||
target_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""取消应用分享
|
||||
|
||||
- 只能取消自己工作空间应用的分享
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
service = app_service.AppService(db)
|
||||
service.unshare_app(
|
||||
app_id=app_id,
|
||||
target_workspace_id=target_workspace_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return success(msg="应用分享已取消")
|
||||
|
||||
|
||||
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
||||
@cur_workspace_access_guard()
|
||||
def list_app_shares(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""列出应用的所有分享记录
|
||||
|
||||
- 只能查看自己工作空间应用的分享记录
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
service = app_service.AppService(db)
|
||||
shares = service.list_app_shares(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
return success(data=data)
|
||||
|
||||
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
||||
@cur_workspace_access_guard()
|
||||
async def draft_run(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.DraftRunRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
试运行 Agent,使用当前的草稿配置(未发布的配置)
|
||||
|
||||
- 不需要发布应用即可测试
|
||||
- 使用当前的 AgentConfig 配置
|
||||
- 支持流式和非流式返回
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
from app.services.app_service import AppService
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from sqlalchemy import select
|
||||
from app.core.exceptions import BusinessException
|
||||
|
||||
|
||||
service = AppService(db)
|
||||
|
||||
# 1. 验证应用
|
||||
app = service._get_app_or_404(app_id)
|
||||
if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT:
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
# 只读操作,允许访问共享应用
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
if app.type == AppType.AGENT:
|
||||
service._check_agent_config(app_id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 3. 获取模型配置
|
||||
model_config = None
|
||||
if agent_cfg.default_model_config_id:
|
||||
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||
if not model_config:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
async for event in draft_service.run_stream(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
# 非流式返回
|
||||
logger.debug(
|
||||
f"开始非流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
"has_conversation_id": bool(payload.conversation_id),
|
||||
"has_variables": bool(payload.variables)
|
||||
}
|
||||
)
|
||||
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
result = await draft_service.run(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"试运行返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict"
|
||||
}
|
||||
)
|
||||
|
||||
# 验证结果
|
||||
try:
|
||||
validated_result = app_schema.DraftRunResponse.model_validate(result)
|
||||
logger.debug(f"结果验证成功")
|
||||
return success(data=validated_result)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"结果验证失败",
|
||||
extra={
|
||||
"error": str(e),
|
||||
"error_type": str(type(e)),
|
||||
"result": str(result)[:200]
|
||||
}
|
||||
)
|
||||
raise
|
||||
elif app.type == AppType.MULTI_AGENT:
|
||||
# 1. 检查多智能体配置完整性
|
||||
service._check_multi_agent_config(app_id)
|
||||
|
||||
# 2. 构建多智能体运行请求
|
||||
from app.schemas.multi_agent_schema import MultiAgentRunRequest
|
||||
|
||||
multi_agent_request = MultiAgentRunRequest(
|
||||
message=payload.message,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables or {},
|
||||
use_llm_routing=True # 默认启用 LLM 路由
|
||||
)
|
||||
|
||||
# 3. 流式返回
|
||||
if payload.stream:
|
||||
logger.debug(
|
||||
f"开始多智能体流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
"has_conversation_id": bool(payload.conversation_id)
|
||||
}
|
||||
)
|
||||
|
||||
async def event_generator():
|
||||
"""多智能体流式事件生成器"""
|
||||
multiservice = MultiAgentService(db)
|
||||
|
||||
# 调用多智能体服务的流式方法
|
||||
async for event in multiservice.run_stream(
|
||||
app_id=app_id,
|
||||
request=multi_agent_request,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
|
||||
):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
# 4. 非流式返回
|
||||
logger.debug(
|
||||
f"开始多智能体非流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
"has_conversation_id": bool(payload.conversation_id)
|
||||
}
|
||||
)
|
||||
|
||||
multiservice = MultiAgentService(db)
|
||||
result = await multiservice.run(app_id, multi_agent_request)
|
||||
|
||||
logger.debug(
|
||||
f"多智能体试运行返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"has_response": "response" in result if isinstance(result, dict) else False
|
||||
}
|
||||
)
|
||||
|
||||
return success(
|
||||
data=result,
|
||||
msg="多 Agent 任务执行成功"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")
|
||||
@cur_workspace_access_guard()
|
||||
async def draft_run_compare(
|
||||
app_id: uuid.UUID,
|
||||
payload: app_schema.DraftRunCompareRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
多模型对比试运行
|
||||
|
||||
- 支持对比 1-5 个模型
|
||||
- 可以是不同的模型,也可以是同一模型的不同参数配置
|
||||
- 通过 model_parameters 覆盖默认参数
|
||||
- 支持并行或串行执行(非流式)
|
||||
- 支持流式返回(串行执行)
|
||||
- 返回每个模型的运行结果和性能对比
|
||||
|
||||
使用场景:
|
||||
1. 对比不同模型的效果(GPT-4 vs Claude vs Gemini)
|
||||
2. 调优模型参数(不同 temperature 的效果对比)
|
||||
3. 性能和成本分析
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
logger.info(
|
||||
f"多模型对比试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"model_count": len(payload.models),
|
||||
"parallel": payload.parallel,
|
||||
"stream": payload.stream
|
||||
}
|
||||
)
|
||||
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
from app.services.app_service import AppService
|
||||
from app.models import ModelConfig
|
||||
|
||||
service = AppService(db)
|
||||
|
||||
# 1. 验证应用和权限
|
||||
app = service._get_app_or_404(app_id)
|
||||
if app.type != "agent":
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
from sqlalchemy import select
|
||||
from app.models import AgentConfig
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 3. 验证所有模型配置
|
||||
model_configs = []
|
||||
for model_item in payload.models:
|
||||
model_config = db.get(ModelConfig, model_item.model_config_id)
|
||||
if not model_config:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
|
||||
|
||||
merged_parameters = {
|
||||
**(agent_cfg.model_parameters or {}),
|
||||
**(model_item.model_parameters or {})
|
||||
}
|
||||
|
||||
model_configs.append({
|
||||
"model_config": model_config,
|
||||
"parameters": merged_parameters,
|
||||
"label": model_item.label or model_config.name,
|
||||
"model_config_id": model_item.model_config_id,
|
||||
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||
})
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
async for event in draft_service.run_compare_stream(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
# 非流式返回
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
result = await draft_service.run_compare(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"多模型对比完成",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"successful": result["successful_count"],
|
||||
"failed": result["failed_count"]
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=app_schema.DraftRunCompareResponse(**result))
|
||||
195
api/app/controllers/auth_controller.py
Normal file
195
api/app/controllers/auth_controller.py
Normal file
@@ -0,0 +1,195 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.token_schema import Token, RefreshTokenRequest, TokenRequest
|
||||
from app.schemas.workspace_schema import InviteAcceptRequest
|
||||
from app.services import auth_service, user_service, workspace_service
|
||||
from app.core import security
|
||||
from app.core.config import settings
|
||||
from app.services.session_service import SessionService
|
||||
from app.core.logging_config import get_auth_logger, get_security_logger
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.dependencies import get_current_user, oauth2_scheme
|
||||
from app.models.user_model import User
|
||||
|
||||
# 获取专用日志器
|
||||
auth_logger = get_auth_logger()
|
||||
security_logger = get_security_logger()
|
||||
|
||||
router = APIRouter(tags=["Authentication"])
|
||||
|
||||
@router.post("/token", response_model=ApiResponse)
|
||||
async def login_for_access_token(
|
||||
form_data: TokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""用户登录获取token"""
|
||||
auth_logger.info(f"用户登录请求: {form_data.email}")
|
||||
|
||||
# 验证邀请码(如果提供)
|
||||
invite_info = None
|
||||
# 验证用户凭据或注册新用户
|
||||
user = None
|
||||
if form_data.invite:
|
||||
auth_logger.info(f"检测到邀请码: {form_data.invite[:8]}...")
|
||||
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
||||
|
||||
if not invite_info.is_valid:
|
||||
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
|
||||
|
||||
if invite_info.email != form_data.email:
|
||||
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
|
||||
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
||||
try:
|
||||
# 尝试认证用户
|
||||
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
||||
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
||||
if form_data.invite:
|
||||
auth_service.bind_workspace_with_invite(db=db,
|
||||
user=user,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id)
|
||||
except BusinessException as e:
|
||||
# 用户不存在且有邀请码,尝试注册
|
||||
if e.code == BizCode.USER_NOT_FOUND:
|
||||
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
|
||||
user = auth_service.register_user_with_invite(
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
elif e.code == BizCode.PASSWORD_ERROR:
|
||||
# 用户存在但密码错误
|
||||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
|
||||
else:
|
||||
# 其他认证失败情况,直接抛出
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
# 尝试认证用户
|
||||
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
||||
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
||||
|
||||
except BusinessException as e:
|
||||
|
||||
# 其他认证失败情况,直接抛出
|
||||
raise BusinessException(e.message,BizCode.LOGIN_FAILED)
|
||||
|
||||
# 创建 tokens
|
||||
access_token, access_token_id = security.create_access_token(subject=user.id)
|
||||
refresh_token, refresh_token_id = security.create_refresh_token(subject=user.id)
|
||||
|
||||
# 计算过期时间
|
||||
access_expires_at = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
# 单点登录会话管理
|
||||
if settings.ENABLE_SINGLE_SESSION:
|
||||
await SessionService.invalidate_old_session(user.id, access_token_id)
|
||||
await SessionService.set_user_active_session(user.id, access_token_id, access_expires_at)
|
||||
|
||||
# 更新最后登录时间
|
||||
user_service.update_last_login_time(db, user.id)
|
||||
|
||||
auth_logger.info(f"用户 {user.username} 登录成功")
|
||||
|
||||
return success(
|
||||
data=Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
expires_at=access_expires_at,
|
||||
refresh_expires_at=refresh_expires_at
|
||||
),
|
||||
msg="登录成功"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=ApiResponse)
|
||||
async def refresh_token(
|
||||
refresh_request: RefreshTokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""刷新token"""
|
||||
auth_logger.info("收到token刷新请求")
|
||||
|
||||
# 验证 refresh token
|
||||
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
||||
if not userId:
|
||||
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
|
||||
|
||||
# 检查用户是否存在
|
||||
user = auth_service.get_user_by_id(db, userId)
|
||||
if not user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 检查 refresh token 黑名单
|
||||
if settings.ENABLE_SINGLE_SESSION:
|
||||
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||||
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
||||
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
|
||||
|
||||
# 生成新 tokens
|
||||
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
||||
new_refresh_token, new_refresh_token_id = security.create_refresh_token(subject=user.id)
|
||||
|
||||
# 计算过期时间
|
||||
access_expires_at = datetime.now() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
refresh_expires_at = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
# 单点登录会话管理
|
||||
if settings.ENABLE_SINGLE_SESSION:
|
||||
# 将旧 refresh token 加入黑名单
|
||||
old_refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||||
if old_refresh_token_id:
|
||||
await SessionService.blacklist_token(old_refresh_token_id)
|
||||
|
||||
# 更新会话
|
||||
await SessionService.invalidate_old_session(user.id, new_access_token_id)
|
||||
await SessionService.set_user_active_session(user.id, new_access_token_id, access_expires_at)
|
||||
|
||||
auth_logger.info(f"用户 {user.id} token刷新成功")
|
||||
|
||||
return success(
|
||||
data=Token(
|
||||
access_token=new_access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
token_type="bearer",
|
||||
expires_at=access_expires_at,
|
||||
refresh_expires_at=refresh_expires_at
|
||||
),
|
||||
msg="token刷新成功"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=ApiResponse)
|
||||
async def logout(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""登出当前用户:加入token黑名单并清理会话"""
|
||||
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
||||
|
||||
token_id = security.get_token_id(token)
|
||||
if not token_id:
|
||||
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
|
||||
|
||||
# 加入黑名单
|
||||
await SessionService.blacklist_token(token_id)
|
||||
|
||||
# 清理会话
|
||||
if settings.ENABLE_SINGLE_SESSION:
|
||||
await SessionService.clear_user_session(current_user.username)
|
||||
|
||||
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
||||
return success(msg="登出成功")
|
||||
|
||||
447
api/app/controllers/chunk_controller.py
Normal file
447
api/app/controllers/chunk_controller.py
Normal file
@@ -0,0 +1,447 @@
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db import get_db
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.models.document_model import Document
|
||||
from app.models import knowledge_model, knowledgeshare_model
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.schemas import chunk_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/chunks",
|
||||
tags=["chunks"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/previewchunks", response_model=ApiResponse)
|
||||
async def get_preview_chunks(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Paged query document block preview list
|
||||
- Support filtering by document_id
|
||||
- Support keyword search for segmented content
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
api_logger.info(f"Paged query document block preview list: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Obtain knowledge base information
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or access is denied"
|
||||
)
|
||||
# 3. Check if the document exists
|
||||
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
|
||||
if not db_document:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The document does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 4. Check if the file exists
|
||||
db_file = file_service.get_file_by_id(db, file_id=db_document.file_id)
|
||||
if not db_file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 5. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||
file_path = os.path.join(
|
||||
settings.FILE_PATH,
|
||||
str(db_file.kb_id),
|
||||
str(db_file.parent_id),
|
||||
f"{db_file.id}{db_file.file_ext}"
|
||||
)
|
||||
|
||||
# 6. Check if the file exists
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found (possibly deleted)"
|
||||
)
|
||||
|
||||
# 7. Document parsing & segmentation
|
||||
def progress_callback(prog=None, msg=None):
|
||||
print(f"prog: {prog} msg: {msg}\n")
|
||||
# Prepare to configure vision_model information
|
||||
vision_model = QWenCV(
|
||||
key=db_knowledge.image2text.api_keys[0].api_key,
|
||||
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
||||
lang="Chinese", # Default to Chinese
|
||||
base_url=db_knowledge.image2text.api_keys[0].api_base
|
||||
)
|
||||
from app.core.rag.app.naive import chunk
|
||||
res = chunk(filename=file_path,
|
||||
from_page=0,
|
||||
to_page=5,
|
||||
callback=progress_callback,
|
||||
vision_model=vision_model,
|
||||
parser_config=db_document.parser_config,
|
||||
is_root=False)
|
||||
|
||||
start_index = (page - 1) * pagesize
|
||||
end_index = start_index + pagesize
|
||||
# Use slicing to obtain the data of the current page
|
||||
paginated_chunk_str_list = res[start_index:end_index]
|
||||
chunks = []
|
||||
for idx, item in enumerate(paginated_chunk_str_list):
|
||||
metadata = {
|
||||
"doc_id": uuid.uuid4().hex,
|
||||
"file_id": str(db_document.file_id),
|
||||
"file_name": db_document.file_name,
|
||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||
"document_id": str(db_document.id),
|
||||
"knowledge_id": str(db_document.kb_id),
|
||||
"sort_id": idx,
|
||||
"status": 1,
|
||||
}
|
||||
chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata))
|
||||
|
||||
# 8. Return structured response
|
||||
total = len(res)
|
||||
result = {
|
||||
"items": chunks,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
api_logger.info(f"Querying the document block preview list successful: total={total}, returned={len(chunks)} records")
|
||||
return success(data=result, msg="Querying the document block preview list succeeded")
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/chunks", response_model=ApiResponse)
|
||||
async def get_chunks(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Paged query document chunk list
|
||||
- Support filtering by document_id
|
||||
- Support keyword search for segmented content
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
api_logger.info(f"Paged query document chunk list: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Obtain knowledge base information
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing document chunk query")
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
total, items = vector_service.search_by_segment(document_id=str(document_id), query=keywords, pagesize=pagesize, page=page, asc=True)
|
||||
api_logger.info(f"Document chunk query successful: total={total}, returned={len(items)} records")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Document chunk query failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Query failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=result, msg="Query of document chunk list succeeded")
|
||||
|
||||
|
||||
@router.post("/{kb_id}/{document_id}/chunk", response_model=ApiResponse)
|
||||
async def create_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
create_data: chunk_schema.ChunkCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
create chunk
|
||||
"""
|
||||
api_logger.info(f"Create chunk request: kb_id={kb_id}, document_id={document_id}, content={create_data.content}, username: {current_user.username}")
|
||||
|
||||
# 1. Obtain knowledge base information
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or access is denied"
|
||||
)
|
||||
# 1. Obtain document information
|
||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||
if not db_document:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The document does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
# 2. Get the sort ID
|
||||
sort_id = 0
|
||||
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
|
||||
if items:
|
||||
sort_id = items[0].metadata["sort_id"]
|
||||
sort_id = sort_id + 1
|
||||
|
||||
doc_id = uuid.uuid4().hex
|
||||
metadata = {
|
||||
"doc_id": doc_id,
|
||||
"file_id": str(db_document.file_id),
|
||||
"file_name": db_document.file_name,
|
||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||
"document_id": str(document_id),
|
||||
"knowledge_id": str(kb_id),
|
||||
"sort_id": sort_id,
|
||||
"status": 1,
|
||||
}
|
||||
chunk = DocumentChunk(page_content=create_data.content, metadata=metadata)
|
||||
# 3. Segmented vector storage
|
||||
vector_service.add_chunks([chunk])
|
||||
|
||||
# 4.update chunk_num
|
||||
db_document.chunk_num += 1
|
||||
db.commit()
|
||||
|
||||
return success(data=chunk, msg="Document chunk creation successful")
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
async def get_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
doc_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve document chunk information based on doc_id
|
||||
"""
|
||||
api_logger.info(f"Obtain document chunk information: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, username: {current_user.username}")
|
||||
|
||||
# 1. Obtain knowledge base information
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or access is denied"
|
||||
)
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
total, items = vector_service.get_by_segment(doc_id=doc_id)
|
||||
if total:
|
||||
return success(data=items[0], msg="Document chunk query successful")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The document chunk does not exist or you do not have access"
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
async def update_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
doc_id: str,
|
||||
update_data: chunk_schema.ChunkUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Update document chunk content
|
||||
"""
|
||||
api_logger.info(f"Update document chunk content: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, content={update_data.content}, username: {current_user.username}")
|
||||
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or access is denied"
|
||||
)
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
total, items = vector_service.get_by_segment(doc_id=doc_id)
|
||||
if total:
|
||||
chunk = items[0]
|
||||
chunk.page_content = update_data.content
|
||||
vector_service.update_by_segment(chunk)
|
||||
return success(data=chunk, msg="The document chunk has been successfully updated")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The document chunk does not exist or you do not have access to it"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||
async def delete_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
doc_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
delete document chunk
|
||||
"""
|
||||
api_logger.info(f"Request to delete document chunk: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, username: {current_user.username}")
|
||||
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or access is denied"
|
||||
)
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
if vector_service.text_exists(doc_id):
|
||||
vector_service.delete_by_ids([doc_id])
|
||||
# 更新 chunk_num
|
||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||
db_document.chunk_num -= 1
|
||||
db.commit()
|
||||
return success(msg="The document chunk has been successfully deleted")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The document chunk does not exist or you do not have access to it"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/retrieve_type", response_model=ApiResponse)
|
||||
def get_retrieve_types():
|
||||
return success(msg="Successfully obtained the retrieval type", data=list(chunk_schema.RetrieveType))
|
||||
|
||||
|
||||
@router.post("/retrieval", response_model=Any, status_code=status.HTTP_200_OK)
|
||||
async def retrieve_chunks(
|
||||
retrieve_data: chunk_schema.ChunkRetrieve,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
retrieve chunk
|
||||
"""
|
||||
api_logger.info(f"retrieve chunk: query={retrieve_data.query}, username: {current_user.username}")
|
||||
|
||||
filters = [
|
||||
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
|
||||
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
|
||||
knowledge_model.Knowledge.chunk_num > 0,
|
||||
knowledge_model.Knowledge.status == 1
|
||||
]
|
||||
existing_ids = knowledge_service.get_chunded_knowledgeids(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
filters = [
|
||||
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
|
||||
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
|
||||
knowledge_model.Knowledge.chunk_num > 0,
|
||||
knowledge_model.Knowledge.status == 1
|
||||
]
|
||||
share_ids = knowledge_service.get_chunded_knowledgeids(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
if share_ids:
|
||||
filters = [
|
||||
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids)
|
||||
]
|
||||
items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
existing_ids.extend(items)
|
||||
if not existing_ids:
|
||||
return success(data=[], msg="retrieval successful")
|
||||
kb_id = existing_ids[0]
|
||||
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
|
||||
indices = ",".join(uuid_strs)
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or access is denied"
|
||||
)
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
# 1 participle search, 2 semantic search, 3 hybrid search
|
||||
match retrieve_data.retrieve_type:
|
||||
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
case chunk_schema.RetrieveType.SEMANTIC:
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
# Efficient deduplication
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
for doc in rs1 + rs2:
|
||||
if doc.metadata["doc_id"] not in seen_ids:
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
341
api/app/controllers/document_controller.py
Normal file
341
api/app/controllers/document_controller.py
Normal file
@@ -0,0 +1,341 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
import datetime
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.models import document_model
|
||||
from app.schemas import document_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import document_service, file_service, knowledge_service
|
||||
from app.controllers import file_controller
|
||||
from app.celery_app import celery_app
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/documents",
|
||||
tags=["documents"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{parent_id}/documents", response_model=ApiResponse)
|
||||
async def get_documents(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
||||
document_ids: Optional[str] = Query(None, description="document ids, separated by commas"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Paged query document list
|
||||
- Support filtering by kb_id and parent_id
|
||||
- Support keyword search for file names
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
api_logger.info(f"Query document list: kb_id={kb_id}, page={page}, pagesize={pagesize}, keywords={keywords}, document_ids={document_ids}, username: {current_user.username}")
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Construct query conditions
|
||||
filters = [
|
||||
document_model.Document.kb_id == kb_id,
|
||||
document_model.Document.status == 1
|
||||
]
|
||||
|
||||
if parent_id:
|
||||
files = file_service.get_files_by_parent_id(db=db, parent_id=parent_id, current_user=current_user)
|
||||
files_ids = [item.id for item in files]
|
||||
filters.append(document_model.Document.file_id.in_(files_ids))
|
||||
|
||||
# Keyword search (fuzzy matching of file name)
|
||||
if keywords:
|
||||
api_logger.debug(f"Add keyword search criteria: {keywords}")
|
||||
filters.append(document_model.Document.file_name.ilike(f"%{keywords}%"))
|
||||
# document ids
|
||||
if document_ids:
|
||||
filters.append(document_model.Document.id.in_(document_ids.split(',')))
|
||||
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing document paging query")
|
||||
total, items = document_service.get_documents_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"Document query successful: total={total}, returned={len(items)} records")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Document query failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Query failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=result, msg="Query of document list succeeded")
|
||||
|
||||
|
||||
@router.post("/document", response_model=ApiResponse)
|
||||
async def create_document(
|
||||
create_data: document_schema.DocumentCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
create document
|
||||
"""
|
||||
api_logger.info(f"Create document request: file_name={create_data.file_name}, kb_id={create_data.kb_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating a document: {create_data.file_name}")
|
||||
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
||||
api_logger.info(f"Document created successfully: {db_document.file_name} (ID: {db_document.id})")
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="Document creation successful")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Document creation failed: {create_data.file_name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{document_id}", response_model=ApiResponse)
|
||||
async def get_document(
|
||||
document_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve document information based on document_id
|
||||
"""
|
||||
api_logger.info(f"Obtain document information: document_id={document_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query document information from the database
|
||||
api_logger.debug(f"query documentation: {document_id}")
|
||||
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
|
||||
if not db_document:
|
||||
api_logger.warning(f"The document does not exist or you do not have access: document_id={document_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The document does not exist or you do not have access"
|
||||
)
|
||||
|
||||
api_logger.info(f"Document query successful: {db_document.file_name} (ID: {db_document.id})")
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="Successfully obtained document information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Document query failed: document_id={document_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{document_id}", response_model=ApiResponse)
|
||||
async def update_document(
|
||||
document_id: uuid.UUID,
|
||||
update_data: document_schema.DocumentUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Update document information
|
||||
"""
|
||||
# 1. Check if the document exists
|
||||
api_logger.debug(f"Query the document to be updated: {document_id}")
|
||||
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
|
||||
|
||||
if not db_document:
|
||||
api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The document does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. If updating the status, synchronize the document status switch to whether it can be retrieved from the vector database
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
if "status" in update_dict:
|
||||
new_status = update_dict["status"]
|
||||
if new_status != db_document.status:
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
vector_service.change_status_by_document_id(document_id=str(document_id), status=new_status)
|
||||
|
||||
# 3. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the document fields: {document_id}")
|
||||
updated_fields = []
|
||||
for field, value in update_dict.items():
|
||||
if hasattr(db_document, field):
|
||||
old_value = getattr(db_document, field)
|
||||
if old_value != value:
|
||||
# update value
|
||||
setattr(db_document, field, value)
|
||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
db_document.updated_at = datetime.datetime.now()
|
||||
|
||||
# 4. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_document)
|
||||
api_logger.info(f"The document has been successfully updated: {db_document.file_name} (ID: {db_document.id})")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"Document update failed: document_id={document_id} - {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Document update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 5. Return the updated document
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="Document information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{document_id}", response_model=ApiResponse)
|
||||
async def delete_document(
|
||||
document_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Delete document
|
||||
"""
|
||||
api_logger.info(f"Request to delete document: document_id={document_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check if the document exists
|
||||
api_logger.debug(f"Check whether the document exists: {document_id}")
|
||||
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
|
||||
|
||||
if not db_document:
|
||||
api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The document does not exist or you do not have permission to access it"
|
||||
)
|
||||
file_id = db_document.file_id
|
||||
|
||||
# 2. Delete document
|
||||
api_logger.debug(f"Perform document delete: {db_document.file_name} (ID: {document_id})")
|
||||
db.delete(db_document)
|
||||
db.commit()
|
||||
|
||||
# 3. Delete file
|
||||
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user)
|
||||
|
||||
# 4. Delete vector index
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
||||
|
||||
api_logger.info(f"The document has been successfully deleted: {db_document.file_name} (ID: {document_id})")
|
||||
return success(msg="The document has been successfully deleted")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to delete from the document: document_id={document_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/{document_id}/chunks", response_model=ApiResponse)
|
||||
async def parse_documents(
|
||||
document_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
parse document
|
||||
"""
|
||||
api_logger.info(f"Request to parse document: document_id={document_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check if the document exists
|
||||
api_logger.debug(f"Check whether the document exists: {document_id}")
|
||||
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
|
||||
|
||||
if not db_document:
|
||||
api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The document does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Check if the file exists
|
||||
api_logger.debug(f"Check whether the file exists: {db_document.file_id}")
|
||||
db_file = file_service.get_file_by_id(db, file_id=db_document.file_id)
|
||||
|
||||
if not db_file:
|
||||
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={db_document.file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 3. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||
file_path = os.path.join(
|
||||
settings.FILE_PATH,
|
||||
str(db_file.kb_id),
|
||||
str(db_file.parent_id),
|
||||
f"{db_file.id}{db_file.file_ext}"
|
||||
)
|
||||
|
||||
# 4. Check if the file exists
|
||||
if not os.path.exists(file_path):
|
||||
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found (possibly deleted)"
|
||||
)
|
||||
|
||||
# 5. Obtain knowledge base information
|
||||
api_logger.info( f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={db_document.kb_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 6. Task: Document parsing, vectorization, and storage
|
||||
# from app.tasks import parse_document
|
||||
# parse_document(file_path, document_id)
|
||||
task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id])
|
||||
result = {
|
||||
"task_id": task.id
|
||||
}
|
||||
return success(data=result, msg="Task accepted. The document is being processed in the background.")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to parse document: document_id={document_id} - {str(e)}")
|
||||
raise
|
||||
453
api/app/controllers/file_controller.py
Normal file
453
api/app/controllers/file_controller.py
Normal file
@@ -0,0 +1,453 @@
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.models import file_model
|
||||
from app.schemas import file_schema, document_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import file_service, document_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/files",
|
||||
tags=["files"]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{parent_id}/files", response_model=ApiResponse)
|
||||
async def get_files(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Paged query file list
|
||||
- Support filtering by kb_id and parent_id
|
||||
- Support keyword search for file names
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Construct query conditions
|
||||
filters = [
|
||||
file_model.File.kb_id == kb_id
|
||||
]
|
||||
if parent_id:
|
||||
filters.append(file_model.File.parent_id == parent_id)
|
||||
# Keyword search (fuzzy matching of file name)
|
||||
if keywords:
|
||||
filters.append(file_model.File.file_name.ilike(f"%{keywords}%"))
|
||||
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing file paging query")
|
||||
total, items = file_service.get_files_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"File query successful: total={total}, returned={len(items)} records")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Query failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=result, msg="Query of file list succeeded")
|
||||
|
||||
|
||||
@router.post("/folder", response_model=ApiResponse)
|
||||
def create_folder(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
folder_name: str = '/',
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Create a new folder
|
||||
"""
|
||||
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating a folder: {folder_name}")
|
||||
create_folder = file_schema.FileCreate(
|
||||
kb_id=kb_id,
|
||||
created_by=current_user.id,
|
||||
parent_id=parent_id,
|
||||
file_name=folder_name,
|
||||
file_ext='folder',
|
||||
file_size=0,
|
||||
)
|
||||
db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user)
|
||||
api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})")
|
||||
return success(data=file_schema.File.model_validate(db_file), msg="Folder creation successful")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/file", response_model=ApiResponse)
|
||||
async def upload_file(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
upload file
|
||||
"""
|
||||
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}, username: {current_user.username}")
|
||||
|
||||
# Read the contents of the file
|
||||
contents = await file.read()
|
||||
# Check file size
|
||||
file_size = len(contents)
|
||||
print(f"file size: {file_size} byte")
|
||||
if file_size == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The file is empty."
|
||||
)
|
||||
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE}byte limit"
|
||||
)
|
||||
|
||||
# Extract the extension using `os.path.splitext`
|
||||
_, file_extension = os.path.splitext(file.filename)
|
||||
upload_file = file_schema.FileCreate(
|
||||
kb_id=kb_id,
|
||||
created_by=current_user.id,
|
||||
parent_id=parent_id,
|
||||
file_name=file.filename,
|
||||
file_ext=file_extension.lower(),
|
||||
file_size=file_size,
|
||||
)
|
||||
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
|
||||
|
||||
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
|
||||
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||
save_path = os.path.join(save_dir, f"{db_file.id}{file_extension}")
|
||||
|
||||
# Save file
|
||||
with open(save_path, "wb") as f:
|
||||
f.write(contents)
|
||||
|
||||
# Verify whether the file has been saved successfully
|
||||
if not os.path.exists(save_path):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="File save failed"
|
||||
)
|
||||
|
||||
# Create a document
|
||||
create_data = document_schema.DocumentCreate(
|
||||
kb_id=kb_id,
|
||||
created_by=current_user.id,
|
||||
file_id=db_file.id,
|
||||
file_name=db_file.file_name,
|
||||
file_ext=db_file.file_ext,
|
||||
file_size=db_file.file_size,
|
||||
file_meta={},
|
||||
parser_id="naive",
|
||||
parser_config={
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": "\n",
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"html4excel": "false"
|
||||
}
|
||||
)
|
||||
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
||||
|
||||
api_logger.info(f"File upload successfully: {file.filename} (file_id: {db_file.id}, document_id: {db_document.id})")
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="File upload successful")
|
||||
|
||||
|
||||
@router.post("/customtext", response_model=ApiResponse)
|
||||
async def custom_text(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
create_data: file_schema.CustomTextFileCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
custom text
|
||||
"""
|
||||
api_logger.info(f"custom text upload request: kb_id={kb_id}, parent_id={parent_id}, title={create_data.title}, content={create_data.content}, username: {current_user.username}")
|
||||
|
||||
# Check file content size
|
||||
# 将内容编码为字节(UTF-8)
|
||||
content_bytes = create_data.content.encode('utf-8')
|
||||
file_size = len(content_bytes)
|
||||
print(f"file size: {file_size} byte")
|
||||
if file_size == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The content is empty."
|
||||
)
|
||||
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit"
|
||||
)
|
||||
|
||||
upload_file = file_schema.FileCreate(
|
||||
kb_id=kb_id,
|
||||
created_by=current_user.id,
|
||||
parent_id=parent_id,
|
||||
file_name=f"{create_data.title}.txt",
|
||||
file_ext=".txt",
|
||||
file_size=file_size,
|
||||
)
|
||||
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
|
||||
|
||||
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
|
||||
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||
save_path = os.path.join(save_dir, f"{db_file.id}.txt")
|
||||
|
||||
# Save file
|
||||
with open(save_path, "wb") as f:
|
||||
f.write(content_bytes)
|
||||
|
||||
# Verify whether the file has been saved successfully
|
||||
if not os.path.exists(save_path):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="File save failed"
|
||||
)
|
||||
|
||||
# Create a document
|
||||
create_document_data = document_schema.DocumentCreate(
|
||||
kb_id=kb_id,
|
||||
created_by=current_user.id,
|
||||
file_id=db_file.id,
|
||||
file_name=db_file.file_name,
|
||||
file_ext=db_file.file_ext,
|
||||
file_size=db_file.file_size,
|
||||
file_meta={},
|
||||
parser_id="naive",
|
||||
parser_config={
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": "\n",
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"html4excel": "false"
|
||||
}
|
||||
)
|
||||
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
||||
|
||||
api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})")
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
||||
|
||||
|
||||
@router.get("/{file_id}", response_model=Any)
|
||||
async def get_file(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
) -> Any:
|
||||
"""
|
||||
Download the file based on the file_id
|
||||
- Query file information from the database
|
||||
- Construct the file path and check if it exists
|
||||
- Return a FileResponse to download the file
|
||||
"""
|
||||
api_logger.info(f"Download the file based on the file_id: file_id={file_id}")
|
||||
|
||||
# 1. Query file information from the database
|
||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||
if not db_file:
|
||||
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||
file_path = os.path.join(
|
||||
settings.FILE_PATH,
|
||||
str(db_file.kb_id),
|
||||
str(db_file.parent_id),
|
||||
f"{db_file.id}{db_file.file_ext}"
|
||||
)
|
||||
|
||||
# 3. Check if the file exists
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found (possibly deleted)"
|
||||
)
|
||||
|
||||
# 4.Return FileResponse (automatically handle download)
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
filename=db_file.file_name, # Use original file name
|
||||
media_type="application/octet-stream" # Universal binary stream type
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{file_id}", response_model=ApiResponse)
|
||||
async def update_file(
|
||||
file_id: uuid.UUID,
|
||||
update_data: file_schema.FileUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Update file information (such as file name)
|
||||
- Only specified fields such as file_name are allowed to be modified
|
||||
"""
|
||||
api_logger.debug(f"Query the file to be updated: {file_id}")
|
||||
|
||||
# 1. Check if the file exists
|
||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||
|
||||
if not db_file:
|
||||
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the file fields: {file_id}")
|
||||
updated_fields = []
|
||||
for field, value in update_data.items():
|
||||
if hasattr(db_file, field):
|
||||
old_value = getattr(db_file, field)
|
||||
if old_value != value:
|
||||
# update value
|
||||
setattr(db_file, field, value)
|
||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 3. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_file)
|
||||
api_logger.info(f"The file has been successfully updated: {db_file.file_name} (ID: {db_file.id})")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"File update failed: file_id={file_id} - {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"File update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return the updated file
|
||||
return success(data=file_schema.File.model_validate(db_file), msg="File information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{file_id}", response_model=ApiResponse)
|
||||
async def delete_file(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Delete a file or folder
|
||||
"""
|
||||
api_logger.info(f"Request to delete file: file_id={file_id}, username: {current_user.username}")
|
||||
await _delete_file(db=db, file_id=file_id, current_user=current_user)
|
||||
return success(msg="File deleted successfully")
|
||||
|
||||
async def _delete_file(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> None:
|
||||
"""
|
||||
Delete a file or folder
|
||||
"""
|
||||
# 1. Check if the file exists
|
||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||
|
||||
if not db_file:
|
||||
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Construct physical path
|
||||
file_path = Path(
|
||||
settings.FILE_PATH,
|
||||
str(db_file.kb_id),
|
||||
str(db_file.id)
|
||||
) if db_file.file_ext == 'folder' else Path(
|
||||
settings.FILE_PATH,
|
||||
str(db_file.kb_id),
|
||||
str(db_file.parent_id),
|
||||
f"{db_file.id}{db_file.file_ext}"
|
||||
)
|
||||
|
||||
# 3. Delete physical files/folders
|
||||
try:
|
||||
if file_path.exists():
|
||||
if db_file.file_ext == 'folder':
|
||||
shutil.rmtree(file_path) # Recursively delete folders
|
||||
else:
|
||||
file_path.unlink() # Delete a single file
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete physical file/folder: {str(e)}"
|
||||
)
|
||||
|
||||
# 4.Delete db_file
|
||||
if db_file.file_ext == 'folder':
|
||||
db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete()
|
||||
db.delete(db_file)
|
||||
db.commit()
|
||||
305
api/app/controllers/knowledge_controller.py
Normal file
305
api/app/controllers/knowledge_controller.py
Normal file
@@ -0,0 +1,305 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.models import knowledge_model, document_model, file_model
|
||||
from app.schemas import knowledge_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import knowledge_service, document_service
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/knowledges",
|
||||
tags=["knowledges"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/knowledgetype", response_model=ApiResponse)
|
||||
def get_knowledge_types():
|
||||
return success(msg="Successfully obtained the knowledge type", data=list(knowledge_model.KnowledgeType))
|
||||
|
||||
|
||||
@router.get("/permissiontype", response_model=ApiResponse)
|
||||
def get_permission_types():
|
||||
return success(msg="Successfully obtained the knowledge permission type", data=list(knowledge_model.PermissionType))
|
||||
|
||||
|
||||
@router.get("/parsertype", response_model=ApiResponse)
|
||||
def get_parser_types():
|
||||
return success(msg="Successfully obtained the knowledge parser type", data=list(knowledge_model.ParserType))
|
||||
|
||||
|
||||
@router.get("/knowledges", response_model=ApiResponse)
|
||||
async def get_knowledges(
|
||||
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id"),
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (knowledge base name)"),
|
||||
kb_ids: Optional[str] = Query(None, description="Knowledge base ids, separated by commas"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the knowledge base list in pages
|
||||
- Support filtering by parent_id
|
||||
- Support keyword search for knowledge base names
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + file list
|
||||
"""
|
||||
api_logger.info(f"Query knowledge base list: workspace_id={current_user.current_workspace_id}, page={page}, pagesize={pagesize}, keywords={keywords}, kb_ids={kb_ids}, username: {current_user.username}")
|
||||
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Construct query conditions
|
||||
filters = [
|
||||
knowledge_model.Knowledge.workspace_id == current_user.current_workspace_id
|
||||
]
|
||||
if parent_id:
|
||||
filters.append(knowledge_model.Knowledge.parent_id == parent_id)
|
||||
|
||||
# Keyword search (fuzzy matching of knowledge base name)
|
||||
if keywords:
|
||||
api_logger.debug(f"Add keyword search criteria: {keywords}")
|
||||
filters.append(
|
||||
or_(
|
||||
knowledge_model.Knowledge.name.ilike(f"%{keywords}%"),
|
||||
knowledge_model.Knowledge.description.ilike(f"%{keywords}%")
|
||||
)
|
||||
)
|
||||
# Knowledge base ids
|
||||
if kb_ids:
|
||||
filters.append(knowledge_model.Knowledge.id.in_(kb_ids.split(',')))
|
||||
else:
|
||||
filters.append(knowledge_model.Knowledge.status != 2)
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing knowledge base paging query")
|
||||
total, items = knowledge_service.get_knowledges_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"Knowledge base query successful: total={total}, returned={len(items)} records")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Knowledge base query failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Query failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page*pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=result, msg="Query of knowledge base list successful")
|
||||
|
||||
|
||||
@router.post("/knowledge", response_model=ApiResponse)
|
||||
async def create_knowledge(
|
||||
create_data: knowledge_schema.KnowledgeCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
create knowledge
|
||||
"""
|
||||
api_logger.info(f"Request to create a knowledge base: name={create_data.name}, workspace_id={current_user.current_workspace_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating the knowledge base: {create_data.name}")
|
||||
# 1. Check if the knowledge base name already exists
|
||||
db_knowledge_exist = knowledge_service.get_knowledge_by_name(db, name=create_data.name, current_user=current_user)
|
||||
if db_knowledge_exist:
|
||||
api_logger.warning(f"The knowledge base name already exists: {create_data.name}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The knowledge base name already exists: {create_data.name}"
|
||||
)
|
||||
db_knowledge = knowledge_service.create_knowledge(db=db, knowledge=create_data, current_user=current_user)
|
||||
api_logger.info(f"The knowledge base has been successfully created: {db_knowledge.name} (ID: {db_knowledge.id})")
|
||||
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base has been successfully created")
|
||||
except Exception as e:
|
||||
api_logger.error(f"The creation of the knowledge base failed: {create_data.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{knowledge_id}", response_model=ApiResponse)
|
||||
async def get_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve knowledge base information based on knowledge_id
|
||||
"""
|
||||
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query knowledge base information from the database
|
||||
api_logger.debug(f"Query knowledge base: {knowledge_id}")
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"Knowledge base query successful: {db_knowledge.name} (ID: {db_knowledge.id})")
|
||||
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="Successfully obtained knowledge base information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Knowledge base query failed: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{knowledge_id}", response_model=ApiResponse)
|
||||
async def update_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
update_data: knowledge_schema.KnowledgeUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
api_logger.info(f"Update knowledge base request: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||
db_knowledge = await _update_knowledge(knowledge_id=knowledge_id, update_data=update_data, db=db, current_user=current_user)
|
||||
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base information has been successfully updated")
|
||||
|
||||
|
||||
async def _update_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
update_data: knowledge_schema.KnowledgeUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> knowledge_schema.Knowledge:
|
||||
"""
|
||||
Update knowledge base information
|
||||
"""
|
||||
try:
|
||||
# 1. Check whether the knowledge base exists
|
||||
api_logger.debug(f"Query the knowledge base to be updated: {knowledge_id}")
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
|
||||
|
||||
if not db_knowledge:
|
||||
api_logger.warning(f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. If updating the embedding_id, delete the knowledge base vector index, reset all document parsing progress to 0, and set chunk_num to 0
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
if "name" in update_dict:
|
||||
name = update_dict["name"]
|
||||
if name != db_knowledge.name:
|
||||
# Check if the knowledge base name already exists
|
||||
db_knowledge_exist = knowledge_service.get_knowledge_by_name(db, name=name, current_user=current_user)
|
||||
if db_knowledge_exist:
|
||||
api_logger.warning(f"The knowledge base name already exists: {name}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The knowledge base name already exists: {name}"
|
||||
)
|
||||
if "embedding_id" in update_dict:
|
||||
embedding_id = update_dict["embedding_id"]
|
||||
if embedding_id != db_knowledge.embedding_id:
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
vector_service.delete()
|
||||
document_service.reset_documents_progress_by_kb_id(db, kb_id=db_knowledge.id, current_user=current_user)
|
||||
|
||||
# 2. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the knowledge base fields: {knowledge_id}")
|
||||
updated_fields = []
|
||||
for field, value in update_data.dict(exclude_unset=True).items():
|
||||
if hasattr(db_knowledge, field):
|
||||
old_value = getattr(db_knowledge, field)
|
||||
if old_value != value:
|
||||
# update value
|
||||
setattr(db_knowledge, field, value)
|
||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
db_knowledge.updated_at = datetime.datetime.now()
|
||||
|
||||
# 3. Save to database
|
||||
db.commit()
|
||||
db.refresh(db_knowledge)
|
||||
api_logger.info(f"The knowledge base has been successfully updated: {db_knowledge.name} (ID: {db_knowledge.id})")
|
||||
|
||||
# 4. Return the updated knowledge base
|
||||
return db_knowledge
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"Knowledge base update failed: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Knowledge base update failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{knowledge_id}", response_model=ApiResponse)
|
||||
async def delete_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Soft-delete knowledge base
|
||||
"""
|
||||
api_logger.info(f"Request to delete knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the knowledge base exists
|
||||
api_logger.debug(f"Check whether the knowledge base exists: {knowledge_id}")
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
|
||||
|
||||
if not db_knowledge:
|
||||
api_logger.warning(f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Soft-delete knowledge base
|
||||
api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})")
|
||||
db_knowledge.status = 2
|
||||
db.commit()
|
||||
api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
|
||||
return success(msg="The knowledge base has been successfully deleted")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to delete from the knowledge base: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise
|
||||
199
api/app/controllers/knowledgeshare_controller.py
Normal file
199
api/app/controllers/knowledgeshare_controller.py
Normal file
@@ -0,0 +1,199 @@
|
||||
from typing import Optional
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.models import knowledgeshare_model, knowledge_model
|
||||
from app.schemas import knowledgeshare_schema, knowledge_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import knowledgeshare_service, knowledge_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/knowledgeshares",
|
||||
tags=["knowledgeshares"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{kb_id}/knowledgeshares", response_model=ApiResponse)
|
||||
async def get_knowledgeshares(
|
||||
kb_id: uuid.UUID,
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Paged query knowledge base sharing list
|
||||
- Support filtering by kb_id
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + share list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query knowledge base sharing list: workspace_id={current_user.current_workspace_id}, kb_id={kb_id}, page={page}, pagesize={pagesize}, username: {current_user.username}")
|
||||
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Construct query conditions
|
||||
filters = [
|
||||
knowledgeshare_model.KnowledgeShare.source_workspace_id == current_user.current_workspace_id,
|
||||
knowledgeshare_model.KnowledgeShare.source_kb_id == kb_id
|
||||
]
|
||||
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing knowledge base sharing and paging query")
|
||||
total, items = knowledgeshare_service.get_knowledgeshares_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"Knowledge base sharing query successful: total={total}, returned={len(items)} records")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Knowledge base sharing query failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Query failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=result, msg="Query of knowledge base sharing list successful")
|
||||
|
||||
|
||||
@router.post("/knowledgeshare", response_model=ApiResponse)
|
||||
async def create_knowledgeshare(
|
||||
create_data: knowledgeshare_schema.KnowledgeShareCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
create knowledgeshare
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Create a knowledge base sharing request: source_kb_id={create_data.source_kb_id}, source_workspace_id={current_user.current_workspace_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1.Create a knowledge base with permission_id=knowledge_model.PermissionType.Share
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=create_data.source_kb_id, current_user=current_user)
|
||||
knowledge = knowledge_schema.KnowledgeCreate(
|
||||
workspace_id=create_data.target_workspace_id,
|
||||
created_by=current_user.id,
|
||||
parent_id=create_data.target_workspace_id,
|
||||
name=db_knowledge.name,
|
||||
description=db_knowledge.description,
|
||||
avatar=db_knowledge.avatar,
|
||||
type=db_knowledge.type,
|
||||
permission_id=knowledge_model.PermissionType.Share,
|
||||
embedding_id=db_knowledge.embedding_id,
|
||||
reranker_id=db_knowledge.reranker_id,
|
||||
llm_id=db_knowledge.llm_id,
|
||||
image2text_id=db_knowledge.image2text_id,
|
||||
doc_num=db_knowledge.doc_num,
|
||||
chunk_num=db_knowledge.chunk_num,
|
||||
parser_id=db_knowledge.parser_id,
|
||||
parser_config=db_knowledge.parser_config
|
||||
)
|
||||
db_knowledge = knowledge_service.create_knowledge(db=db, knowledge=knowledge, current_user=current_user)
|
||||
# 2. Create a knowledge base for sharing
|
||||
api_logger.debug(f"Start creating the knowledge base sharing: {db_knowledge.name}")
|
||||
create_data.target_kb_id = db_knowledge.id
|
||||
db_knowledgeshare = knowledgeshare_service.create_knowledgeshare(db=db, knowledgeshare=create_data, current_user=current_user)
|
||||
api_logger.info(f"The knowledge base sharing has been successfully created: (ID: {db_knowledgeshare.id})")
|
||||
return success(data=knowledgeshare_schema.KnowledgeShare.model_validate(db_knowledgeshare), msg="The knowledge base sharing has been successfully created")
|
||||
except Exception as e:
|
||||
api_logger.error(f"The creation of the knowledge base sharing failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{knowledgeshare_id}", response_model=ApiResponse)
|
||||
async def get_knowledgeshare(
|
||||
knowledgeshare_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve knowledge base sharing information based on knowledgeshare_id
|
||||
"""
|
||||
api_logger.info(f"Obtain details of the knowledge base sharing: knowledgeshare_id={knowledgeshare_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query knowledge base sharing information from the database
|
||||
api_logger.debug(f"Query knowledge base sharing: {knowledgeshare_id}")
|
||||
db_knowledgeshare = knowledgeshare_service.get_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user)
|
||||
if not db_knowledgeshare:
|
||||
api_logger.warning(f"The knowledge base sharing does not exist or access is denied: knowledgeshare_id={knowledgeshare_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base sharing does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"Knowledge base sharing query successful: (ID: {db_knowledgeshare.id})")
|
||||
return success(data=knowledgeshare_schema.KnowledgeShare.model_validate(db_knowledgeshare), msg="Successfully obtained knowledge base sharing information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Knowledge base sharing query failed: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.delete("/{knowledgeshare_id}", response_model=ApiResponse)
|
||||
async def delete_knowledgeshare(
|
||||
knowledgeshare_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Delete knowledge base sharing
|
||||
"""
|
||||
api_logger.info(f"Delete knowledge base sharing request: knowledgeshare_id={knowledgeshare_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query knowledge base sharing information from the database
|
||||
api_logger.debug(f"Query knowledge base sharing: {knowledgeshare_id}")
|
||||
db_knowledgeshare = knowledgeshare_service.get_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user)
|
||||
if not db_knowledgeshare:
|
||||
api_logger.warning(f"The knowledge base sharing does not exist or access is denied: knowledgeshare_id={knowledgeshare_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base sharing does not exist or access is denied"
|
||||
)
|
||||
# 2. Deleting shared knowledge base
|
||||
knowledge_service.delete_knowledge_by_id(db, knowledge_id=db_knowledgeshare.target_kb_id ,current_user=current_user)
|
||||
# 3. Delete knowledge base sharing
|
||||
api_logger.debug(f"perform knowledge base sharing delete: (ID: {knowledgeshare_id})")
|
||||
|
||||
knowledgeshare_service.delete_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user)
|
||||
api_logger.info(f"The knowledge base sharing has been successfully deleted: (ID: {knowledgeshare_id})")
|
||||
return success(msg="The knowledge base sharing has been successfully deleted")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to delete from the knowledge base sharing: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
|
||||
raise
|
||||
802
api/app/controllers/memory_agent_controller.py
Normal file
802
api/app/controllers/memory_agent_controller.py
Normal file
@@ -0,0 +1,802 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, Depends, Query, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
from app.db import get_db
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.models import ModelApiKey, Knowledge
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
|
||||
from app.celery_app import celery_app
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services import task_service, workspace_service
|
||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from fastapi import APIRouter, Depends, File, UploadFile, Form
|
||||
from app.repositories import knowledge_repository
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Initialize service
|
||||
memory_agent_service = MemoryAgentService()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory",
|
||||
tags=["Memory"],
|
||||
)
|
||||
|
||||
|
||||
def validate_config_id(config_id: int, db: Session) -> int:
|
||||
"""
|
||||
Validate and ensure config_id is available, valid, and exists in database.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID to validate
|
||||
db: Database session for checking existence
|
||||
|
||||
Returns:
|
||||
int: Validated config_id
|
||||
|
||||
Raises:
|
||||
ValueError: If config_id is None, invalid, or doesn't exist in database
|
||||
"""
|
||||
if config_id is None:
|
||||
api_logger.info(f"config_id is required but was not provided")
|
||||
config_id = os.getenv('config_id')
|
||||
if config_id is None:
|
||||
raise ValueError("config_id is required but was not provided")
|
||||
|
||||
|
||||
# Check if config exists in database
|
||||
try:
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.models.models_model import ModelConfig
|
||||
|
||||
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if config is None:
|
||||
error_msg = f"Configuration with config_id={config_id} does not exist in database"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Validate llm_id exists and is usable
|
||||
if config.llm_id:
|
||||
try:
|
||||
llm_config = db.query(ModelConfig).filter(ModelConfig.id == config.llm_id).first()
|
||||
if llm_config is None:
|
||||
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) does not exist"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
if not llm_config.is_active:
|
||||
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) is not active"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
api_logger.debug(f"LLM validation successful: llm_id={config.llm_id}, name={llm_config.name}")
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Error validating LLM model: {str(e)}"
|
||||
api_logger.error(error_msg, exc_info=True)
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
api_logger.error(f"Config {config_id} has no llm_id set")
|
||||
raise ValueError(f"Config {config_id} has no llm_id set")
|
||||
|
||||
# Validate embedding_id exists and is usable
|
||||
if config.embedding_id:
|
||||
try:
|
||||
embedding_config = db.query(ModelConfig).filter(ModelConfig.id == config.embedding_id).first()
|
||||
if embedding_config is None:
|
||||
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) does not exist"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
if not embedding_config.is_active:
|
||||
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) is not active"
|
||||
api_logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
api_logger.debug(f"Embedding validation successful: embedding_id={config.embedding_id}, name={embedding_config.name}")
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Error validating embedding model: {str(e)}"
|
||||
api_logger.error(error_msg, exc_info=True)
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
api_logger.error(f"Config {config_id} has no embedding_id set")
|
||||
raise ValueError(f"Config {config_id} has no embedding_id set")
|
||||
|
||||
api_logger.info(f"Config validation successful: config_id={config_id}, config_name={config.config_name}, llm_id={config.llm_id}, embedding_id={config.embedding_id}")
|
||||
return config_id
|
||||
except ValueError:
|
||||
# Re-raise ValueError from above
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Database error while validating config_id={config_id}: {str(e)}"
|
||||
api_logger.error(error_msg, exc_info=True)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
@router.get("/health/status", response_model=ApiResponse)
|
||||
async def get_health_status(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get latest health status written by Celery periodic task
|
||||
|
||||
Returns health status information from Redis cache
|
||||
"""
|
||||
api_logger.info("Health status check requested")
|
||||
try:
|
||||
result = await memory_agent_service.get_health_status()
|
||||
return success(data=result["status"])
|
||||
except Exception as e:
|
||||
api_logger.error(f"Health status check failed: {str(e)}")
|
||||
return fail(BizCode.SERVICE_UNAVAILABLE, "健康状态查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/download_log")
|
||||
async def download_log(
|
||||
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Download or stream agent service log file
|
||||
|
||||
log_type: str = Query("file", regex="^(file|transmission)$",
|
||||
description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
|
||||
Args:
|
||||
log_type: Log retrieval mode
|
||||
- "file": Returns complete log file content in single response (default)
|
||||
- "transmission": Real-time streaming of log content using Server-Sent Events
|
||||
|
||||
Returns:
|
||||
- file mode: ApiResponse with log content
|
||||
- transmission mode: StreamingResponse with SSE
|
||||
"""
|
||||
api_logger.info(f"Log download requested with log_type={log_type}")
|
||||
|
||||
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
|
||||
if log_type not in ["file", "transmission"]:
|
||||
api_logger.warning(f"Invalid log_type parameter: {log_type}")
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"无效的log_type参数",
|
||||
"log_type必须是'file'或'transmission'"
|
||||
)
|
||||
|
||||
# Route to appropriate mode
|
||||
if log_type == "file":
|
||||
# File mode: Return complete log file content
|
||||
try:
|
||||
log_content = memory_agent_service.get_log_content()
|
||||
return success(data=log_content)
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Log content issue: {str(e)}")
|
||||
return fail(BizCode.FILE_NOT_FOUND, str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"Log reading failed: {str(e)}")
|
||||
return fail(BizCode.FILE_READ_ERROR, "日志读取失败", str(e))
|
||||
|
||||
else: # log_type == "transmission"
|
||||
# Transmission mode: Stream log content using SSE
|
||||
try:
|
||||
api_logger.info("Starting SSE log streaming")
|
||||
return StreamingResponse(
|
||||
memory_agent_service.stream_log_content(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # Disable nginx buffering
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to start log streaming: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
||||
|
||||
|
||||
@router.post("/writer_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server(
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Write service endpoint - processes write operations synchronously
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and group_id
|
||||
|
||||
Returns:
|
||||
Response with write operation status
|
||||
"""
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
|
||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
if storage_type == 'rag':
|
||||
if workspace_id:
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
api_logger.warning(f"workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.group_id,
|
||||
user_input.message,
|
||||
config_id,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
)
|
||||
return success(data=result, msg="写入成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Write operation error: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server_async(
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Async write service endpoint - enqueues write processing to Celery
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and group_id
|
||||
|
||||
Returns:
|
||||
Task ID for tracking async operation
|
||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
"""
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
try:
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Async write operation failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@router.post("/read_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def read_server(
|
||||
user_input: UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Read service endpoint - processes read operations synchronously
|
||||
|
||||
search_switch values:
|
||||
- "0": Requires verification
|
||||
- "1": No verification, direct split
|
||||
- "2": Direct answer based on context
|
||||
|
||||
Args:
|
||||
user_input: Read request with message, history, search_switch, and group_id
|
||||
|
||||
Returns:
|
||||
Response with query answer
|
||||
"""
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await memory_agent_service.read_memory(
|
||||
user_input.group_id,
|
||||
user_input.message,
|
||||
user_input.history,
|
||||
user_input.search_switch,
|
||||
config_id,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
)
|
||||
return success(data=result, msg="回复对话消息成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read operation error: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
|
||||
|
||||
|
||||
@router.post("/file", response_model=ApiResponse)
|
||||
async def file_update(
|
||||
files: List[UploadFile] = File(..., description="要上传的文件"),
|
||||
model_id:str = Form(..., description="模型ID"),
|
||||
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
文件上传接口 - 支持图片识别
|
||||
|
||||
Args:
|
||||
files: 上传的文件列表
|
||||
metadata: 文件元数据(可选)
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
文件处理结果
|
||||
"""
|
||||
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen)
|
||||
api_logger.info(f"File upload requested, file count: {len(files)}")
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
file_content = []
|
||||
try:
|
||||
for file in files:
|
||||
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
|
||||
content = await file.read()
|
||||
|
||||
if file.content_type and file.content_type.startswith("image/"):
|
||||
vision_model = QWenCV(
|
||||
key=apiConfig.api_key,
|
||||
model_name=apiConfig.model_name,
|
||||
lang="Chinese",
|
||||
base_url=apiConfig.api_base
|
||||
)
|
||||
description, token_count = vision_model.describe(content)
|
||||
file_content.append(description)
|
||||
api_logger.info(f"Image processed: {file.filename}, tokens: {token_count}")
|
||||
else:
|
||||
api_logger.warning(f"Unsupported file type: {file.content_type}")
|
||||
file_content.append(f"[不支持的文件类型: {file.content_type}]")
|
||||
|
||||
result_text = ';'.join(file_content)
|
||||
api_logger.info(f"File processing completed, result length: {len(result_text)}")
|
||||
|
||||
return success(data=result_text, msg="转换文本成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
|
||||
|
||||
|
||||
@router.post("/read_service_async", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def read_server_async(
|
||||
user_input: UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# Validate config_id
|
||||
try:
|
||||
config_id = validate_config_id(user_input.config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async read service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
api_logger.info(f"Async read: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
try:
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.read_message",
|
||||
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
|
||||
config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
api_logger.info(f"Read task queued: {task.id}")
|
||||
|
||||
return success(data={"task_id": task.id}, msg="查询任务已提交")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Async read operation failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
|
||||
|
||||
|
||||
@router.get("/read_result/", response_model=ApiResponse)
|
||||
async def get_read_task_result(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get the status and result of an async read task
|
||||
|
||||
Args:
|
||||
task_id: Celery task ID returned from /read_service_async
|
||||
|
||||
Returns:
|
||||
Task status and result if completed
|
||||
|
||||
Response format:
|
||||
- PENDING: Task is waiting to be executed
|
||||
- STARTED: Task has started
|
||||
- SUCCESS: Task completed successfully, returns result
|
||||
- FAILURE: Task failed, returns error message
|
||||
"""
|
||||
api_logger.info(f"Read task status check requested for task {task_id}")
|
||||
try:
|
||||
result = task_service.get_task_memory_read_result(task_id)
|
||||
status = result.get("status")
|
||||
|
||||
if status == "SUCCESS":
|
||||
# 任务成功完成
|
||||
task_result = result.get("result", {})
|
||||
if isinstance(task_result, dict):
|
||||
# 新格式:包含详细信息
|
||||
return success(
|
||||
data={
|
||||
"result": task_result.get("result"),
|
||||
"group_id": task_result.get("group_id"),
|
||||
"elapsed_time": task_result.get("elapsed_time"),
|
||||
"task_id": task_id
|
||||
},
|
||||
msg="查询任务已完成"
|
||||
)
|
||||
else:
|
||||
# 旧格式:直接返回结果
|
||||
return success(data=task_result, msg="查询任务已完成")
|
||||
|
||||
elif status == "FAILURE":
|
||||
# 任务失败
|
||||
error_info = result.get("result", "Unknown error")
|
||||
if isinstance(error_info, dict):
|
||||
error_msg = error_info.get("error", str(error_info))
|
||||
else:
|
||||
error_msg = str(error_info)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
|
||||
|
||||
elif status in ["PENDING", "STARTED"]:
|
||||
# 任务进行中
|
||||
return success(
|
||||
data={
|
||||
"status": status,
|
||||
"task_id": task_id,
|
||||
"message": "任务处理中,请稍后查询"
|
||||
},
|
||||
msg="查询任务处理中"
|
||||
)
|
||||
else:
|
||||
# 未知状态
|
||||
return success(
|
||||
data={
|
||||
"status": status,
|
||||
"task_id": task_id
|
||||
},
|
||||
msg=f"任务状态: {status}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/write_result/", response_model=ApiResponse)
|
||||
async def get_write_task_result(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get the status and result of an async write task
|
||||
|
||||
Args:
|
||||
task_id: Celery task ID returned from /writer_service_async
|
||||
|
||||
Returns:
|
||||
Task status and result if completed
|
||||
|
||||
Response format:
|
||||
- PENDING: Task is waiting to be executed
|
||||
- STARTED: Task has started
|
||||
- SUCCESS: Task completed successfully, returns result
|
||||
- FAILURE: Task failed, returns error message
|
||||
"""
|
||||
api_logger.info(f"Write task status check requested for task {task_id}")
|
||||
try:
|
||||
result = task_service.get_task_memory_write_result(task_id)
|
||||
status = result.get("status")
|
||||
|
||||
if status == "SUCCESS":
|
||||
# 任务成功完成
|
||||
task_result = result.get("result", {})
|
||||
if isinstance(task_result, dict):
|
||||
# 新格式:包含详细信息
|
||||
return success(
|
||||
data={
|
||||
"result": task_result.get("result"),
|
||||
"group_id": task_result.get("group_id"),
|
||||
"elapsed_time": task_result.get("elapsed_time"),
|
||||
"task_id": task_id
|
||||
},
|
||||
msg="写入任务已完成"
|
||||
)
|
||||
else:
|
||||
# 旧格式:直接返回结果
|
||||
return success(data=task_result, msg="写入任务已完成")
|
||||
|
||||
elif status == "FAILURE":
|
||||
# 任务失败
|
||||
error_info = result.get("result", "Unknown error")
|
||||
if isinstance(error_info, dict):
|
||||
error_msg = error_info.get("error", str(error_info))
|
||||
else:
|
||||
error_msg = str(error_info)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
|
||||
|
||||
elif status in ["PENDING", "STARTED"]:
|
||||
# 任务进行中
|
||||
return success(
|
||||
data={
|
||||
"status": status,
|
||||
"task_id": task_id,
|
||||
"message": "任务处理中,请稍后查询"
|
||||
},
|
||||
msg="写入任务处理中"
|
||||
)
|
||||
else:
|
||||
# 未知状态
|
||||
return success(
|
||||
data={
|
||||
"status": status,
|
||||
"task_id": task_id
|
||||
},
|
||||
msg=f"任务状态: {status}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||
|
||||
|
||||
@router.post("/status_type", response_model=ApiResponse)
|
||||
async def status_type(
|
||||
user_input: Write_UserInput,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
|
||||
Args:
|
||||
user_input: Request containing user message and group_id
|
||||
|
||||
Returns:
|
||||
Type classification result
|
||||
"""
|
||||
api_logger.info(f"Status type check requested for group {user_input.group_id}")
|
||||
try:
|
||||
result = await memory_agent_service.classify_message_type(user_input.message)
|
||||
return success(data=result)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Message type classification failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型判断失败", str(e))
|
||||
|
||||
|
||||
# ==================== 新增的三个接口路由 ====================
|
||||
|
||||
@router.get("/stats/types", response_model=ApiResponse)
|
||||
async def get_knowledge_type_stats_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory。
|
||||
会对缺失类型补 0,返回字典形式。
|
||||
可选按状态过滤。
|
||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
|
||||
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||
"""
|
||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
try:
|
||||
from app.db import get_db
|
||||
|
||||
# 获取数据库会话
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
# 调用service层函数
|
||||
result = await memory_agent_service.get_knowledge_type_stats(
|
||||
end_user_id=end_user_id,
|
||||
only_active=only_active,
|
||||
current_workspace_id=current_user.current_workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
return success(data=result, msg="获取知识库类型统计成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Knowledge type stats failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_by_user_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
limit: int = Query(20, description="返回标签数量限制"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{"name": "标签名", "frequency": 频次},
|
||||
...
|
||||
]
|
||||
"""
|
||||
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
|
||||
try:
|
||||
result = await memory_agent_service.get_hot_memory_tags_by_user(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit
|
||||
)
|
||||
return success(data=result, msg="获取热门记忆标签成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||
async def get_user_profile_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
- name: 用户名字(直接使用 end_user_id)
|
||||
- tags: 3个用户特征标签(从语句和实体中LLM总结)
|
||||
- hot_tags: 4个热门记忆标签
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"name": "用户名",
|
||||
"tags": ["产品设计师", "旅行爱好者", "摄影发烧友"],
|
||||
"hot_tags": [
|
||||
{"name": "标签1", "frequency": 10},
|
||||
{"name": "标签2", "frequency": 8},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
api_logger.info(f"User profile requested: end_user_id={end_user_id}, current_user={current_user.id}")
|
||||
try:
|
||||
result = await memory_agent_service.get_user_profile(
|
||||
end_user_id=end_user_id,
|
||||
current_user_id=str(current_user.id)
|
||||
)
|
||||
return success(data=result, msg="获取用户详情成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"User profile failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取用户详情失败", str(e))
|
||||
|
||||
|
||||
# @router.get("/docs/api", response_model=ApiResponse)
|
||||
# async def get_api_docs_api(
|
||||
# file_path: Optional[str] = Query(None, description="API文档文件路径,不传则使用默认路径")
|
||||
# ):
|
||||
# """
|
||||
# Get parsed API documentation (Public endpoint - no authentication required)
|
||||
|
||||
# Args:
|
||||
# file_path: Optional path to API docs file. If None, uses default path.
|
||||
|
||||
# Returns:
|
||||
# Parsed API documentation including title, meta info, and sections
|
||||
# """
|
||||
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
|
||||
# try:
|
||||
# result = await memory_agent_service.get_api_docs(file_path)
|
||||
|
||||
# if result.get("success"):
|
||||
# return success(msg=result["msg"], data=result["data"])
|
||||
# else:
|
||||
# return fail(
|
||||
# code=BizCode.BAD_REQUEST,
|
||||
# msg=result["msg"],
|
||||
# error=result.get("data", {}).get("error", result.get("error_code", ""))
|
||||
# )
|
||||
# except Exception as e:
|
||||
# api_logger.error(f"API docs retrieval failed: {str(e)}")
|
||||
# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e))
|
||||
516
api/app/controllers/memory_dashboard_controller.py
Normal file
516
api/app/controllers/memory_dashboard_controller.py
Normal file
@@ -0,0 +1,516 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.app_schema import App as AppSchema
|
||||
|
||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/dashboard",
|
||||
tags=["Dashboard"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/total_end_users", response_model=ApiResponse)
|
||||
def get_workspace_total_end_users(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取用户列表的总用户数
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||
total_end_users = memory_dashboard_service.get_workspace_total_end_users(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
|
||||
return success(data=total_end_users, msg="用户数量获取成功")
|
||||
|
||||
|
||||
@router.get("/end_users", response_model=ApiResponse)
|
||||
async def get_workspace_end_users(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取工作空间的宿主列表
|
||||
|
||||
返回格式与原 memory_list 接口中的 end_users 字段相同
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
)
|
||||
result = []
|
||||
for end_user in end_users:
|
||||
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
|
||||
memory_num = await memory_storage_service.search_all(str(end_user.id))
|
||||
result.append(
|
||||
{
|
||||
'end_user':end_user,
|
||||
'memory_num':memory_num
|
||||
}
|
||||
)
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
|
||||
@router.get("/memory_increment", response_model=ApiResponse)
|
||||
def get_workspace_memory_increment(
|
||||
limit: int = Query(7, description="返回记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取工作空间的记忆增量"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆增量")
|
||||
memory_increment = memory_dashboard_service.get_workspace_memory_increment(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user,
|
||||
limit=limit
|
||||
)
|
||||
api_logger.info(f"成功获取 {len(memory_increment)} 条记忆增量记录")
|
||||
return success(data=memory_increment, msg="记忆增量获取成功")
|
||||
|
||||
|
||||
@router.get("/api_increment", response_model=ApiResponse)
|
||||
def get_workspace_api_increment(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取API调用趋势"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的API调用增量")
|
||||
api_increment = memory_dashboard_service.get_workspace_api_increment(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"成功获取 {api_increment} API调用增量")
|
||||
return success(data=api_increment, msg="API调用增量获取成功")
|
||||
|
||||
|
||||
@router.post("/total_memory", response_model=ApiResponse)
|
||||
def write_workspace_total_memory(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""工作空间记忆总量的写入(异步任务)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求写入工作空间 {workspace_id} 的记忆总量")
|
||||
|
||||
# 触发 Celery 异步任务
|
||||
from app.celery_app import celery_app
|
||||
task = celery_app.send_task(
|
||||
"app.controllers.memory_storage_controller.search_all",
|
||||
kwargs={"workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
api_logger.info(f"已触发记忆总量统计任务,task_id: {task.id}")
|
||||
return success(
|
||||
data={"task_id": task.id, "workspace_id": str(workspace_id)},
|
||||
msg="记忆总量统计任务已启动"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/task_status/{task_id}", response_model=ApiResponse)
|
||||
def get_task_status(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""查询异步任务的执行状态和结果"""
|
||||
api_logger.info(f"用户 {current_user.username} 查询任务状态: task_id={task_id}")
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from celery.result import AsyncResult
|
||||
|
||||
# 获取任务结果
|
||||
task_result = AsyncResult(task_id, app=celery_app)
|
||||
|
||||
response_data = {
|
||||
"task_id": task_id,
|
||||
"status": task_result.state, # PENDING, STARTED, SUCCESS, FAILURE, RETRY, REVOKED
|
||||
}
|
||||
|
||||
# 如果任务完成,返回结果
|
||||
if task_result.ready():
|
||||
if task_result.successful():
|
||||
response_data["result"] = task_result.result
|
||||
api_logger.info(f"任务 {task_id} 执行成功")
|
||||
return success(data=response_data, msg="任务执行成功")
|
||||
else:
|
||||
# 任务失败
|
||||
response_data["error"] = str(task_result.result)
|
||||
api_logger.error(f"任务 {task_id} 执行失败: {task_result.result}")
|
||||
return success(data=response_data, msg="任务执行失败")
|
||||
else:
|
||||
# 任务还在执行中
|
||||
api_logger.info(f"任务 {task_id} 状态: {task_result.state}")
|
||||
return success(data=response_data, msg=f"任务状态: {task_result.state}")
|
||||
|
||||
|
||||
@router.get("/memory_list", response_model=ApiResponse)
|
||||
def get_workspace_memory_list(
|
||||
limit: int = Query(7, description="记忆增量返回记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
用户记忆列表整合接口
|
||||
|
||||
整合以下三个接口的数据:
|
||||
1. total_memory - 工作空间记忆总量
|
||||
2. memory_increment - 工作空间记忆增量
|
||||
3. hosts - 工作空间宿主列表
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"total_memory": float,
|
||||
"memory_increment": [
|
||||
{"date": "2024-01-01", "count": 100},
|
||||
...
|
||||
],
|
||||
"hosts": [
|
||||
{"id": "uuid", "name": "宿主名", ...},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆列表")
|
||||
memory_list = memory_dashboard_service.get_workspace_memory_list(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user,
|
||||
limit=limit
|
||||
)
|
||||
api_logger.info(f"成功获取记忆列表")
|
||||
return success(data=memory_list, msg="记忆列表获取成功")
|
||||
|
||||
|
||||
@router.get("/total_memory_count", response_model=ApiResponse)
|
||||
async def get_workspace_total_memory_count(
|
||||
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取工作空间的记忆总量(通过聚合所有host的记忆数)
|
||||
|
||||
逻辑:
|
||||
1. 从 memory_list 获取所有 host_id
|
||||
2. 对每个 host_id 调用 search_all 获取 total
|
||||
3. 将所有 total 求和返回
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"total_memory_count": int,
|
||||
"host_count": int,
|
||||
"details": [
|
||||
{"host_id": "uuid", "count": 100},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆总量")
|
||||
total_memory_count = await memory_dashboard_service.get_workspace_total_memory_count(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
api_logger.info(f"成功获取记忆总量: {total_memory_count.get('total_memory_count', 0)}")
|
||||
return success(data=total_memory_count, msg="记忆总量获取成功")
|
||||
|
||||
|
||||
# ======== RAG 数据统计 ========
|
||||
@router.get("/total_rag_count", response_model=ApiResponse)
|
||||
def get_workspace_total_rag_count(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取 rag 的总文档数、总chunk数、总知识库数量、总api调用数量
|
||||
"""
|
||||
total_documents = memory_dashboard_service.get_rag_total_doc(db, current_user)
|
||||
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
|
||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||
data = {
|
||||
'total_documents':total_documents,
|
||||
'total_chunk':total_chunk,
|
||||
'total_kb':total_kb,
|
||||
'total_api':1024
|
||||
}
|
||||
return success(data=data, msg="RAG相关数据获取成功")
|
||||
|
||||
@router.get("/current_user_rag_total_num", response_model=ApiResponse)
|
||||
def get_current_user_rag_total_num(
|
||||
end_user_id: str = Query(..., description="宿主ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取当前宿主的 RAG 的总chunk数量
|
||||
"""
|
||||
total_chunk = memory_dashboard_service.get_current_user_total_chunk(end_user_id, db, current_user)
|
||||
return success(data=total_chunk, msg="宿主RAG知识数据获取成功")
|
||||
|
||||
@router.get("/rag_content", response_model=ApiResponse)
|
||||
def get_rag_content(
|
||||
end_user_id: str = Query(..., description="宿主ID"),
|
||||
limit: int = Query(15, description="返回记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取当前宿主知识库中的chunk内容
|
||||
"""
|
||||
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user)
|
||||
return success(data=data, msg="宿主RAGchunk数据获取成功")
|
||||
|
||||
|
||||
@router.get("/chunk_summary_tag", response_model=ApiResponse)
|
||||
async def get_chunk_summary_tag(
|
||||
end_user_id: str = Query(..., description="宿主ID"),
|
||||
limit: int = Query(15, description="返回记录数"),
|
||||
max_tags: int = Query(10, description="最大标签数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取chunk总结、提取的标签和人物形象
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"summary": "chunk内容的总结",
|
||||
"tags": [
|
||||
{"tag": "标签1", "frequency": 5},
|
||||
{"tag": "标签2", "frequency": 3},
|
||||
...
|
||||
],
|
||||
"personas": [
|
||||
"产品设计师",
|
||||
"旅行爱好者",
|
||||
"摄影发烧友",
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk摘要、标签和人物形象")
|
||||
|
||||
data = await memory_dashboard_service.get_chunk_summary_and_tags(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
max_tags=max_tags,
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象")
|
||||
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
|
||||
|
||||
|
||||
@router.get("/chunk_insight", response_model=ApiResponse)
|
||||
async def get_chunk_insight(
|
||||
end_user_id: str = Query(..., description="宿主ID"),
|
||||
limit: int = Query(15, description="返回记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取chunk的洞察内容
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"insight": "对chunk内容的深度洞察分析"
|
||||
}
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk洞察")
|
||||
|
||||
data = await memory_dashboard_service.get_chunk_insight(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取chunk洞察")
|
||||
return success(data=data, msg="chunk洞察获取成功")
|
||||
|
||||
|
||||
@router.get("/dashboard_data", response_model=ApiResponse)
|
||||
async def dashboard_data(
|
||||
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
整合dashboard数据接口
|
||||
|
||||
整合以下接口的数据:
|
||||
1. /dashboard/total_memory_count - 记忆总量
|
||||
2. /dashboard/api_increment - API调用增量
|
||||
3. /memory/stats/types - 知识库类型统计(只要total数据)
|
||||
4. /dashboard/total_rag_count - RAG相关数据
|
||||
|
||||
根据 storage_type 判断调用不同的接口
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"storage_type": str,
|
||||
"neo4j_data": {
|
||||
"total_memory": int,
|
||||
"total_app": int,
|
||||
"total_knowledge": int,
|
||||
"total_api_call": int
|
||||
} | null,
|
||||
"rag_data": {
|
||||
"total_memory": int,
|
||||
"total_app": int,
|
||||
"total_knowledge": int,
|
||||
"total_api_call": int
|
||||
} | null
|
||||
}
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None:
|
||||
storage_type = 'neo4j'
|
||||
|
||||
user_rag_memory_id = None
|
||||
|
||||
# 根据 storage_type 决定返回哪个数据对象
|
||||
# 如果是 'rag',neo4j_data 为 null;否则 rag_data 为 null
|
||||
result = {
|
||||
"storage_type": storage_type,
|
||||
"neo4j_data": None,
|
||||
"rag_data": None
|
||||
}
|
||||
|
||||
try:
|
||||
# 如果 storage_type 为 'neo4j' 或空,获取 neo4j_data
|
||||
if storage_type == 'neo4j':
|
||||
neo4j_data = {
|
||||
"total_memory": None,
|
||||
"total_app": None,
|
||||
"total_knowledge": None,
|
||||
"total_api_call": None
|
||||
}
|
||||
|
||||
# 1. 获取记忆总量(total_memory)
|
||||
try:
|
||||
total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
from app.repositories import app_repository
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||
neo4j_data["total_app"] = len(apps_orm)
|
||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
||||
|
||||
# 2. 获取知识库类型统计(total_knowledge)
|
||||
try:
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
memory_agent_service = MemoryAgentService()
|
||||
knowledge_stats = await memory_agent_service.get_knowledge_type_stats(
|
||||
end_user_id=end_user_id,
|
||||
only_active=True,
|
||||
current_workspace_id=workspace_id,
|
||||
db=db
|
||||
)
|
||||
neo4j_data["total_knowledge"] = knowledge_stats.get("total", 0)
|
||||
api_logger.info(f"成功获取知识库类型统计total: {neo4j_data['total_knowledge']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
|
||||
|
||||
# 3. 获取API调用增量(total_api_call,转换为整数)
|
||||
try:
|
||||
api_increment = memory_dashboard_service.get_workspace_api_increment(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
)
|
||||
neo4j_data["total_api_call"] = api_increment
|
||||
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取API调用增量失败: {str(e)}")
|
||||
|
||||
result["neo4j_data"] = neo4j_data
|
||||
api_logger.info(f"成功获取neo4j_data")
|
||||
|
||||
# 如果 storage_type 为 'rag',获取 rag_data
|
||||
elif storage_type == 'rag':
|
||||
rag_data = {
|
||||
"total_memory": None,
|
||||
"total_app": None,
|
||||
"total_knowledge": None,
|
||||
"total_api_call": None
|
||||
}
|
||||
|
||||
# 获取RAG相关数据
|
||||
try:
|
||||
# total_memory: 使用 total_chunk(总chunk数)
|
||||
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
|
||||
rag_data["total_memory"] = total_chunk
|
||||
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
from app.repositories import app_repository
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||
rag_data["total_app"] = len(apps_orm)
|
||||
|
||||
# total_knowledge: 使用 total_kb(总知识库数)
|
||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||
rag_data["total_knowledge"] = total_kb
|
||||
|
||||
# total_api_call: 固定值
|
||||
rag_data["total_api_call"] = 1024
|
||||
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||
|
||||
result["rag_data"] = rag_data
|
||||
api_logger.info(f"成功获取rag_data")
|
||||
|
||||
api_logger.info(f"成功获取dashboard整合数据")
|
||||
return success(data=result, msg="Dashboard数据获取成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取dashboard整合数据失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取dashboard整合数据失败: {str(e)}"
|
||||
)
|
||||
542
api/app/controllers/memory_storage_controller.py
Normal file
542
api/app/controllers/memory_storage_controller.py
Normal file
@@ -0,0 +1,542 @@
|
||||
from typing import Optional
|
||||
import os
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.memory_storage_service import (
|
||||
MemoryStorageService,
|
||||
DataConfigService,
|
||||
kb_type_distribution,
|
||||
search_dialogue,
|
||||
search_chunk,
|
||||
search_statement,
|
||||
search_entity,
|
||||
search_all,
|
||||
search_detials,
|
||||
search_edges,
|
||||
search_entity_graph,
|
||||
analytics_hot_memory_tags,
|
||||
analytics_memory_insight_report,
|
||||
analytics_recent_activity_stats,
|
||||
analytics_user_summary,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
ConfigKey,
|
||||
ConfigPilotRun,
|
||||
)
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Initialize service
|
||||
memory_storage_service = MemoryStorageService()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory-storage",
|
||||
tags=["Memory Storage"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/info", response_model=ApiResponse)
|
||||
async def get_storage_info(
|
||||
storage_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Example wrapper endpoint - retrieves storage information
|
||||
|
||||
Args:
|
||||
storage_id: Storage identifier
|
||||
|
||||
Returns:
|
||||
Storage information
|
||||
"""
|
||||
api_logger.info(f"Storage info requested ")
|
||||
try:
|
||||
result = await memory_storage_service.get_storage_info()
|
||||
return success(data=result)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Storage info retrieval failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||
|
||||
|
||||
# --- DB connection dependency ---
|
||||
_CONN: Optional[object] = None
|
||||
|
||||
|
||||
"""PostgreSQL 连接生成与管理(使用 psycopg2)。"""
|
||||
# 这个可以转移,可能是已经有的
|
||||
# PostgreSQL 数据库连接
|
||||
def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接
|
||||
host = os.getenv("DB_HOST")
|
||||
user = os.getenv("DB_USER")
|
||||
password = os.getenv("DB_PASSWORD")
|
||||
database = os.getenv("DB_NAME")
|
||||
port_str = os.getenv("DB_PORT")
|
||||
try:
|
||||
import psycopg2 # type: ignore
|
||||
port = int(port_str) if port_str else 5432
|
||||
conn = psycopg2.connect(
|
||||
host=host or "localhost",
|
||||
port=port,
|
||||
user=user,
|
||||
password=password,
|
||||
dbname=database,
|
||||
)
|
||||
# 设置自动提交,避免显式事务管理
|
||||
conn.autocommit = True
|
||||
# 设置会话时区为中国标准时间(Asia/Shanghai),便于直接以本地时区展示
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
|
||||
cur.close()
|
||||
except Exception:
|
||||
# 时区设置失败不影响连接,仅记录但不抛出
|
||||
pass
|
||||
return conn
|
||||
except Exception as e:
|
||||
try:
|
||||
print(f"[PostgreSQL] 连接失败: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接
|
||||
global _CONN
|
||||
if _CONN is None:
|
||||
_CONN = _make_pgsql_conn()
|
||||
return _CONN
|
||||
|
||||
|
||||
def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
|
||||
"""Close and recreate the global DB connection."""
|
||||
global _CONN
|
||||
try:
|
||||
if _CONN:
|
||||
try:
|
||||
_CONN.close()
|
||||
except Exception:
|
||||
pass
|
||||
_CONN = _make_pgsql_conn()
|
||||
return _CONN is not None
|
||||
except Exception:
|
||||
_CONN = None
|
||||
return False
|
||||
|
||||
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}")
|
||||
try:
|
||||
# 将 workspace_id 注入到 payload 中(保持为 UUID 类型)
|
||||
payload.workspace_id = workspace_id
|
||||
svc = DataConfigService(get_db_conn())
|
||||
result = svc.create(payload)
|
||||
return success(data=result, msg="创建成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Create config failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||
|
||||
|
||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||
def delete_config(
|
||||
config_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
result = svc.delete(ConfigParamsDelete(config_id=config_id))
|
||||
return success(data=result, msg="删除成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Delete config failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
|
||||
|
||||
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||
def update_config(
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
result = svc.update(payload)
|
||||
return success(data=result, msg="更新成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Update config failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "更新配置失败", str(e))
|
||||
|
||||
|
||||
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
|
||||
def update_config_extracted(
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
result = svc.update_extracted(payload)
|
||||
return success(data=result, msg="更新成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Update config extracted failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "更新配置失败", str(e))
|
||||
|
||||
|
||||
# --- Forget config params ---
|
||||
@router.post("/update_config_forget", response_model=ApiResponse) # 更新遗忘引擎配置参数(固定路径)
|
||||
def update_config_forget(
|
||||
payload: ConfigUpdateForget,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
result = svc.update_forget(payload)
|
||||
return success(data=result, msg="更新成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Update config forget failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e))
|
||||
|
||||
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_extracted(
|
||||
config_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
result = svc.get_extracted(ConfigKey(config_id=config_id))
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read config extracted failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
||||
|
||||
@router.get("/read_config_forget", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_forget(
|
||||
config_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
result = svc.get_forget(ConfigKey(config_id=config_id))
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read config forget failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e))
|
||||
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
def read_all_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
# 传递 workspace_id 进行过滤(保持为 UUID 类型)
|
||||
result = svc.get_all(workspace_id=workspace_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read all config failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询所有配置失败", str(e))
|
||||
|
||||
|
||||
@router.post("/pilot_run", response_model=ApiResponse) # 试运行:触发执行主管线,使用 POST 更为合理
|
||||
async def pilot_run(
|
||||
payload: ConfigPilotRun,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
|
||||
|
||||
# 先尝试从数据库加载配置
|
||||
try:
|
||||
config_loaded = reload_configuration_from_database(str(payload.config_id))
|
||||
if not config_loaded:
|
||||
api_logger.error(f"Failed to load configuration for config_id: {payload.config_id}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "配置加载失败", f"无法加载 config_id={payload.config_id} 的配置")
|
||||
api_logger.info(f"Configuration loaded successfully for config_id: {payload.config_id}")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Exception while loading configuration: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "配置加载异常", str(e))
|
||||
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
result = await svc.pilot_run(payload)
|
||||
return success(data=result, msg="试运行完成")
|
||||
except ValueError as e:
|
||||
# 捕获参数验证错误
|
||||
api_logger.error(f"Pilot run parameter validation failed: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "参数验证失败", str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"Pilot run failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "试运行失败", str(e))
|
||||
|
||||
"""
|
||||
以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。
|
||||
"""
|
||||
|
||||
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||
async def get_kb_type_distribution(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await kb_type_distribution(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"KB type distribution failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/search/dialogue", response_model=ApiResponse)
|
||||
async def search_dialogues_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_dialogue(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Search dialogue failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "对话查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/search/chunk", response_model=ApiResponse)
|
||||
async def search_chunks_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_chunk(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Search chunk failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "分块查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/search/statement", response_model=ApiResponse)
|
||||
async def search_statements_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_statement(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Search statement failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "语句查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/search/entity", response_model=ApiResponse)
|
||||
async def search_entities_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_entity(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Search entity failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "实体查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/search", response_model=ApiResponse)
|
||||
async def search_all_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_all(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Search all failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "全部查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/search/detials", response_model=ApiResponse)
|
||||
async def search_entities_detials(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_detials(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Search details failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "详情查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/search/edges", response_model=ApiResponse)
|
||||
async def search_entity_edges(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_edges(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Search edges failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
||||
|
||||
@router.get("/search/entity_graph", response_model=ApiResponse)
|
||||
async def search_for_entity_graph(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
搜索所有实体之间的关系网络
|
||||
"""
|
||||
api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_entity_graph(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Search entity graph failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_api(
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Hot memory tags requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await analytics_hot_memory_tags(end_user_id, limit)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Memory insight report requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await analytics_memory_insight_report(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Memory insight report failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告生成失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||
async def get_recent_activity_stats_api(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info("Recent activity stats requested")
|
||||
try:
|
||||
result = await analytics_recent_activity_stats()
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"User summary requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await analytics_user_summary(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"User summary failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户摘要生成失败", str(e))
|
||||
|
||||
from app.core.memory.utils.self_reflexion_utils import self_reflexion
|
||||
@router.get("/self_reflexion")
|
||||
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
|
||||
"""
|
||||
自我反思接口,自动对检索出的信息进行自我反思并返回自我反思结果。
|
||||
|
||||
Args:
|
||||
None
|
||||
Returns:
|
||||
自我反思结果。
|
||||
"""
|
||||
return await self_reflexion(host_id)
|
||||
332
api/app/controllers/model_controller.py
Normal file
332
api/app/controllers/model_controller.py
Normal file
@@ -0,0 +1,332 @@
|
||||
from fastapi import APIRouter, Depends, status, Query
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
from app.core.models import RedBearLLM
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from app.models.user_model import User
|
||||
from app.schemas import model_schema
|
||||
from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse, PageData
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/models",
|
||||
tags=["Models"],
|
||||
)
|
||||
|
||||
@router.get("/type", response_model=ApiResponse)
|
||||
def get_model_types():
|
||||
|
||||
return success(msg="获取模型类型成功", data=list(ModelType))
|
||||
|
||||
|
||||
@router.get("/provider", response_model=ApiResponse)
|
||||
def get_model_providers():
|
||||
return success(msg="获取模型提供商成功", data=list(ModelProvider))
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[List[model_schema.ModelType]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM&type=EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取模型配置列表
|
||||
|
||||
支持多个 type 参数:
|
||||
- 单个:?type=LLM
|
||||
- 多个:?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}")
|
||||
|
||||
try:
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type,
|
||||
provider=provider,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query)
|
||||
result = PageData.model_validate(result_orm)
|
||||
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||
return success(data=result, msg="模型配置列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{model_id}", response_model=ApiResponse)
|
||||
def get_model_by_id(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
根据ID获取模型配置
|
||||
"""
|
||||
api_logger.info(f"获取模型配置请求: model_id={model_id}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始获取模型配置: model_id={model_id}")
|
||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
api_logger.info(f"模型配置获取成功: {result_orm.name}")
|
||||
|
||||
# 将ORM对象转换为Pydantic模型
|
||||
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||
|
||||
return success(data=result_pydantic, msg="模型配置获取成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取模型配置失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
async def create_model(
|
||||
model_data: model_schema.ModelConfigCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
创建模型配置
|
||||
|
||||
- 创建模型配置基础信息
|
||||
- 如果包含 API Key,会先验证配置有效性,然后创建
|
||||
- 验证失败时会抛出异常,不会创建配置
|
||||
- 可通过 skip_validation=true 跳过验证
|
||||
"""
|
||||
api_logger.info(f"创建模型配置请求: {model_data.name}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始创建模型配置: {model_data.name}")
|
||||
result_orm = await ModelConfigService.create_model(db=db, model_data=model_data)
|
||||
api_logger.info(f"模型配置创建成功: {result_orm.name} (ID: {result_orm.id})")
|
||||
|
||||
# 将ORM对象转换为Pydantic模型
|
||||
result = model_schema.ModelConfig.model_validate(result_orm)
|
||||
|
||||
return success(data=result, msg="模型配置创建成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建模型配置失败: {model_data.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{model_id}", response_model=ApiResponse)
|
||||
def update_model(
|
||||
model_id: uuid.UUID,
|
||||
model_data: model_schema.ModelConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
更新模型配置
|
||||
"""
|
||||
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
||||
result_orm = ModelConfigService.update_model(db=db, model_id=model_id, model_data=model_data)
|
||||
api_logger.info(f"模型配置更新成功: {result_orm.name} (ID: {model_id})")
|
||||
|
||||
# 将ORM对象转换为Pydantic模型
|
||||
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||
|
||||
return success(data=result_pydantic, msg="模型配置更新成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"更新模型配置失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.delete("/{model_id}", response_model=ApiResponse)
|
||||
def delete_model(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
删除模型配置
|
||||
"""
|
||||
api_logger.info(f"删除模型配置请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始删除模型配置: model_id={model_id}")
|
||||
ModelConfigService.delete_model(db=db, model_id=model_id)
|
||||
api_logger.info(f"模型配置删除成功: model_id={model_id}")
|
||||
return success(msg="模型配置删除成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"删除模型配置失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# API Key 相关接口
|
||||
@router.get("/{model_id}/apikeys", response_model=ApiResponse)
|
||||
def get_model_api_keys(
|
||||
model_id: uuid.UUID,
|
||||
is_active: bool = Query(True, description="是否只获取活跃的API Key"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取模型的API Key列表
|
||||
"""
|
||||
api_logger.info(f"获取模型API Key列表请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始获取模型API Key列表: model_id={model_id}")
|
||||
result_orm = ModelApiKeyService.get_api_keys_by_model(
|
||||
db=db, model_config_id=model_id, is_active=is_active
|
||||
)
|
||||
|
||||
# 将ORM对象列表转换为Pydantic模型列表
|
||||
result_pydantic = [model_schema.ModelApiKey.model_validate(item) for item in result_orm]
|
||||
|
||||
api_logger.info(f"模型API Key列表获取成功: 数量={len(result_pydantic)}")
|
||||
return success(data=result_pydantic, msg="模型API Key列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取模型API Key列表失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_model_api_key(
|
||||
model_id: uuid.UUID,
|
||||
api_key_data: model_schema.ModelApiKeyCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
为模型创建API Key
|
||||
"""
|
||||
api_logger.info(f"创建模型API Key请求: model_id={model_id}, model_name={api_key_data.model_name}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 设置模型配置ID
|
||||
api_key_data.model_config_id = model_id
|
||||
|
||||
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
|
||||
result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
||||
api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})")
|
||||
return success(data=result, msg="模型API Key创建成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/apikeys/{api_key_id}", response_model=ApiResponse)
|
||||
def get_api_key_by_id(
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
根据ID获取API Key
|
||||
"""
|
||||
api_logger.info(f"获取API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始获取API Key: api_key_id={api_key_id}")
|
||||
result = ModelApiKeyService.get_api_key_by_id(db=db, api_key_id=api_key_id)
|
||||
api_logger.info(f"API Key获取成功: {result.model_name}")
|
||||
return success(data=result, msg="API Key获取成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取API Key失败: api_key_id={api_key_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/apikeys/{api_key_id}", response_model=ApiResponse)
|
||||
async def update_api_key(
|
||||
api_key_id: uuid.UUID,
|
||||
api_key_data: model_schema.ModelApiKeyUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
更新API Key
|
||||
"""
|
||||
api_logger.info(f"更新API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始更新API Key: api_key_id={api_key_id}")
|
||||
result = await ModelApiKeyService.update_api_key(db=db, api_key_id=api_key_id, api_key_data=api_key_data)
|
||||
api_logger.info(f"API Key更新成功: {result.model_name} (ID: {api_key_id})")
|
||||
result_pydantic = model_schema.ModelApiKey.model_validate(result)
|
||||
return success(data=result_pydantic, msg="API Key更新成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"更新API Key失败: api_key_id={api_key_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.delete("/apikeys/{api_key_id}", response_model=ApiResponse)
|
||||
def delete_api_key(
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
删除API Key
|
||||
"""
|
||||
api_logger.info(f"删除API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始删除API Key: api_key_id={api_key_id}")
|
||||
ModelApiKeyService.delete_api_key(db=db, api_key_id=api_key_id)
|
||||
api_logger.info(f"API Key删除成功: api_key_id={api_key_id}")
|
||||
return success(msg="API Key删除成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"删除API Key失败: api_key_id={api_key_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/validate", response_model=ApiResponse)
|
||||
async def validate_model_config(
|
||||
validate_data: model_schema.ModelValidateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
验证模型配置是否有效
|
||||
|
||||
支持验证不同类型的模型:
|
||||
- llm: 大语言模型
|
||||
- chat: 对话模型
|
||||
- embedding: 向量模型
|
||||
- rerank: 重排序模型
|
||||
"""
|
||||
api_logger.info(f"验证模型配置请求: {validate_data.model_name} ({validate_data.model_type}), 用户: {current_user.username}")
|
||||
|
||||
result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=validate_data.model_name,
|
||||
provider=validate_data.provider,
|
||||
api_key=validate_data.api_key,
|
||||
api_base=validate_data.api_base,
|
||||
model_type=validate_data.model_type,
|
||||
test_message=validate_data.test_message
|
||||
)
|
||||
|
||||
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")
|
||||
|
||||
|
||||
|
||||
|
||||
404
api/app/controllers/multi_agent_controller.py
Normal file
404
api/app/controllers/multi_agent_controller.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""多 Agent 控制器"""
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Query, Path
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.schemas import multi_agent_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import User
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Multi-Agent"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
# ==================== 多 Agent 配置管理 ====================
|
||||
|
||||
@router.post(
|
||||
"/{app_id}/multi-agent",
|
||||
summary="创建多 Agent 配置"
|
||||
)
|
||||
def create_multi_agent_config(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
data: multi_agent_schema.MultiAgentConfigCreate = ...,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""创建多 Agent 配置
|
||||
|
||||
支持四种编排模式:
|
||||
- sequential: 顺序执行
|
||||
- parallel: 并行执行
|
||||
- conditional: 条件路由
|
||||
- loop: 循环执行
|
||||
"""
|
||||
service = MultiAgentService(db)
|
||||
config = service.create_config(
|
||||
app_id=app_id,
|
||||
data=data,
|
||||
created_by=current_user.id
|
||||
)
|
||||
|
||||
return success(
|
||||
data=multi_agent_schema.MultiAgentConfigSchema.model_validate(config),
|
||||
msg="多 Agent 配置创建成功"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{app_id}/multi-agent",
|
||||
summary="获取当前应用的最新有效多 Agent 配置"
|
||||
)
|
||||
def get_multi_agent_configs(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取指定应用的最新有效多 Agent 配置,如果不存在则返回默认模板"""
|
||||
service = MultiAgentService(db)
|
||||
|
||||
# 通过 app_id 获取最新有效配置(已转换 agent_id 为 app_id)
|
||||
config = service.get_multi_agent_configs(app_id)
|
||||
|
||||
if not config:
|
||||
# 返回默认模板
|
||||
default_template = {
|
||||
"app_id": str(app_id),
|
||||
"master_agent_id": None,
|
||||
"master_agent_name": None,
|
||||
"orchestration_mode": "conditional",
|
||||
"sub_agents": [],
|
||||
"routing_rules": [],
|
||||
"execution_config": {
|
||||
"max_iterations": 10,
|
||||
"timeout": 300,
|
||||
"enable_parallel": False,
|
||||
"error_handling": "stop"
|
||||
},
|
||||
"aggregation_strategy": "merge",
|
||||
}
|
||||
return success(
|
||||
data=default_template,
|
||||
msg="该应用暂无配置,返回默认模板"
|
||||
)
|
||||
|
||||
# config 已经是字典格式,直接返回
|
||||
return success(data=config)
|
||||
|
||||
@router.put(
|
||||
"/{app_id}/multi-agent",
|
||||
summary="更新多 Agent 配置"
|
||||
)
|
||||
def update_multi_agent_config(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
data: multi_agent_schema.MultiAgentConfigUpdate = ...,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""更新多 Agent 配置"""
|
||||
service = MultiAgentService(db)
|
||||
config = service.update_config(app_id, data)
|
||||
|
||||
return success(
|
||||
data=multi_agent_schema.MultiAgentConfigSchema.model_validate(config),
|
||||
msg="多 Agent 配置更新成功"
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{app_id}/multi-agent",
|
||||
summary="删除多 Agent 配置"
|
||||
)
|
||||
def delete_multi_agent_config(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除多 Agent 配置"""
|
||||
service = MultiAgentService(db)
|
||||
service.delete_config(app_id)
|
||||
|
||||
return success(msg="多 Agent 配置删除成功")
|
||||
|
||||
# ==================== 多 Agent 运行 ====================
|
||||
|
||||
@router.post(
|
||||
"/{app_id}/multi-agent/run",
|
||||
summary="运行多 Agent 任务"
|
||||
)
|
||||
async def run_multi_agent(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
request: multi_agent_schema.MultiAgentRunRequest = ...,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""运行多 Agent 任务
|
||||
|
||||
根据配置的编排模式执行多个 Agent:
|
||||
- sequential: 按优先级顺序执行
|
||||
- parallel: 并行执行所有 Agent
|
||||
- conditional: 根据条件选择 Agent
|
||||
- loop: 循环执行直到满足条件
|
||||
"""
|
||||
service = MultiAgentService(db)
|
||||
result = await service.run(app_id, request)
|
||||
|
||||
return success(
|
||||
data=multi_agent_schema.MultiAgentRunResponse(**result),
|
||||
msg="多 Agent 任务执行成功"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 智能路由测试 ====================
|
||||
|
||||
@router.post(
|
||||
"/{app_id}/multi-agent/test-routing",
|
||||
summary="测试智能路由"
|
||||
)
|
||||
async def test_routing(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
request: multi_agent_schema.RoutingTestRequest = ...,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""测试智能路由功能
|
||||
|
||||
支持三种路由模式:
|
||||
- keyword: 仅使用关键词路由
|
||||
- llm: 使用 LLM 路由(需要提供 routing_model_id)
|
||||
- hybrid: 混合路由(关键词 + LLM)
|
||||
|
||||
参数:
|
||||
- message: 测试消息
|
||||
- conversation_id: 会话 ID(可选)
|
||||
- routing_model_id: 路由模型 ID(可选,用于 LLM 路由)
|
||||
- use_llm: 是否启用 LLM(默认 False)
|
||||
- keyword_threshold: 关键词置信度阈值(默认 0.8)
|
||||
"""
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.services.llm_router import LLMRouter
|
||||
from app.models import ModelConfig
|
||||
|
||||
# 1. 获取多 Agent 配置
|
||||
service = MultiAgentService(db)
|
||||
config = service.get_config(app_id)
|
||||
|
||||
if not config:
|
||||
return success(
|
||||
data=None,
|
||||
msg="应用未配置多 Agent,无法测试路由"
|
||||
)
|
||||
|
||||
# 2. 准备子 Agent 信息
|
||||
sub_agents = {}
|
||||
for sub_agent_info in config.sub_agents:
|
||||
agent_id = sub_agent_info["agent_id"]
|
||||
sub_agents[agent_id] = {
|
||||
"name": sub_agent_info.get("name", agent_id),
|
||||
"role": sub_agent_info.get("role", "")
|
||||
}
|
||||
|
||||
# 3. 获取路由模型(如果指定)
|
||||
routing_model = None
|
||||
if request.routing_model_id:
|
||||
routing_model = db.get(ModelConfig, request.routing_model_id)
|
||||
if not routing_model:
|
||||
return success(
|
||||
data=None,
|
||||
msg=f"路由模型不存在: {request.routing_model_id}"
|
||||
)
|
||||
|
||||
# 4. 初始化路由器
|
||||
state_manager = ConversationStateManager()
|
||||
router = LLMRouter(
|
||||
db=db,
|
||||
state_manager=state_manager,
|
||||
routing_rules=config.routing_rules or [],
|
||||
sub_agents=sub_agents,
|
||||
routing_model_config=routing_model,
|
||||
use_llm=request.use_llm and routing_model is not None
|
||||
)
|
||||
|
||||
# 5. 设置阈值
|
||||
if request.keyword_threshold:
|
||||
router.keyword_high_confidence_threshold = request.keyword_threshold
|
||||
|
||||
# 6. 执行路由
|
||||
try:
|
||||
routing_result = await router.route(
|
||||
message=request.message,
|
||||
conversation_id=str(request.conversation_id) if request.conversation_id else None,
|
||||
force_new=request.force_new
|
||||
)
|
||||
|
||||
# 7. 获取 Agent 信息
|
||||
agent_id = routing_result["agent_id"]
|
||||
agent_info = sub_agents.get(agent_id, {})
|
||||
|
||||
# 8. 构建响应
|
||||
response_data = {
|
||||
"message": request.message,
|
||||
"routing_result": {
|
||||
"agent_id": agent_id,
|
||||
"agent_name": agent_info.get("name", agent_id),
|
||||
"agent_role": agent_info.get("role", ""),
|
||||
"confidence": routing_result["confidence"],
|
||||
"strategy": routing_result["strategy"],
|
||||
"topic": routing_result["topic"],
|
||||
"topic_changed": routing_result["topic_changed"],
|
||||
"reason": routing_result["reason"],
|
||||
"routing_method": routing_result["routing_method"]
|
||||
},
|
||||
"cmulti-agent/batch-test-routingonfig_info": {
|
||||
"use_llm": request.use_llm and routing_model is not None,
|
||||
"routing_model": routing_model.name if routing_model else None,
|
||||
"keyword_threshold": router.keyword_high_confidence_threshold,
|
||||
"total_sub_agents": len(sub_agents)
|
||||
}
|
||||
}
|
||||
|
||||
return success(
|
||||
data=response_data,
|
||||
msg="路由测试成功"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"路由测试失败: {str(e)}")
|
||||
return success(
|
||||
data=None,
|
||||
msg=f"路由测试失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{app_id}/",
|
||||
summary="批量测试智能路由"
|
||||
)
|
||||
async def batch_test_routing(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
request: multi_agent_schema.BatchRoutingTestRequest = ...,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""批量测试智能路由功能
|
||||
|
||||
用于测试多条消息的路由效果,并统计准确率
|
||||
|
||||
参数:
|
||||
- test_cases: 测试用例列表
|
||||
- routing_model_id: 路由模型 ID(可选)
|
||||
- use_llm: 是否启用 LLM
|
||||
- keyword_threshold: 关键词置信度阈值
|
||||
"""
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.services.llm_router import LLMRouter
|
||||
from app.models import ModelConfig
|
||||
|
||||
# 1. 获取多 Agent 配置
|
||||
service = MultiAgentService(db)
|
||||
config = service.get_config(app_id)
|
||||
|
||||
if not config:
|
||||
return success(
|
||||
data=None,
|
||||
msg="应用未配置多 Agent,无法测试路由"
|
||||
)
|
||||
|
||||
# 2. 准备子 Agent 信息
|
||||
sub_agents = {}
|
||||
for sub_agent_info in config.sub_agents:
|
||||
agent_id = sub_agent_info["agent_id"]
|
||||
sub_agents[agent_id] = {
|
||||
"name": sub_agent_info.get("name", agent_id),
|
||||
"role": sub_agent_info.get("role", "")
|
||||
}
|
||||
|
||||
# 3. 获取路由模型
|
||||
routing_model = None
|
||||
if request.routing_model_id:
|
||||
routing_model = db.get(ModelConfig, request.routing_model_id)
|
||||
|
||||
# 4. 初始化路由器
|
||||
state_manager = ConversationStateManager()
|
||||
router = LLMRouter(
|
||||
db=db,
|
||||
state_manager=state_manager,
|
||||
routing_rules=config.routing_rules or [],
|
||||
sub_agents=sub_agents,
|
||||
routing_model_config=routing_model,
|
||||
use_llm=request.use_llm and routing_model is not None
|
||||
)
|
||||
|
||||
if request.keyword_threshold:
|
||||
router.keyword_high_confidence_threshold = request.keyword_threshold
|
||||
|
||||
# 5. 批量测试
|
||||
results = []
|
||||
correct_count = 0
|
||||
total_count = len(request.test_cases)
|
||||
|
||||
for test_case in request.test_cases:
|
||||
try:
|
||||
routing_result = await router.route(
|
||||
message=test_case.message,
|
||||
conversation_id=str(uuid.uuid4()) # 每个测试用例使用独立会话
|
||||
)
|
||||
|
||||
agent_id = routing_result["agent_id"]
|
||||
agent_info = sub_agents.get(agent_id, {})
|
||||
|
||||
# 判断是否正确
|
||||
is_correct = None
|
||||
if test_case.expected_agent_id:
|
||||
is_correct = (agent_id == str(test_case.expected_agent_id))
|
||||
if is_correct:
|
||||
correct_count += 1
|
||||
|
||||
results.append({
|
||||
"message": test_case.message,
|
||||
"description": test_case.description,
|
||||
"routed_agent_id": agent_id,
|
||||
"routed_agent_name": agent_info.get("name"),
|
||||
"expected_agent_id": str(test_case.expected_agent_id) if test_case.expected_agent_id else None,
|
||||
"is_correct": is_correct,
|
||||
"confidence": routing_result["confidence"],
|
||||
"routing_method": routing_result["routing_method"],
|
||||
"strategy": routing_result["strategy"]
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"测试用例失败: {test_case.message}, 错误: {str(e)}")
|
||||
results.append({
|
||||
"message": test_case.message,
|
||||
"description": test_case.description,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# 6. 统计
|
||||
accuracy = None
|
||||
if correct_count > 0:
|
||||
total_with_expected = sum(1 for r in results if r.get("expected_agent_id"))
|
||||
if total_with_expected > 0:
|
||||
accuracy = correct_count / total_with_expected * 100
|
||||
|
||||
response_data = {
|
||||
"total_count": total_count,
|
||||
"correct_count": correct_count,
|
||||
"accuracy": accuracy,
|
||||
"results": results,
|
||||
"config_info": {
|
||||
"use_llm": request.use_llm and routing_model is not None,
|
||||
"routing_model": routing_model.name if routing_model else None,
|
||||
"keyword_threshold": router.keyword_high_confidence_threshold
|
||||
}
|
||||
}
|
||||
|
||||
return success(
|
||||
data=response_data,
|
||||
msg=f"批量测试完成,准确率: {accuracy:.1f}%" if accuracy else "批量测试完成"
|
||||
)
|
||||
437
api/app/controllers/public_share_controller.py
Normal file
437
api/app/controllers/public_share_controller.py
Normal file
@@ -0,0 +1,437 @@
|
||||
from fastapi import APIRouter, Depends, Query, Request, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
import hashlib
|
||||
import time
|
||||
import jwt
|
||||
from typing import Optional, Dict
|
||||
from functools import wraps
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.config import settings
|
||||
from app.schemas import release_share_schema, conversation_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.auth_service import create_access_token
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
|
||||
|
||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
def get_base_url(request: Request) -> str:
|
||||
"""从请求中获取基础 URL"""
|
||||
return f"{request.url.scheme}://{request.url.netloc}"
|
||||
|
||||
|
||||
def get_or_generate_user_id(payload_user_id: str, request: Request) -> str:
|
||||
"""获取或生成用户 ID
|
||||
|
||||
优先级:
|
||||
1. 使用前端传递的 user_id
|
||||
2. 基于 IP + User-Agent 生成唯一 ID
|
||||
|
||||
Args:
|
||||
payload_user_id: 前端传递的 user_id
|
||||
request: FastAPI Request 对象
|
||||
|
||||
Returns:
|
||||
用户 ID
|
||||
"""
|
||||
if payload_user_id:
|
||||
return payload_user_id
|
||||
|
||||
# 获取客户端 IP
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
# 获取 User-Agent
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
# 生成唯一 ID:基于 IP + User-Agent 的哈希
|
||||
unique_string = f"{client_ip}_{user_agent}"
|
||||
hash_value = hashlib.md5(unique_string.encode()).hexdigest()[:16]
|
||||
|
||||
return f"guest_{hash_value}"
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{share_token}/token",
|
||||
summary="获取访问 token"
|
||||
)
|
||||
def get_access_token(
|
||||
share_token: str,
|
||||
payload: release_share_schema.TokenRequest,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取访问 token
|
||||
|
||||
- 用户通过 user_id + share_token 换取访问 token
|
||||
- 后续请求需要携带此 token
|
||||
"""
|
||||
# 获取或生成 user_id
|
||||
user_id = get_or_generate_user_id(payload.user_id, request)
|
||||
|
||||
# 验证分享链接(可选:验证密码)
|
||||
service = ReleaseShareService(db)
|
||||
try:
|
||||
service.get_shared_release_info(
|
||||
share_token=share_token,
|
||||
password=payload.password
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取分享信息失败: {str(e)}")
|
||||
raise
|
||||
|
||||
# 生成 token
|
||||
access_token = create_access_token(user_id, share_token)
|
||||
|
||||
logger.info(
|
||||
f"生成访问 token",
|
||||
extra={
|
||||
"share_token": share_token,
|
||||
"user_id": user_id
|
||||
}
|
||||
)
|
||||
|
||||
return success(data={
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"user_id": user_id
|
||||
})
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
summary="获取公开分享的应用信息",
|
||||
response_model=None
|
||||
)
|
||||
def get_shared_release(
|
||||
password: str = Query(None, description="访问密码(如果需要)"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取公开分享的发布版本信息
|
||||
|
||||
- 无需认证即可访问
|
||||
- 如果设置了密码保护,需要提供正确的密码
|
||||
- 如果密码错误或未提供密码,返回基本信息(不含配置详情)
|
||||
"""
|
||||
service = ReleaseShareService(db)
|
||||
info = service.get_shared_release_info(
|
||||
share_token=share_data.share_token,
|
||||
password=password
|
||||
)
|
||||
|
||||
return success(data=info)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/verify",
|
||||
summary="验证访问密码"
|
||||
)
|
||||
def verify_password(
|
||||
payload: release_share_schema.PasswordVerifyRequest,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""验证分享的访问密码
|
||||
|
||||
- 用于前端先验证密码,再获取完整信息
|
||||
"""
|
||||
service = ReleaseShareService(db)
|
||||
is_valid = service.verify_password(
|
||||
share_token=share_data.share_token,
|
||||
password=payload.password
|
||||
)
|
||||
|
||||
return success(data={"valid": is_valid})
|
||||
|
||||
|
||||
@router.get(
|
||||
"/embed",
|
||||
summary="获取嵌入代码"
|
||||
)
|
||||
def get_embed_code(
|
||||
width: str = Query("100%", description="iframe 宽度"),
|
||||
height: str = Query("600px", description="iframe 高度"),
|
||||
request: Request = None,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取嵌入代码
|
||||
|
||||
- 返回 iframe 嵌入代码
|
||||
- 可以自定义宽度和高度
|
||||
"""
|
||||
base_url = get_base_url(request) if request else None
|
||||
|
||||
service = ReleaseShareService(db)
|
||||
embed_code = service.get_embed_code(
|
||||
share_token=share_data.share_token,
|
||||
width=width,
|
||||
height=height,
|
||||
base_url=base_url
|
||||
)
|
||||
|
||||
return success(data=embed_code)
|
||||
|
||||
|
||||
|
||||
# ---------- 会话管理接口 ----------
|
||||
|
||||
@router.get(
|
||||
"/conversations",
|
||||
summary="获取会话列表"
|
||||
)
|
||||
def list_conversations(
|
||||
password: str = Query(None, description="访问密码"),
|
||||
page: int = Query(1, ge=1),
|
||||
pagesize: int = Query(20, ge=1, le=100),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取分享应用的会话列表
|
||||
|
||||
- 可以按 user_id 筛选
|
||||
- 支持分页
|
||||
"""
|
||||
logger.debug(f"share_data:{share_data.user_id}")
|
||||
other_id = share_data.user_id
|
||||
service = SharedChatService(db)
|
||||
share, release = service._get_release_by_share_token(share_data.share_token, password)
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
other_id=other_id
|
||||
)
|
||||
logger.debug(new_end_user.id)
|
||||
service = SharedChatService(db)
|
||||
conversations, total = service.list_conversations(
|
||||
share_token=share_data.share_token,
|
||||
user_id=str(new_end_user.id),
|
||||
password=password,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
items = [conversation_schema.Conversation.model_validate(c) for c in conversations]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/conversations/{conversation_id}",
|
||||
summary="获取会话详情(含消息)"
|
||||
)
|
||||
def get_conversation(
|
||||
conversation_id: uuid.UUID,
|
||||
password: str = Query(None, description="访问密码"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取会话详情和消息历史"""
|
||||
chat_service = SharedChatService(db)
|
||||
conversation = chat_service.get_conversation_messages(
|
||||
share_token=share_data.share_token,
|
||||
conversation_id=conversation_id,
|
||||
password=password
|
||||
)
|
||||
|
||||
# 获取消息
|
||||
conv_service = ConversationService(db)
|
||||
messages = conv_service.get_messages(conversation_id)
|
||||
|
||||
# 构建响应
|
||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
|
||||
conv_dict["messages"] = [
|
||||
conversation_schema.Message.model_validate(m) for m in messages
|
||||
]
|
||||
|
||||
return success(data=conv_dict)
|
||||
|
||||
|
||||
# ---------- 聊天接口 ----------
|
||||
|
||||
@router.post(
|
||||
"/chat",
|
||||
summary="发送消息(支持流式和非流式)"
|
||||
)
|
||||
async def chat(
|
||||
payload: conversation_schema.ChatRequest,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""发送消息并获取回复
|
||||
|
||||
使用 Bearer token 认证:
|
||||
- Header: Authorization: Bearer {token}
|
||||
- user_id 和 share_token 从 token 中解码
|
||||
|
||||
- 支持多轮对话(提供 conversation_id)
|
||||
- 支持流式返回(设置 stream=true)
|
||||
- 如果不提供 conversation_id,会自动创建新会话
|
||||
"""
|
||||
service = SharedChatService(db)
|
||||
|
||||
# 从依赖中获取 user_id 和 share_token
|
||||
user_id = share_data.user_id
|
||||
share_token = share_data.share_token
|
||||
password = None # Token 认证不需要密码
|
||||
# end_user_id = user_id
|
||||
other_id = user_id
|
||||
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
|
||||
from app.models.app_model import AppType
|
||||
try:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.app_service import AppService
|
||||
# 验证分享链接和密码
|
||||
share, release = service._get_release_by_share_token(share_token, password)
|
||||
|
||||
# # Create end_user_id by concatenating app_id with user_id
|
||||
# end_user_id = f"{share.app_id}_{user_id}"
|
||||
|
||||
# Store end_user_id in database with original user_id
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
other_id=other_id,
|
||||
original_user_id=user_id # Save original user_id to other_id
|
||||
)
|
||||
|
||||
# 获取应用类型
|
||||
app_type = release.app.type if release.app else None
|
||||
|
||||
# 根据应用类型验证配置
|
||||
if app_type == "agent":
|
||||
# Agent 类型:验证模型配置
|
||||
model_config_id = release.default_model_config_id
|
||||
if not model_config_id:
|
||||
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app_type == "multi_agent":
|
||||
# Multi-Agent 类型:验证多 Agent 配置
|
||||
config = release.config or {}
|
||||
if not config.get("sub_agents"):
|
||||
raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING)
|
||||
else:
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
# 获取或创建会话(提前验证)
|
||||
conversation = service.create_or_get_conversation(
|
||||
share_token=share_data.share_token,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
password=password
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"参数验证完成",
|
||||
extra={
|
||||
"share_token": share_token,
|
||||
"app_type": app_type,
|
||||
"conversation_id": str(conversation.id),
|
||||
"stream": payload.stream
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 验证失败,直接抛出异常(会被 FastAPI 的异常处理器捕获)
|
||||
logger.error(f"参数验证失败: {str(e)}")
|
||||
raise
|
||||
|
||||
if app_type == AppType.AGENT:
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in service.chat_stream(
|
||||
share_token=share_token,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
# 非流式返回
|
||||
result = await service.chat(
|
||||
share_token=share_token,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# 多 Agent 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in service.multi_agent_chat_stream(
|
||||
share_token=share_token,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
result = await service.multi_agent_chat(
|
||||
share_token=share_token,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
pass
|
||||
170
api/app/controllers/release_share_controller.py
Normal file
170
api/app/controllers/release_share_controller.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.schemas import release_share_schema
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
|
||||
router = APIRouter(tags=["Release Share"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
def get_base_url(request: Request) -> str:
|
||||
"""从请求中获取基础 URL"""
|
||||
return f"{request.url.scheme}://{request.url.netloc}"
|
||||
|
||||
|
||||
@router.post(
|
||||
"/apps/{app_id}/releases/{release_id}/share",
|
||||
summary="创建/启用分享配置"
|
||||
)
|
||||
@cur_workspace_access_guard()
|
||||
def create_share(
|
||||
app_id: uuid.UUID,
|
||||
release_id: uuid.UUID,
|
||||
payload: release_share_schema.ReleaseShareCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""创建或更新发布版本的分享配置
|
||||
|
||||
- 如果已存在分享配置,则更新
|
||||
- 自动生成唯一的分享 token
|
||||
- 返回完整的分享 URL
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
base_url = get_base_url(request)
|
||||
|
||||
service = ReleaseShareService(db)
|
||||
share = service.create_or_update_share(
|
||||
release_id=release_id,
|
||||
user_id=current_user.id,
|
||||
workspace_id=workspace_id,
|
||||
data=payload,
|
||||
base_url=base_url
|
||||
)
|
||||
|
||||
share_schema = service._convert_to_schema(share, base_url)
|
||||
return success(data=share_schema, msg="分享配置已创建")
|
||||
|
||||
|
||||
@router.put(
|
||||
"/apps/{app_id}/releases/{release_id}/share",
|
||||
summary="更新分享配置"
|
||||
)
|
||||
@cur_workspace_access_guard()
|
||||
def update_share(
|
||||
app_id: uuid.UUID,
|
||||
release_id: uuid.UUID,
|
||||
payload: release_share_schema.ReleaseShareUpdate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""更新分享配置
|
||||
|
||||
- 可以更新启用状态、密码、嵌入设置等
|
||||
- 不会改变 share_token
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
base_url = get_base_url(request)
|
||||
|
||||
service = ReleaseShareService(db)
|
||||
share = service.update_share(
|
||||
release_id=release_id,
|
||||
workspace_id=workspace_id,
|
||||
data=payload
|
||||
)
|
||||
|
||||
share_schema = service._convert_to_schema(share, base_url)
|
||||
return success(data=share_schema, msg="分享配置已更新")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/apps/{app_id}/releases/{release_id}/share",
|
||||
summary="获取分享配置"
|
||||
)
|
||||
@cur_workspace_access_guard()
|
||||
def get_share(
|
||||
app_id: uuid.UUID,
|
||||
release_id: uuid.UUID,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取发布版本的分享配置
|
||||
|
||||
- 如果不存在分享配置,返回 null
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
base_url = get_base_url(request)
|
||||
|
||||
service = ReleaseShareService(db)
|
||||
share = service.get_share(
|
||||
release_id=release_id,
|
||||
workspace_id=workspace_id,
|
||||
base_url=base_url
|
||||
)
|
||||
|
||||
return success(data=share)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/apps/{app_id}/releases/{release_id}/share",
|
||||
summary="删除分享配置"
|
||||
)
|
||||
@cur_workspace_access_guard()
|
||||
def delete_share(
|
||||
app_id: uuid.UUID,
|
||||
release_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""删除分享配置
|
||||
|
||||
- 删除后,公开访问链接将失效
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
service = ReleaseShareService(db)
|
||||
service.delete_share(
|
||||
release_id=release_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return success(msg="分享配置已删除")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/apps/{app_id}/releases/{release_id}/share/regenerate-token",
|
||||
summary="重新生成分享链接"
|
||||
)
|
||||
@cur_workspace_access_guard()
|
||||
def regenerate_token(
|
||||
app_id: uuid.UUID,
|
||||
release_id: uuid.UUID,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""重新生成分享 token
|
||||
|
||||
- 旧的分享链接将失效
|
||||
- 生成新的唯一 token
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
base_url = get_base_url(request)
|
||||
|
||||
service = ReleaseShareService(db)
|
||||
share = service.regenerate_token(
|
||||
release_id=release_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
share_schema = service._convert_to_schema(share, base_url)
|
||||
return success(data=share_schema, msg="分享链接已重新生成")
|
||||
17
api/app/controllers/service/__init__.py
Normal file
17
api/app/controllers/service/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Service API Controllers - 基于 API Key 认证的服务接口
|
||||
|
||||
路由前缀: /v1
|
||||
认证方式: API Key
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
from . import app_api_controller, rag_api_controller, memory_api_controller
|
||||
|
||||
# 创建 V1 API 路由器
|
||||
service_router = APIRouter()
|
||||
|
||||
# 注册子路由
|
||||
service_router.include_router(app_api_controller.router)
|
||||
service_router.include_router(rag_api_controller.router)
|
||||
service_router.include_router(memory_api_controller.router)
|
||||
|
||||
__all__ = ["service_router"]
|
||||
16
api/app/controllers/service/app_api_controller.py
Normal file
16
api/app/controllers/service/app_api_controller.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""App 服务接口 - 基于 API Key 认证"""
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
router = APIRouter(prefix="/v1/apps", tags=["V1 - App API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_apps():
|
||||
"""列出可访问的应用(占位)"""
|
||||
return success(data=[], msg="App API - Coming Soon")
|
||||
16
api/app/controllers/service/memory_api_controller.py
Normal file
16
api/app/controllers/service/memory_api_controller.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Memory 服务接口 - 基于 API Key 认证"""
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_memory_info():
|
||||
"""获取记忆服务信息(占位)"""
|
||||
return success(data={}, msg="Memory API - Coming Soon")
|
||||
16
api/app/controllers/service/rag_api_controller.py
Normal file
16
api/app/controllers/service/rag_api_controller.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""RAG 服务接口 - 基于 API Key 认证"""
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
router = APIRouter(prefix="/knowledge", tags=["V1 - RAG API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_knowledge():
|
||||
"""列出可访问的知识库(占位)"""
|
||||
return success(data=[], msg="RAG API - Coming Soon")
|
||||
23
api/app/controllers/setup_controller.py
Normal file
23
api/app/controllers/setup_controller.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import user_service
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/setup",
|
||||
tags=["Setup"],
|
||||
)
|
||||
|
||||
@router.post("", summary="Create the first superuser", response_model=ApiResponse)
|
||||
def setup_initial_user(db: Session = Depends(get_db)):
|
||||
"""
|
||||
Create the initial superuser. This can only be run once.
|
||||
Reads credentials from environment variables.
|
||||
"""
|
||||
user = user_service.create_initial_superuser(db)
|
||||
if not user:
|
||||
return success(msg="Superuser already exists.")
|
||||
return success(msg="Superuser created successfully.")
|
||||
25
api/app/controllers/task_controller.py
Normal file
25
api/app/controllers/task_controller.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from fastapi import APIRouter, status
|
||||
from app.schemas.item_schema import Item
|
||||
from app.services import task_service
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/tasks",
|
||||
tags=["Tasks"],
|
||||
)
|
||||
|
||||
@router.post("/process_item", status_code=status.HTTP_202_ACCEPTED)
|
||||
def process_item_task(item: Item):
|
||||
"""
|
||||
This endpoint receives an item, and instead of processing it directly,
|
||||
it sends a task to the Celery queue via the task service.
|
||||
"""
|
||||
task_id = task_service.create_processing_task(item.dict())
|
||||
return {"message": "Task accepted. The item is being processed in the background.", "task_id": task_id}
|
||||
|
||||
@router.get("/result/{task_id}")
|
||||
def get_task_result_controller(task_id: str):
|
||||
"""
|
||||
This endpoint allows clients to check the status and result of a
|
||||
previously submitted task using its ID, by calling the task service.
|
||||
"""
|
||||
return task_service.get_task_result(task_id)
|
||||
126
api/app/controllers/test_controller.py
Normal file
126
api/app/controllers/test_controller.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from fastapi import APIRouter, Depends, status, Query, HTTPException
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
from app.core.models import RedBearLLM, RedBearRerank
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelApiKey, ModelProvider, ModelType
|
||||
from app.models.user_model import User
|
||||
from app.schemas import model_schema
|
||||
from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse, PageData
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/test",
|
||||
tags=["test"],
|
||||
)
|
||||
|
||||
|
||||
@router.get(f"/llm/{{model_id}}", response_model=ApiResponse)
|
||||
def test_llm(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
if not config:
|
||||
api_logger.error(f"模型ID {model_id} 不存在")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
try:
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
llm = RedBearLLM(RedBearModelConfig(
|
||||
model_name=apiConfig.model_name,
|
||||
provider=apiConfig.provider,
|
||||
api_key=apiConfig.api_key,
|
||||
base_url=apiConfig.api_base
|
||||
), type=config.type)
|
||||
print(llm.dict())
|
||||
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
# ChatPromptTemplate
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
chain = prompt | llm
|
||||
answer = chain.invoke({"question": "What is LangChain?"})
|
||||
print("Answer:", answer)
|
||||
return success(msg="测试LLM成功", data={"question": "What is LangChain?", "answer": answer})
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"测试LLM失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get(f"/embedding/{{model_id}}", response_model=ApiResponse)
|
||||
def test_embedding(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
if not config:
|
||||
api_logger.error(f"模型ID {model_id} 不存在")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
model = RedBearEmbeddings(RedBearModelConfig(
|
||||
model_name=apiConfig.model_name,
|
||||
provider=apiConfig.provider,
|
||||
api_key=apiConfig.api_key,
|
||||
base_url=apiConfig.api_base
|
||||
))
|
||||
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
embeddings = model.embed_documents(data)
|
||||
print(embeddings)
|
||||
query = "我想找一个适合学习的地方。"
|
||||
query_embedding = model.embed_query(query)
|
||||
print(query_embedding)
|
||||
|
||||
return success(msg="测试LLM成功")
|
||||
|
||||
|
||||
@router.get(f"/rerank/{{model_id}}", response_model=ApiResponse)
|
||||
def test_rerank(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
if not config:
|
||||
api_logger.error(f"模型ID {model_id} 不存在")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
model = RedBearRerank(RedBearModelConfig(
|
||||
model_name=apiConfig.model_name,
|
||||
provider=apiConfig.provider,
|
||||
api_key=apiConfig.api_key,
|
||||
base_url=apiConfig.api_base
|
||||
))
|
||||
query = "最近哪家咖啡店评价最好?"
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
scores = model.rerank(query=query, documents=data, top_n=3)
|
||||
print(scores)
|
||||
return success(msg="测试Rerank成功", data={"query": query, "documents": data, "scores": scores})
|
||||
376
api/app/controllers/upload_controller.py
Normal file
376
api/app/controllers/upload_controller.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""
|
||||
Upload Controller for Generic File Upload System
|
||||
Handles HTTP requests for file upload, download, deletion, and metadata updates.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
from typing import List, Optional, Any
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, File, UploadFile, Form
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.generic_file_schema import (
|
||||
GenericFileResponse,
|
||||
FileMetadataUpdate,
|
||||
UploadResultSchema,
|
||||
BatchUploadResponse
|
||||
)
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.upload_enums import UploadContext
|
||||
from app.services.upload_service import UploadService
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.exceptions import (
|
||||
ValidationException,
|
||||
ResourceNotFoundException,
|
||||
FileUploadException,
|
||||
BusinessException
|
||||
)
|
||||
|
||||
# Get logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Create router
|
||||
router = APIRouter(
|
||||
prefix="/api",
|
||||
tags=["upload"],
|
||||
dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
|
||||
# Initialize upload service
|
||||
upload_service = UploadService()
|
||||
|
||||
|
||||
@router.post("/upload", response_model=ApiResponse)
|
||||
async def upload_file(
|
||||
file: UploadFile = File(..., description="要上传的文件"),
|
||||
context: str = Form(..., description="上传上下文 (avatar, app_icon, knowledge_base, temp, attachment)"),
|
||||
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
单文件上传接口
|
||||
|
||||
- **file**: 要上传的文件
|
||||
- **context**: 上传上下文,决定文件存储位置和验证规则
|
||||
- **metadata**: 可选的文件元数据,JSON格式字符串
|
||||
|
||||
返回上传成功的文件信息
|
||||
"""
|
||||
logger.info(f"Upload request: filename={file.filename}, context={context}, user={current_user.id}")
|
||||
|
||||
try:
|
||||
# Validate and parse context
|
||||
try:
|
||||
upload_context = UploadContext(context)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid upload context: {context}")
|
||||
raise ValidationException(
|
||||
f"无效的上传上下文: {context}. 允许的值: {', '.join([c.value for c in UploadContext])}",
|
||||
field="context"
|
||||
)
|
||||
|
||||
# Parse metadata if provided
|
||||
file_metadata = {}
|
||||
if metadata:
|
||||
try:
|
||||
file_metadata = json.loads(metadata)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid metadata JSON: {metadata}")
|
||||
raise ValidationException(
|
||||
"元数据必须是有效的JSON格式",
|
||||
field="metadata"
|
||||
)
|
||||
|
||||
# Upload file
|
||||
db_file = upload_service.upload_file(
|
||||
file=file,
|
||||
context=upload_context,
|
||||
metadata=file_metadata,
|
||||
current_user=current_user,
|
||||
db=db
|
||||
)
|
||||
|
||||
# Convert to response schema
|
||||
file_response = GenericFileResponse.model_validate(db_file)
|
||||
|
||||
logger.info(f"Upload successful: {file.filename} (ID: {db_file.id})")
|
||||
return success(data=file_response.dict(), msg="文件上传成功")
|
||||
|
||||
except BusinessException:
|
||||
# Business exceptions are handled by global exception handlers
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Upload failed: {str(e)}")
|
||||
# Wrap unknown exceptions as FileUploadException
|
||||
raise FileUploadException(
|
||||
f"文件上传失败: {str(e)}",
|
||||
cause=e
|
||||
)
|
||||
|
||||
|
||||
@router.post("/upload/batch", response_model=ApiResponse)
|
||||
async def upload_files_batch(
|
||||
files: List[UploadFile] = File(..., description="要上传的文件列表"),
|
||||
context: str = Form(..., description="上传上下文 (avatar, app_icon, knowledge_base, temp, attachment)"),
|
||||
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
批量文件上传接口
|
||||
|
||||
- **files**: 要上传的文件列表(最多20个)
|
||||
- **context**: 上传上下文,决定文件存储位置和验证规则
|
||||
- **metadata**: 可选的文件元数据,JSON格式字符串,应用于所有文件
|
||||
|
||||
返回每个文件的上传结果
|
||||
"""
|
||||
logger.info(f"Batch upload request: {len(files)} files, context={context}, user={current_user.id}")
|
||||
|
||||
try:
|
||||
# Validate and parse context
|
||||
try:
|
||||
upload_context = UploadContext(context)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid upload context: {context}")
|
||||
raise ValidationException(
|
||||
f"无效的上传上下文: {context}. 允许的值: {', '.join([c.value for c in UploadContext])}",
|
||||
field="context"
|
||||
)
|
||||
|
||||
# Parse metadata if provided
|
||||
file_metadata = {}
|
||||
if metadata:
|
||||
try:
|
||||
file_metadata = json.loads(metadata)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid metadata JSON: {metadata}")
|
||||
raise ValidationException(
|
||||
"元数据必须是有效的JSON格式",
|
||||
field="metadata"
|
||||
)
|
||||
|
||||
# Upload files in batch
|
||||
upload_results = upload_service.upload_files_batch(
|
||||
files=files,
|
||||
context=upload_context,
|
||||
metadata=file_metadata,
|
||||
current_user=current_user,
|
||||
db=db
|
||||
)
|
||||
|
||||
# Convert results to response schemas
|
||||
result_schemas = []
|
||||
for result in upload_results:
|
||||
result_schema = UploadResultSchema(
|
||||
success=result.success,
|
||||
file_id=result.file_id,
|
||||
file_name=result.file_name,
|
||||
error=result.error,
|
||||
file_info=None
|
||||
)
|
||||
|
||||
# If upload was successful, get file info
|
||||
if result.success and result.file_id:
|
||||
try:
|
||||
db_file = upload_service.get_file(result.file_id, current_user, db)
|
||||
result_schema.file_info = GenericFileResponse.model_validate(db_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get file info for {result.file_id}: {str(e)}")
|
||||
|
||||
result_schemas.append(result_schema)
|
||||
|
||||
# Create batch response
|
||||
batch_response = BatchUploadResponse(
|
||||
total=len(files),
|
||||
success_count=sum(1 for r in upload_results if r.success),
|
||||
failed_count=sum(1 for r in upload_results if not r.success),
|
||||
results=result_schemas
|
||||
)
|
||||
|
||||
logger.info(f"Batch upload completed: {batch_response.success_count}/{batch_response.total} successful")
|
||||
return success(data=batch_response.dict(), msg="批量上传完成")
|
||||
|
||||
except BusinessException:
|
||||
# Business exceptions are handled by global exception handlers
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Batch upload failed: {str(e)}")
|
||||
# Wrap unknown exceptions as FileUploadException
|
||||
raise FileUploadException(
|
||||
f"批量上传失败: {str(e)}",
|
||||
cause=e
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}", response_model=Any)
|
||||
async def download_file(
|
||||
file_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> Any:
|
||||
"""
|
||||
文件下载接口
|
||||
|
||||
- **file_id**: 文件ID
|
||||
|
||||
返回文件内容供下载
|
||||
"""
|
||||
logger.info(f"Download request: file_id={file_id}, user={current_user.id}")
|
||||
|
||||
try:
|
||||
# Parse file_id
|
||||
import uuid
|
||||
try:
|
||||
file_uuid = uuid.UUID(file_id)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid file ID format: {file_id}")
|
||||
raise ValidationException(
|
||||
"无效的文件ID格式",
|
||||
field="file_id"
|
||||
)
|
||||
|
||||
# Get file from database
|
||||
db_file = upload_service.get_file(file_uuid, current_user, db)
|
||||
|
||||
# Check if physical file exists
|
||||
storage_path = Path(db_file.storage_path)
|
||||
if not storage_path.exists():
|
||||
logger.error(f"Physical file not found: {storage_path}")
|
||||
raise ResourceNotFoundException(
|
||||
"文件",
|
||||
str(file_uuid),
|
||||
context={"detail": "文件未找到(可能已被删除)"}
|
||||
)
|
||||
|
||||
# Return file response
|
||||
logger.info(f"Download successful: {db_file.file_name} (ID: {file_id})")
|
||||
return FileResponse(
|
||||
path=str(storage_path),
|
||||
filename=db_file.file_name,
|
||||
media_type=db_file.mime_type or "application/octet-stream"
|
||||
)
|
||||
|
||||
except BusinessException:
|
||||
# Business exceptions are handled by global exception handlers
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Download failed: {str(e)}")
|
||||
# Wrap unknown exceptions
|
||||
raise FileUploadException(
|
||||
f"文件下载失败: {str(e)}",
|
||||
cause=e
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/files/{file_id}", response_model=ApiResponse)
|
||||
async def delete_file(
|
||||
file_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
文件删除接口
|
||||
|
||||
- **file_id**: 文件ID
|
||||
|
||||
删除文件(包括物理文件和数据库记录)
|
||||
"""
|
||||
logger.info(f"Delete request: file_id={file_id}, user={current_user.id}")
|
||||
|
||||
try:
|
||||
# Parse file_id
|
||||
import uuid
|
||||
try:
|
||||
file_uuid = uuid.UUID(file_id)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid file ID format: {file_id}")
|
||||
raise ValidationException(
|
||||
"无效的文件ID格式",
|
||||
field="file_id"
|
||||
)
|
||||
|
||||
# Delete file
|
||||
upload_service.delete_file(file_uuid, current_user, db)
|
||||
|
||||
logger.info(f"Delete successful: file_id={file_id}")
|
||||
return success(msg="文件删除成功")
|
||||
|
||||
except BusinessException:
|
||||
# Business exceptions are handled by global exception handlers
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Delete failed: {str(e)}")
|
||||
# Wrap unknown exceptions
|
||||
raise FileUploadException(
|
||||
f"文件删除失败: {str(e)}",
|
||||
cause=e
|
||||
)
|
||||
|
||||
|
||||
@router.put("/files/{file_id}", response_model=ApiResponse)
|
||||
async def update_file_metadata(
|
||||
file_id: str,
|
||||
update_data: FileMetadataUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
文件元数据更新接口
|
||||
|
||||
- **file_id**: 文件ID
|
||||
- **update_data**: 要更新的元数据
|
||||
|
||||
更新文件的元数据(文件名、自定义元数据、公开状态)
|
||||
"""
|
||||
logger.info(f"Update metadata request: file_id={file_id}, user={current_user.id}")
|
||||
|
||||
try:
|
||||
# Parse file_id
|
||||
import uuid
|
||||
try:
|
||||
file_uuid = uuid.UUID(file_id)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid file ID format: {file_id}")
|
||||
raise ValidationException(
|
||||
"无效的文件ID格式",
|
||||
field="file_id"
|
||||
)
|
||||
|
||||
# Convert update data to dict, excluding unset fields
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
|
||||
if not update_dict:
|
||||
logger.warning(f"No fields to update for file: {file_id}")
|
||||
raise ValidationException(
|
||||
"没有提供要更新的字段",
|
||||
field="update_data"
|
||||
)
|
||||
|
||||
# Update file metadata
|
||||
updated_file = upload_service.update_file_metadata(
|
||||
file_uuid, update_dict, current_user, db
|
||||
)
|
||||
|
||||
# Convert to response schema
|
||||
file_response = GenericFileResponse.model_validate(updated_file)
|
||||
|
||||
logger.info(f"Update metadata successful: file_id={file_id}")
|
||||
return success(data=file_response.dict(), msg="文件元数据更新成功")
|
||||
|
||||
except BusinessException:
|
||||
# Business exceptions are handled by global exception handlers
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Update metadata failed: {str(e)}")
|
||||
# Wrap unknown exceptions
|
||||
raise FileUploadException(
|
||||
f"文件元数据更新失败: {str(e)}",
|
||||
cause=e
|
||||
)
|
||||
183
api/app/controllers/user_controller.py
Normal file
183
api/app/controllers/user_controller.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, get_current_superuser
|
||||
from app.models.user_model import User
|
||||
from app.schemas import user_schema
|
||||
from app.schemas.user_schema import ChangePasswordRequest, AdminChangePasswordRequest
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import user_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/users",
|
||||
tags=["Users"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/superuser", response_model=ApiResponse)
|
||||
def create_superuser(
|
||||
user: user_schema.UserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_superuser: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""创建超级管理员(仅超级管理员可访问)"""
|
||||
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
|
||||
|
||||
result = user_service.create_superuser(db=db, user=user, current_user=current_superuser)
|
||||
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="超级管理员创建成功")
|
||||
|
||||
|
||||
@router.delete("/{user_id}", response_model=ApiResponse)
|
||||
def delete_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""停用用户(软删除)"""
|
||||
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
result = user_service.deactivate_user(
|
||||
db=db, user_id_to_deactivate=user_id, current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
|
||||
return success(msg="用户停用成功")
|
||||
|
||||
@router.post("/{user_id}/activate", response_model=ApiResponse)
|
||||
def activate_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""激活用户"""
|
||||
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
|
||||
result = user_service.activate_user(
|
||||
db=db, user_id_to_activate=user_id, current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="用户激活成功")
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_current_user_info(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取当前用户信息"""
|
||||
api_logger.info(f"当前用户信息请求: {current_user.username}")
|
||||
|
||||
result = user_service.get_user(
|
||||
db=db, user_id=current_user.id, current_user=current_user
|
||||
)
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
|
||||
# 设置当前工作空间的角色和名称
|
||||
if current_user.current_workspace_id:
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
current_workspace = workspace_repo.get_workspace_by_id(current_user.current_workspace_id)
|
||||
if current_workspace:
|
||||
result_schema.current_workspace_name = current_workspace.name
|
||||
|
||||
for ws in result.workspaces:
|
||||
if ws.workspace_id == current_user.current_workspace_id:
|
||||
result_schema.role = ws.role
|
||||
break
|
||||
|
||||
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||
return success(data=result_schema, msg="用户信息获取成功")
|
||||
|
||||
|
||||
@router.get("/superusers", response_model=ApiResponse)
|
||||
def get_tenant_superusers(
|
||||
include_inactive: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
):
|
||||
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
|
||||
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
|
||||
|
||||
superusers = user_service.get_tenant_superusers(
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
include_inactive=include_inactive
|
||||
)
|
||||
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
|
||||
|
||||
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
|
||||
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=ApiResponse)
|
||||
def get_user_info_by_id(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""根据用户ID获取用户信息"""
|
||||
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
|
||||
result = user_service.get_user(
|
||||
db=db, user_id=user_id, current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户信息获取成功: {result.username}")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="用户信息获取成功")
|
||||
|
||||
|
||||
@router.put("/change-password", response_model=ApiResponse)
|
||||
async def change_password(
|
||||
request: ChangePasswordRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""修改当前用户密码"""
|
||||
api_logger.info(f"用户密码修改请求: {current_user.username}")
|
||||
|
||||
await user_service.change_password(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
old_password=request.old_password,
|
||||
new_password=request.new_password,
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户密码修改成功: {current_user.username}")
|
||||
return success(msg="密码修改成功")
|
||||
|
||||
|
||||
@router.put("/admin/change-password", response_model=ApiResponse)
|
||||
async def admin_change_password(
|
||||
request: AdminChangePasswordRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
):
|
||||
"""超级管理员修改指定用户的密码"""
|
||||
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
|
||||
|
||||
user, generated_password = await user_service.admin_change_password(
|
||||
db=db,
|
||||
target_user_id=request.user_id,
|
||||
new_password=request.new_password,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
# 根据是否生成了随机密码来构造响应
|
||||
if request.new_password:
|
||||
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
|
||||
return success(msg="密码修改成功")
|
||||
else:
|
||||
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
342
api/app/controllers/workspace_controller.py
Normal file
342
api/app/controllers/workspace_controller.py
Normal file
@@ -0,0 +1,342 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
|
||||
from app.models.user_model import User
|
||||
from app.models.tenant_model import Tenants
|
||||
from app.models.workspace_model import Workspace, InviteStatus
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.workspace_schema import (
|
||||
WorkspaceCreate, WorkspaceUpdate, WorkspaceResponse,
|
||||
WorkspaceInviteCreate, WorkspaceInviteResponse,
|
||||
InviteValidateResponse, InviteAcceptRequest,
|
||||
WorkspaceMemberUpdate, WorkspaceMemberItem
|
||||
)
|
||||
from app.schemas import knowledge_schema
|
||||
from app.services import workspace_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.services import knowledge_service, document_service
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
# 需要认证的路由器
|
||||
router = APIRouter(
|
||||
prefix="/workspaces",
|
||||
tags=["Workspaces"],
|
||||
dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
|
||||
# 公开路由器(不需要认证)
|
||||
public_router = APIRouter(
|
||||
prefix="/workspaces",
|
||||
tags=["Workspaces"]
|
||||
)
|
||||
|
||||
|
||||
def _convert_members_to_table_items(members):
|
||||
"""将工作空间成员列表转换为表格项"""
|
||||
return [
|
||||
WorkspaceMemberItem(
|
||||
id=m.id,
|
||||
username=m.user.username,
|
||||
account=m.user.email,
|
||||
role=m.role,
|
||||
last_login_at=m.user.last_login_at
|
||||
)
|
||||
for m in members
|
||||
]
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_workspaces(
|
||||
include_current: bool = Query(True, description="是否包含当前工作空间"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_tenant: Tenants = Depends(get_current_tenant)
|
||||
):
|
||||
"""获取当前租户下用户参与的所有工作空间
|
||||
|
||||
Args:
|
||||
include_current: 是否包含当前工作空间(默认 True)
|
||||
"""
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在租户 {current_tenant.name} 中请求获取工作空间列表",
|
||||
extra={"include_current": include_current}
|
||||
)
|
||||
|
||||
workspaces = workspace_service.get_user_workspaces(db, current_user)
|
||||
|
||||
# 如果不包含当前工作空间,则过滤掉
|
||||
if not include_current and current_user.current_workspace_id:
|
||||
workspaces = [w for w in workspaces if w.id != current_user.current_workspace_id]
|
||||
api_logger.debug(
|
||||
f"过滤掉当前工作空间",
|
||||
extra={"current_workspace_id": str(current_user.current_workspace_id)}
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
|
||||
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
|
||||
return success(data=workspaces_schema, msg="工作空间列表获取成功")
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
def create_workspace(
|
||||
workspace: WorkspaceCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
):
|
||||
"""创建新的工作空间"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}")
|
||||
|
||||
result = workspace_service.create_workspace(
|
||||
db=db, workspace=workspace, user=current_user)
|
||||
|
||||
api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}")
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间创建成功")
|
||||
|
||||
@router.put("", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def update_workspace(
|
||||
workspace: WorkspaceUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""更新工作空间"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 ID: {workspace_id}")
|
||||
|
||||
result = workspace_service.update_workspace(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
workspace_in=workspace,
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间更新成功")
|
||||
|
||||
@router.get("/members", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_cur_workspace_members(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取工作空间成员列表(关系序列化)"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
|
||||
|
||||
members = workspace_service.get_workspace_members(
|
||||
db=db,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
|
||||
table_items = _convert_members_to_table_items(members)
|
||||
return success(data=table_items, msg="工作空间成员列表获取成功")
|
||||
|
||||
|
||||
@router.put("/members", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def update_workspace_members(
|
||||
|
||||
updates: List[WorkspaceMemberUpdate],
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
|
||||
members = workspace_service.update_workspace_member_roles(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
updates=updates,
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
|
||||
return success(msg="成员角色更新成功")
|
||||
|
||||
|
||||
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def delete_workspace_member(
|
||||
member_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
|
||||
workspace_service.delete_workspace_member(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
member_id=member_id,
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
|
||||
return success(msg="成员删除成功")
|
||||
|
||||
|
||||
# 创建空间协作邀请
|
||||
@router.post("/invites", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def create_workspace_invite(
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""创建工作空间邀请"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求为工作空间 {workspace_id} 创建邀请: {invite_data.email}")
|
||||
|
||||
result = workspace_service.create_workspace_invite(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
invite_data=invite_data,
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
|
||||
return success(data=result, msg="邀请创建成功")
|
||||
|
||||
|
||||
@router.get("/invites", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_workspace_invites(
|
||||
|
||||
status_filter: Optional[InviteStatus] = Query(None, alias="status"),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取工作空间邀请列表"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的邀请列表")
|
||||
|
||||
invites = workspace_service.get_workspace_invites(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user,
|
||||
status=status_filter,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
|
||||
return success(data=invites, msg="邀请列表获取成功")
|
||||
|
||||
|
||||
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
|
||||
def get_workspace_invite_info(
|
||||
token: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取工作空间邀请用户信息(无需认证)"""
|
||||
result = workspace_service.validate_invite_token(db=db, token=token)
|
||||
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
|
||||
return success(data=result, msg="邀请验证成功")
|
||||
|
||||
|
||||
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def revoke_workspace_invite(
|
||||
|
||||
invite_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""撤销工作空间邀请"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求撤销工作空间 {workspace_id} 的邀请 {invite_id}")
|
||||
|
||||
result = workspace_service.revoke_workspace_invite(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
invite_id=invite_id,
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
|
||||
return success(data=result, msg="邀请撤销成功")
|
||||
|
||||
# ==================== 公开邀请接口(无需认证) ====================
|
||||
|
||||
# # 创建一个新的路由器用于公开接口
|
||||
# public_router = APIRouter(
|
||||
# prefix="/invites",
|
||||
# tags=["Public Invites"]
|
||||
# )
|
||||
|
||||
# @public_router.get("/validate", response_model=ApiResponse)
|
||||
# def validate_invite_token(
|
||||
# token: str = Query(..., description="邀请令牌"),
|
||||
# db: Session = Depends(get_db),
|
||||
# ):
|
||||
# """验证邀请令牌(公开接口)"""
|
||||
# api_logger.info(f"验证邀请令牌请求")
|
||||
@router.put("/{workspace_id}/switch", response_model=ApiResponse)
|
||||
@workspace_access_guard()
|
||||
def switch_workspace(
|
||||
workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""切换工作空间"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
|
||||
|
||||
workspace_service.switch_workspace(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"成功切换工作空间为 {workspace_id}")
|
||||
return success(msg="工作空间切换成功")
|
||||
|
||||
|
||||
@router.get("/storage", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_workspace_storage_type(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取当前工作空间的存储类型"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的存储类型")
|
||||
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
|
||||
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
|
||||
|
||||
|
||||
@router.get("/workspace_models", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def workspace_models_configs(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的模型配置")
|
||||
|
||||
configs = workspace_service.get_workspace_models_configs(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
|
||||
if configs is None:
|
||||
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作空间不存在或无权访问"
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||
)
|
||||
return success(data=configs, msg="模型配置获取成功")
|
||||
|
||||
Reference in New Issue
Block a user