Initial commit

This commit is contained in:
Ke Sun
2025-11-30 18:22:17 +08:00
commit aea2fe391e
449 changed files with 83030 additions and 0 deletions

View File

@@ -0,0 +1,60 @@
"""管理端接口 - 基于 JWT 认证
路由前缀: /
认证方式: JWT Token
"""
from fastapi import APIRouter
from . import (
model_controller,
task_controller,
test_controller,
user_controller,
auth_controller,
workspace_controller,
setup_controller,
file_controller,
document_controller,
knowledge_controller,
chunk_controller,
knowledgeshare_controller,
app_controller,
upload_controller,
memory_agent_controller,
memory_dashboard_controller,
memory_storage_controller,
memory_dashboard_controller,
api_key_controller,
release_share_controller,
public_share_controller,
multi_agent_controller,
)
# 创建管理端 API 路由器
manager_router = APIRouter()
# 注册所有管理端路由
manager_router.include_router(task_controller.router)
manager_router.include_router(user_controller.router)
manager_router.include_router(auth_controller.router)
manager_router.include_router(workspace_controller.router)
manager_router.include_router(workspace_controller.public_router) # 公开路由(无需认证)
manager_router.include_router(setup_controller.router)
manager_router.include_router(model_controller.router)
manager_router.include_router(file_controller.router)
manager_router.include_router(document_controller.router)
manager_router.include_router(knowledge_controller.router)
manager_router.include_router(chunk_controller.router)
manager_router.include_router(test_controller.router)
manager_router.include_router(knowledgeshare_controller.router)
manager_router.include_router(app_controller.router)
manager_router.include_router(upload_controller.router)
manager_router.include_router(memory_agent_controller.router)
manager_router.include_router(memory_dashboard_controller.router)
manager_router.include_router(memory_storage_controller.router)
manager_router.include_router(api_key_controller.router)
manager_router.include_router(release_share_controller.router)
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
manager_router.include_router(memory_dashboard_controller.router)
manager_router.include_router(multi_agent_controller.router)
__all__ = ["manager_router"]

View File

@@ -0,0 +1,151 @@
"""API Key 管理接口 - 基于 JWT 认证"""
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
import uuid
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models.user_model import User
from app.core.response_utils import success
from app.schemas import api_key_schema
from app.schemas.response_schema import ApiResponse
from app.services.api_key_service import ApiKeyService
from app.core.logging_config import get_business_logger
router = APIRouter(prefix="/apikeys", tags=["API Keys"])
logger = get_business_logger()
@router.post("", response_model=ApiResponse)
@cur_workspace_access_guard()
def create_api_key(
data: api_key_schema.ApiKeyCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建 API Key
- 创建后返回明文 API Key仅此一次
- 支持设置权限范围、速率限制、配额等
"""
workspace_id = current_user.current_workspace_id
api_key_obj, api_key = ApiKeyService.create_api_key(
db,
workspace_id=workspace_id,
user_id=current_user.id,
data=data
)
# 返回包含明文 Key 的响应
response_data = api_key_schema.ApiKeyResponse(
**api_key_obj.__dict__,
api_key=api_key
)
return success(data=response_data, msg="API Key 创建成功")
@router.get("", response_model=ApiResponse)
@cur_workspace_access_guard()
def list_api_keys(
type: api_key_schema.ApiKeyType = Query(None),
is_active: bool = Query(None),
resource_id: uuid.UUID = Query(None),
page: int = Query(1, ge=1),
pagesize: int = Query(10, ge=1, le=100),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""列出 API Keys"""
workspace_id = current_user.current_workspace_id
query = api_key_schema.ApiKeyQuery(
type=type,
is_active=is_active,
resource_id=resource_id,
page=page,
pagesize=pagesize
)
result = ApiKeyService.list_api_keys(db, workspace_id, query)
return success(data=result)
@router.get("/{api_key_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
def get_api_key(
api_key_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取 API Key 详情"""
workspace_id = current_user.current_workspace_id
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
return success(data=api_key_schema.ApiKey.model_validate(api_key))
@router.put("/{api_key_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
def update_api_key(
api_key_id: uuid.UUID,
data: api_key_schema.ApiKeyUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新 API Key"""
workspace_id = current_user.current_workspace_id
api_key = ApiKeyService.update_api_key(db, api_key_id, workspace_id, data)
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
@router.delete("/{api_key_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
def delete_api_key(
api_key_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除 API Key"""
workspace_id = current_user.current_workspace_id
ApiKeyService.delete_api_key(db, api_key_id, workspace_id)
return success(msg="API Key 删除成功")
@router.post("/{api_key_id}/regenerate", response_model=ApiResponse)
@cur_workspace_access_guard()
def regenerate_api_key(
api_key_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""重新生成 API Key
- 生成新的 API Key 并返回明文(仅此一次)
- 旧的 API Key 立即失效
"""
workspace_id = current_user.current_workspace_id
api_key_obj, api_key = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
# 返回包含明文 Key 的响应
response_data = api_key_schema.ApiKeyResponse(
**api_key_obj.__dict__,
api_key=api_key
)
return success(data=response_data, msg="API Key 重新生成成功")
@router.get("/{api_key_id}/stats", response_model=ApiResponse)
@cur_workspace_access_guard()
def get_api_key_stats(
api_key_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取 API Key 使用统计"""
workspace_id = current_user.current_workspace_id
stats = ApiKeyService.get_stats(db, api_key_id, workspace_id)
return success(data=stats)

View File

@@ -0,0 +1,716 @@
import uuid
from typing import Optional
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
from app.models import User
from app.repositories import knowledge_repository
from app.schemas import app_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services import app_service, workspace_service
from app.services.app_service import AppService
from app.services.agent_config_helper import enrich_agent_config
from app.dependencies import get_current_user, cur_workspace_access_guard, workspace_access_guard
from fastapi.responses import StreamingResponse
from app.models.app_model import AppType
from app.core.error_codes import BizCode
router = APIRouter(prefix="/apps", tags=["Apps"])
logger = get_business_logger()
@router.post("", summary="创建应用(可选创建 Agent 配置)")
@cur_workspace_access_guard()
def create_app(
payload: app_schema.AppCreate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
app = app_service.create_app(db, user_id=current_user.id, workspace_id=workspace_id, data=payload)
return success(data=app_schema.App.model_validate(app))
@router.get("", summary="应用列表(分页)")
@cur_workspace_access_guard()
def list_apps(
type: str | None = None,
visibility: str | None = None,
status: str | None = None,
search: str | None = None,
include_shared: bool = True,
page: int = 1,
pagesize: int = 10,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""列出应用
- 默认包含本工作空间的应用和分享给本工作空间的应用
- 设置 include_shared=false 可以只查看本工作空间的应用
"""
workspace_id = current_user.current_workspace_id
items_orm, total = app_service.list_apps(
db,
workspace_id=workspace_id,
type=type,
visibility=visibility,
status=status,
search=search,
include_shared=include_shared,
page=page,
pagesize=pagesize,
)
# 使用 AppService 的转换方法来设置 is_shared 字段
service = app_service.AppService(db)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
return success(data=PageData(page=meta, items=items))
@router.get("/{app_id}", summary="获取应用详情")
@cur_workspace_access_guard()
def get_app(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""获取应用详细信息
- 支持获取本工作空间的应用
- 支持获取分享给本工作空间的应用
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
app = service.get_app(app_id, workspace_id)
# 转换为 Schema 并设置 is_shared 字段
app_schema_obj = service._convert_to_schema(app, workspace_id)
return success(data=app_schema_obj)
@router.put("/{app_id}", summary="更新应用基本信息")
@cur_workspace_access_guard()
def update_app(
app_id: uuid.UUID,
payload: app_schema.AppUpdate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
app = app_service.update_app(db, app_id=app_id, data=payload, workspace_id=workspace_id)
return success(data=app_schema.App.model_validate(app))
@router.delete("/{app_id}", summary="删除应用")
@cur_workspace_access_guard()
def delete_app(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""删除应用
会级联删除:
- Agent 配置
- 发布版本
- 会话和消息
"""
workspace_id = current_user.current_workspace_id
logger.info(
f"用户请求删除应用",
extra={
"app_id": str(app_id),
"user_id": str(current_user.id),
"workspace_id": str(workspace_id)
}
)
app_service.delete_app(db, app_id=app_id, workspace_id=workspace_id)
return success(msg="应用删除成功")
@router.post("/{app_id}/copy", summary="复制应用")
@cur_workspace_access_guard()
def copy_app(
app_id: uuid.UUID,
new_name: Optional[str] = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""复制应用(包括基础信息和配置)
- 复制应用的基础信息(名称、描述、图标等)
- 复制 Agent 配置(如果是 agent 类型)
- 新应用默认为草稿状态
- 不影响原应用
"""
workspace_id = current_user.current_workspace_id
logger.info(
f"用户请求复制应用",
extra={
"source_app_id": str(app_id),
"user_id": str(current_user.id),
"workspace_id": str(workspace_id),
"new_name": new_name
}
)
service = AppService(db)
new_app = service.copy_app(
app_id=app_id,
user_id=current_user.id,
workspace_id=workspace_id,
new_name=new_name
)
return success(data=app_schema.App.model_validate(new_app), msg="应用复制成功")
@router.put("/{app_id}/config", summary="更新 Agent 配置")
@cur_workspace_access_guard()
def update_agent_config(
app_id: uuid.UUID,
payload: app_schema.AgentConfigUpdate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
cfg = app_service.update_agent_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
cfg = enrich_agent_config(cfg)
return success(data=app_schema.AgentConfig.model_validate(cfg))
@router.get("/{app_id}/config", summary="获取 Agent 配置")
@cur_workspace_access_guard()
def get_agent_config(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
# 配置总是存在(不存在时返回默认模板)
cfg = enrich_agent_config(cfg)
return success(data=app_schema.AgentConfig.model_validate(cfg))
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
@cur_workspace_access_guard()
def publish_app(
app_id: uuid.UUID,
payload: app_schema.PublishRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
release = app_service.publish(
db,
app_id=app_id,
publisher_id=current_user.id,
workspace_id=workspace_id,
version_name = payload.version_name,
release_notes=payload.release_notes
)
return success(data=app_schema.AppRelease.model_validate(release))
@router.get("/{app_id}/release", summary="获取当前发布版本")
@cur_workspace_access_guard()
def get_current_release(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
release = app_service.get_current_release(db, app_id=app_id, workspace_id=workspace_id)
if not release:
return success(data=None)
return success(data=app_schema.AppRelease.model_validate(release))
@router.get("/{app_id}/releases", summary="列出历史发布版本(倒序)")
@cur_workspace_access_guard()
def list_releases(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
releases = app_service.list_releases(db, app_id=app_id, workspace_id=workspace_id)
data = [app_schema.AppRelease.model_validate(r) for r in releases]
return success(data=data)
@router.post("/{app_id}/rollback/{version}", summary="回滚到指定版本")
@cur_workspace_access_guard()
def rollback(
app_id: uuid.UUID,
version: int,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
release = app_service.rollback(db, app_id=app_id, version=version, workspace_id=workspace_id)
return success(data=app_schema.AppRelease.model_validate(release))
@router.post("/{app_id}/share", summary="分享应用到其他工作空间")
@cur_workspace_access_guard()
def share_app(
app_id: uuid.UUID,
payload: app_schema.AppShareCreate,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""分享应用到其他工作空间
- 只能分享自己工作空间的应用
- 不能分享到自己的工作空间
- 同一个应用不能重复分享到同一个工作空间
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
shares = service.share_app(
app_id=app_id,
target_workspace_ids=payload.target_workspace_ids,
user_id=current_user.id,
workspace_id=workspace_id
)
data = [app_schema.AppShare.model_validate(s) for s in shares]
return success(data=data, msg=f"应用已分享到 {len(shares)} 个工作空间")
@router.delete("/{app_id}/share/{target_workspace_id}", summary="取消应用分享")
@cur_workspace_access_guard()
def unshare_app(
app_id: uuid.UUID,
target_workspace_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""取消应用分享
- 只能取消自己工作空间应用的分享
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
service.unshare_app(
app_id=app_id,
target_workspace_id=target_workspace_id,
workspace_id=workspace_id
)
return success(msg="应用分享已取消")
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
@cur_workspace_access_guard()
def list_app_shares(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""列出应用的所有分享记录
- 只能查看自己工作空间应用的分享记录
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
shares = service.list_app_shares(
app_id=app_id,
workspace_id=workspace_id
)
data = [app_schema.AppShare.model_validate(s) for s in shares]
return success(data=data)
@router.post("/{app_id}/draft/run", summary="试运行 Agent使用当前草稿配置")
@cur_workspace_access_guard()
async def draft_run(
app_id: uuid.UUID,
payload: app_schema.DraftRunRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
试运行 Agent使用当前的草稿配置未发布的配置
- 不需要发布应用即可测试
- 使用当前的 AgentConfig 配置
- 支持流式和非流式返回
"""
workspace_id = current_user.current_workspace_id
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
# 提前验证和准备(在流式响应开始前完成)
from app.services.app_service import AppService
from app.services.multi_agent_service import MultiAgentService
from app.models import AgentConfig, ModelConfig
from sqlalchemy import select
from app.core.exceptions import BusinessException
service = AppService(db)
# 1. 验证应用
app = service._get_app_or_404(app_id)
if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT:
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
service._validate_app_accessible(app, workspace_id)
if app.type == AppType.AGENT:
service._check_agent_config(app_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 获取模型配置
model_config = None
if agent_cfg.default_model_config_id:
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
if not model_config:
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
# 流式返回
if payload.stream:
async def event_generator():
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
async for event in draft_service.run_stream(
agent_config=agent_cfg,
model_config=model_config,
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 非流式返回
logger.debug(
f"开始非流式试运行",
extra={
"app_id": str(app_id),
"message_length": len(payload.message),
"has_conversation_id": bool(payload.conversation_id),
"has_variables": bool(payload.variables)
}
)
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
result = await draft_service.run(
agent_config=agent_cfg,
model_config=model_config,
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
logger.debug(
f"试运行返回结果",
extra={
"result_type": str(type(result)),
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict"
}
)
# 验证结果
try:
validated_result = app_schema.DraftRunResponse.model_validate(result)
logger.debug(f"结果验证成功")
return success(data=validated_result)
except Exception as e:
logger.error(
f"结果验证失败",
extra={
"error": str(e),
"error_type": str(type(e)),
"result": str(result)[:200]
}
)
raise
elif app.type == AppType.MULTI_AGENT:
# 1. 检查多智能体配置完整性
service._check_multi_agent_config(app_id)
# 2. 构建多智能体运行请求
from app.schemas.multi_agent_schema import MultiAgentRunRequest
multi_agent_request = MultiAgentRunRequest(
message=payload.message,
conversation_id=payload.conversation_id,
user_id=payload.user_id,
variables=payload.variables or {},
use_llm_routing=True # 默认启用 LLM 路由
)
# 3. 流式返回
if payload.stream:
logger.debug(
f"开始多智能体流式试运行",
extra={
"app_id": str(app_id),
"message_length": len(payload.message),
"has_conversation_id": bool(payload.conversation_id)
}
)
async def event_generator():
"""多智能体流式事件生成器"""
multiservice = MultiAgentService(db)
# 调用多智能体服务的流式方法
async for event in multiservice.run_stream(
app_id=app_id,
request=multi_agent_request,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 4. 非流式返回
logger.debug(
f"开始多智能体非流式试运行",
extra={
"app_id": str(app_id),
"message_length": len(payload.message),
"has_conversation_id": bool(payload.conversation_id)
}
)
multiservice = MultiAgentService(db)
result = await multiservice.run(app_id, multi_agent_request)
logger.debug(
f"多智能体试运行返回结果",
extra={
"result_type": str(type(result)),
"has_response": "response" in result if isinstance(result, dict) else False
}
)
return success(
data=result,
msg="多 Agent 任务执行成功"
)
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")
@cur_workspace_access_guard()
async def draft_run_compare(
app_id: uuid.UUID,
payload: app_schema.DraftRunCompareRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
多模型对比试运行
- 支持对比 1-5 个模型
- 可以是不同的模型,也可以是同一模型的不同参数配置
- 通过 model_parameters 覆盖默认参数
- 支持并行或串行执行(非流式)
- 支持流式返回(串行执行)
- 返回每个模型的运行结果和性能对比
使用场景:
1. 对比不同模型的效果GPT-4 vs Claude vs Gemini
2. 调优模型参数(不同 temperature 的效果对比)
3. 性能和成本分析
"""
workspace_id = current_user.current_workspace_id
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
logger.info(
f"多模型对比试运行",
extra={
"app_id": str(app_id),
"model_count": len(payload.models),
"parallel": payload.parallel,
"stream": payload.stream
}
)
# 提前验证和准备(在流式响应开始前完成)
from app.services.app_service import AppService
from app.models import ModelConfig
service = AppService(db)
# 1. 验证应用和权限
app = service._get_app_or_404(app_id)
if app.type != "agent":
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
service._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
from sqlalchemy import select
from app.models import AgentConfig
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = db.scalars(stmt).first()
if not agent_cfg:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 验证所有模型配置
model_configs = []
for model_item in payload.models:
model_config = db.get(ModelConfig, model_item.model_config_id)
if not model_config:
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
merged_parameters = {
**(agent_cfg.model_parameters or {}),
**(model_item.model_parameters or {})
}
model_configs.append({
"model_config": model_config,
"parameters": merged_parameters,
"label": model_item.label or model_config.name,
"model_config_id": model_item.model_config_id,
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
})
# 流式返回
if payload.stream:
async def event_generator():
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
async for event in draft_service.run_compare_stream(
agent_config=agent_cfg,
models=model_configs,
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
web_search=True,
memory=True,
parallel=payload.parallel,
timeout=payload.timeout or 60
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 非流式返回
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
result = await draft_service.run_compare(
agent_config=agent_cfg,
models=model_configs,
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
web_search=True,
memory=True,
parallel=payload.parallel,
timeout=payload.timeout or 60
)
logger.info(
f"多模型对比完成",
extra={
"app_id": str(app_id),
"successful": result["successful_count"],
"failed": result["failed_count"]
}
)
return success(data=app_schema.DraftRunCompareResponse(**result))

View File

@@ -0,0 +1,195 @@
from datetime import datetime, timedelta, timezone
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.response_utils import success
from app.db import get_db
from app.schemas.response_schema import ApiResponse
from app.schemas.token_schema import Token, RefreshTokenRequest, TokenRequest
from app.schemas.workspace_schema import InviteAcceptRequest
from app.services import auth_service, user_service, workspace_service
from app.core import security
from app.core.config import settings
from app.services.session_service import SessionService
from app.core.logging_config import get_auth_logger, get_security_logger
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.dependencies import get_current_user, oauth2_scheme
from app.models.user_model import User
# 获取专用日志器
auth_logger = get_auth_logger()
security_logger = get_security_logger()
router = APIRouter(tags=["Authentication"])
@router.post("/token", response_model=ApiResponse)
async def login_for_access_token(
form_data: TokenRequest,
db: Session = Depends(get_db)
):
"""用户登录获取token"""
auth_logger.info(f"用户登录请求: {form_data.email}")
# 验证邀请码(如果提供)
invite_info = None
# 验证用户凭据或注册新用户
user = None
if form_data.invite:
auth_logger.info(f"检测到邀请码: {form_data.invite[:8]}...")
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
if not invite_info.is_valid:
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
if invite_info.email != form_data.email:
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
try:
# 尝试认证用户
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
if form_data.invite:
auth_service.bind_workspace_with_invite(db=db,
user=user,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id)
except BusinessException as e:
# 用户不存在且有邀请码,尝试注册
if e.code == BizCode.USER_NOT_FOUND:
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
user = auth_service.register_user_with_invite(
db=db,
email=form_data.email,
password=form_data.password,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id
)
elif e.code == BizCode.PASSWORD_ERROR:
# 用户存在但密码错误
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
else:
# 其他认证失败情况,直接抛出
raise
else:
try:
# 尝试认证用户
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
except BusinessException as e:
# 其他认证失败情况,直接抛出
raise BusinessException(e.message,BizCode.LOGIN_FAILED)
# 创建 tokens
access_token, access_token_id = security.create_access_token(subject=user.id)
refresh_token, refresh_token_id = security.create_refresh_token(subject=user.id)
# 计算过期时间
access_expires_at = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
# 单点登录会话管理
if settings.ENABLE_SINGLE_SESSION:
await SessionService.invalidate_old_session(user.id, access_token_id)
await SessionService.set_user_active_session(user.id, access_token_id, access_expires_at)
# 更新最后登录时间
user_service.update_last_login_time(db, user.id)
auth_logger.info(f"用户 {user.username} 登录成功")
return success(
data=Token(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
expires_at=access_expires_at,
refresh_expires_at=refresh_expires_at
),
msg="登录成功"
)
@router.post("/refresh", response_model=ApiResponse)
async def refresh_token(
refresh_request: RefreshTokenRequest,
db: Session = Depends(get_db)
):
"""刷新token"""
auth_logger.info("收到token刷新请求")
# 验证 refresh token
userId = security.verify_token(refresh_request.refresh_token, "refresh")
if not userId:
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
# 检查用户是否存在
user = auth_service.get_user_by_id(db, userId)
if not user:
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
# 检查 refresh token 黑名单
if settings.ENABLE_SINGLE_SESSION:
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
# 生成新 tokens
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
new_refresh_token, new_refresh_token_id = security.create_refresh_token(subject=user.id)
# 计算过期时间
access_expires_at = datetime.now() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_expires_at = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
# 单点登录会话管理
if settings.ENABLE_SINGLE_SESSION:
# 将旧 refresh token 加入黑名单
old_refresh_token_id = security.get_token_id(refresh_request.refresh_token)
if old_refresh_token_id:
await SessionService.blacklist_token(old_refresh_token_id)
# 更新会话
await SessionService.invalidate_old_session(user.id, new_access_token_id)
await SessionService.set_user_active_session(user.id, new_access_token_id, access_expires_at)
auth_logger.info(f"用户 {user.id} token刷新成功")
return success(
data=Token(
access_token=new_access_token,
refresh_token=new_refresh_token,
token_type="bearer",
expires_at=access_expires_at,
refresh_expires_at=refresh_expires_at
),
msg="token刷新成功"
)
@router.post("/logout", response_model=ApiResponse)
async def logout(
token: str = Depends(oauth2_scheme),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""登出当前用户加入token黑名单并清理会话"""
auth_logger.info(f"用户 {current_user.username} 请求登出")
token_id = security.get_token_id(token)
if not token_id:
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
# 加入黑名单
await SessionService.blacklist_token(token_id)
# 清理会话
if settings.ENABLE_SINGLE_SESSION:
await SessionService.clear_user_session(current_user.username)
auth_logger.info(f"用户 {current_user.username} 登出成功")
return success(msg="登出成功")

View File

@@ -0,0 +1,447 @@
import os
from typing import Any, Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.core.config import settings
from app.db import get_db
from app.core.rag.llm.cv_model import QWenCV
from app.dependencies import get_current_user
from app.models.user_model import User
from app.models.document_model import Document
from app.models import knowledge_model, knowledgeshare_model
from app.core.rag.models.chunk import DocumentChunk
from app.schemas import chunk_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/chunks",
tags=["chunks"],
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
)
@router.get("/{kb_id}/{document_id}/previewchunks", response_model=ApiResponse)
async def get_preview_chunks(
kb_id: uuid.UUID,
document_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Paged query document block preview list
- Support filtering by document_id
- Support keyword search for segmented content
- Return paging metadata + file list
"""
api_logger.info(f"Paged query document block preview list: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Obtain knowledge base information
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
# 3. Check if the document exists
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
if not db_document:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document does not exist or you do not have permission to access it"
)
# 4. Check if the file exists
db_file = file_service.get_file_by_id(db, file_id=db_document.file_id)
if not db_file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 5. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
file_path = os.path.join(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 6. Check if the file exists
if not os.path.exists(file_path):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)"
)
# 7. Document parsing & segmentation
def progress_callback(prog=None, msg=None):
print(f"prog: {prog} msg: {msg}\n")
# Prepare to configure vision_model information
vision_model = QWenCV(
key=db_knowledge.image2text.api_keys[0].api_key,
model_name=db_knowledge.image2text.api_keys[0].model_name,
lang="Chinese", # Default to Chinese
base_url=db_knowledge.image2text.api_keys[0].api_base
)
from app.core.rag.app.naive import chunk
res = chunk(filename=file_path,
from_page=0,
to_page=5,
callback=progress_callback,
vision_model=vision_model,
parser_config=db_document.parser_config,
is_root=False)
start_index = (page - 1) * pagesize
end_index = start_index + pagesize
# Use slicing to obtain the data of the current page
paginated_chunk_str_list = res[start_index:end_index]
chunks = []
for idx, item in enumerate(paginated_chunk_str_list):
metadata = {
"doc_id": uuid.uuid4().hex,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(db_document.id),
"knowledge_id": str(db_document.kb_id),
"sort_id": idx,
"status": 1,
}
chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata))
# 8. Return structured response
total = len(res)
result = {
"items": chunks,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
api_logger.info(f"Querying the document block preview list successful: total={total}, returned={len(chunks)} records")
return success(data=result, msg="Querying the document block preview list succeeded")
@router.get("/{kb_id}/{document_id}/chunks", response_model=ApiResponse)
async def get_chunks(
kb_id: uuid.UUID,
document_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
keywords: Optional[str] = Query(None, description="The keywords used to match chunk content"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Paged query document chunk list
- Support filtering by document_id
- Support keyword search for segmented content
- Return paging metadata + file list
"""
api_logger.info(f"Paged query document chunk list: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Obtain knowledge base information
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
# 3. Execute paged query
try:
api_logger.debug(f"Start executing document chunk query")
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
total, items = vector_service.search_by_segment(document_id=str(document_id), query=keywords, pagesize=pagesize, page=page, asc=True)
api_logger.info(f"Document chunk query successful: total={total}, returned={len(items)} records")
except Exception as e:
api_logger.error(f"Document chunk query failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
# 4. Return structured response
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
return success(data=result, msg="Query of document chunk list succeeded")
@router.post("/{kb_id}/{document_id}/chunk", response_model=ApiResponse)
async def create_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
create_data: chunk_schema.ChunkCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
create chunk
"""
api_logger.info(f"Create chunk request: kb_id={kb_id}, document_id={document_id}, content={create_data.content}, username: {current_user.username}")
# 1. Obtain knowledge base information
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
# 1. Obtain document information
db_document = db.query(Document).filter(Document.id == document_id).first()
if not db_document:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document does not exist or you do not have permission to access it"
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# 2. Get the sort ID
sort_id = 0
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
if items:
sort_id = items[0].metadata["sort_id"]
sort_id = sort_id + 1
doc_id = uuid.uuid4().hex
metadata = {
"doc_id": doc_id,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(document_id),
"knowledge_id": str(kb_id),
"sort_id": sort_id,
"status": 1,
}
chunk = DocumentChunk(page_content=create_data.content, metadata=metadata)
# 3. Segmented vector storage
vector_service.add_chunks([chunk])
# 4.update chunk_num
db_document.chunk_num += 1
db.commit()
return success(data=chunk, msg="Document chunk creation successful")
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
async def get_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
doc_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve document chunk information based on doc_id
"""
api_logger.info(f"Obtain document chunk information: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, username: {current_user.username}")
# 1. Obtain knowledge base information
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
total, items = vector_service.get_by_segment(doc_id=doc_id)
if total:
return success(data=items[0], msg="Document chunk query successful")
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document chunk does not exist or you do not have access"
)
@router.put("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
async def update_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
doc_id: str,
update_data: chunk_schema.ChunkUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Update document chunk content
"""
api_logger.info(f"Update document chunk content: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, content={update_data.content}, username: {current_user.username}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
total, items = vector_service.get_by_segment(doc_id=doc_id)
if total:
chunk = items[0]
chunk.page_content = update_data.content
vector_service.update_by_segment(chunk)
return success(data=chunk, msg="The document chunk has been successfully updated")
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document chunk does not exist or you do not have access to it"
)
@router.delete("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
async def delete_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
doc_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
delete document chunk
"""
api_logger.info(f"Request to delete document chunk: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, username: {current_user.username}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
if vector_service.text_exists(doc_id):
vector_service.delete_by_ids([doc_id])
# 更新 chunk_num
db_document = db.query(Document).filter(Document.id == document_id).first()
db_document.chunk_num -= 1
db.commit()
return success(msg="The document chunk has been successfully deleted")
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document chunk does not exist or you do not have access to it"
)
@router.get("/retrieve_type", response_model=ApiResponse)
def get_retrieve_types():
return success(msg="Successfully obtained the retrieval type", data=list(chunk_schema.RetrieveType))
@router.post("/retrieval", response_model=Any, status_code=status.HTTP_200_OK)
async def retrieve_chunks(
retrieve_data: chunk_schema.ChunkRetrieve,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
retrieve chunk
"""
api_logger.info(f"retrieve chunk: query={retrieve_data.query}, username: {current_user.username}")
filters = [
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
existing_ids = knowledge_service.get_chunded_knowledgeids(
db=db,
filters=filters,
current_user=current_user
)
filters = [
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
share_ids = knowledge_service.get_chunded_knowledgeids(
db=db,
filters=filters,
current_user=current_user
)
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids)
]
items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters,
current_user=current_user
)
existing_ids.extend(items)
if not existing_ids:
return success(data=[], msg="retrieval successful")
kb_id = existing_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
indices = ",".join(uuid_strs)
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# 1 participle search, 2 semantic search, 3 hybrid search
match retrieve_data.retrieve_type:
case chunk_schema.RetrieveType.PARTICIPLE:
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
return success(data=rs, msg="retrieval successful")
case chunk_schema.RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
return success(data=rs, msg="retrieval successful")
case _:
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
# Efficient deduplication
seen_ids = set()
unique_rs = []
for doc in rs1 + rs2:
if doc.metadata["doc_id"] not in seen_ids:
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
return success(data=rs, msg="retrieval successful")

View File

@@ -0,0 +1,341 @@
import os
from typing import Optional
import datetime
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from app.core.config import settings
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.models import document_model
from app.schemas import document_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import document_service, file_service, knowledge_service
from app.controllers import file_controller
from app.celery_app import celery_app
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/documents",
tags=["documents"],
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
)
@router.get("/{kb_id}/{parent_id}/documents", response_model=ApiResponse)
async def get_documents(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"),
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
document_ids: Optional[str] = Query(None, description="document ids, separated by commas"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Paged query document list
- Support filtering by kb_id and parent_id
- Support keyword search for file names
- Support dynamic sorting
- Return paging metadata + file list
"""
api_logger.info(f"Query document list: kb_id={kb_id}, page={page}, pagesize={pagesize}, keywords={keywords}, document_ids={document_ids}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Construct query conditions
filters = [
document_model.Document.kb_id == kb_id,
document_model.Document.status == 1
]
if parent_id:
files = file_service.get_files_by_parent_id(db=db, parent_id=parent_id, current_user=current_user)
files_ids = [item.id for item in files]
filters.append(document_model.Document.file_id.in_(files_ids))
# Keyword search (fuzzy matching of file name)
if keywords:
api_logger.debug(f"Add keyword search criteria: {keywords}")
filters.append(document_model.Document.file_name.ilike(f"%{keywords}%"))
# document ids
if document_ids:
filters.append(document_model.Document.id.in_(document_ids.split(',')))
# 3. Execute paged query
try:
api_logger.debug(f"Start executing document paging query")
total, items = document_service.get_documents_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
current_user=current_user
)
api_logger.info(f"Document query successful: total={total}, returned={len(items)} records")
except Exception as e:
api_logger.error(f"Document query failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
# 4. Return structured response
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
return success(data=result, msg="Query of document list succeeded")
@router.post("/document", response_model=ApiResponse)
async def create_document(
create_data: document_schema.DocumentCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
create document
"""
api_logger.info(f"Create document request: file_name={create_data.file_name}, kb_id={create_data.kb_id}, username: {current_user.username}")
try:
api_logger.debug(f"Start creating a document: {create_data.file_name}")
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
api_logger.info(f"Document created successfully: {db_document.file_name} (ID: {db_document.id})")
return success(data=document_schema.Document.model_validate(db_document), msg="Document creation successful")
except Exception as e:
api_logger.error(f"Document creation failed: {create_data.file_name} - {str(e)}")
raise
@router.get("/{document_id}", response_model=ApiResponse)
async def get_document(
document_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve document information based on document_id
"""
api_logger.info(f"Obtain document information: document_id={document_id}, username: {current_user.username}")
try:
# 1. Query document information from the database
api_logger.debug(f"query documentation: {document_id}")
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
if not db_document:
api_logger.warning(f"The document does not exist or you do not have access: document_id={document_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document does not exist or you do not have access"
)
api_logger.info(f"Document query successful: {db_document.file_name} (ID: {db_document.id})")
return success(data=document_schema.Document.model_validate(db_document), msg="Successfully obtained document information")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Document query failed: document_id={document_id} - {str(e)}")
raise
@router.put("/{document_id}", response_model=ApiResponse)
async def update_document(
document_id: uuid.UUID,
update_data: document_schema.DocumentUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Update document information
"""
# 1. Check if the document exists
api_logger.debug(f"Query the document to be updated: {document_id}")
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
if not db_document:
api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document does not exist or you do not have permission to access it"
)
# 2. If updating the status, synchronize the document status switch to whether it can be retrieved from the vector database
update_dict = update_data.dict(exclude_unset=True)
if "status" in update_dict:
new_status = update_dict["status"]
if new_status != db_document.status:
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
vector_service.change_status_by_document_id(document_id=str(document_id), status=new_status)
# 3. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the document fields: {document_id}")
updated_fields = []
for field, value in update_dict.items():
if hasattr(db_document, field):
old_value = getattr(db_document, field)
if old_value != value:
# update value
setattr(db_document, field, value)
updated_fields.append(f"{field}: {old_value} -> {value}")
if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
db_document.updated_at = datetime.datetime.now()
# 4. Save to database
try:
db.commit()
db.refresh(db_document)
api_logger.info(f"The document has been successfully updated: {db_document.file_name} (ID: {db_document.id})")
except Exception as e:
db.rollback()
api_logger.error(f"Document update failed: document_id={document_id} - {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Document update failed: {str(e)}"
)
# 5. Return the updated document
return success(data=document_schema.Document.model_validate(db_document), msg="Document information updated successfully")
@router.delete("/{document_id}", response_model=ApiResponse)
async def delete_document(
document_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Delete document
"""
api_logger.info(f"Request to delete document: document_id={document_id}, username: {current_user.username}")
try:
# 1. Check if the document exists
api_logger.debug(f"Check whether the document exists: {document_id}")
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
if not db_document:
api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document does not exist or you do not have permission to access it"
)
file_id = db_document.file_id
# 2. Delete document
api_logger.debug(f"Perform document delete: {db_document.file_name} (ID: {document_id})")
db.delete(db_document)
db.commit()
# 3. Delete file
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user)
# 4. Delete vector index
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
api_logger.info(f"The document has been successfully deleted: {db_document.file_name} (ID: {document_id})")
return success(msg="The document has been successfully deleted")
except Exception as e:
api_logger.error(f"Failed to delete from the document: document_id={document_id} - {str(e)}")
raise
@router.post("/{document_id}/chunks", response_model=ApiResponse)
async def parse_documents(
document_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
parse document
"""
api_logger.info(f"Request to parse document: document_id={document_id}, username: {current_user.username}")
try:
# 1. Check if the document exists
api_logger.debug(f"Check whether the document exists: {document_id}")
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
if not db_document:
api_logger.warning(f"The document does not exist or you do not have permission to access it: document_id={document_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The document does not exist or you do not have permission to access it"
)
# 2. Check if the file exists
api_logger.debug(f"Check whether the file exists: {db_document.file_id}")
db_file = file_service.get_file_by_id(db, file_id=db_document.file_id)
if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={db_document.file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 3. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
file_path = os.path.join(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 4. Check if the file exists
if not os.path.exists(file_path):
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)"
)
# 5. Obtain knowledge base information
api_logger.info( f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={db_document.kb_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
# 6. Task: Document parsing, vectorization, and storage
# from app.tasks import parse_document
# parse_document(file_path, document_id)
task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id])
result = {
"task_id": task.id
}
return success(data=result, msg="Task accepted. The document is being processed in the background.")
except Exception as e:
api_logger.error(f"Failed to parse document: document_id={document_id} - {str(e)}")
raise

View File

@@ -0,0 +1,453 @@
import os
from typing import Any, Optional
from pathlib import Path
import shutil
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from app.core.config import settings
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.models import file_model
from app.schemas import file_schema, document_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import file_service, document_service
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/files",
tags=["files"]
)
@router.get("/{kb_id}/{parent_id}/files", response_model=ApiResponse)
async def get_files(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"),
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Paged query file list
- Support filtering by kb_id and parent_id
- Support keyword search for file names
- Support dynamic sorting
- Return paging metadata + file list
"""
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Construct query conditions
filters = [
file_model.File.kb_id == kb_id
]
if parent_id:
filters.append(file_model.File.parent_id == parent_id)
# Keyword search (fuzzy matching of file name)
if keywords:
filters.append(file_model.File.file_name.ilike(f"%{keywords}%"))
# 3. Execute paged query
try:
api_logger.debug(f"Start executing file paging query")
total, items = file_service.get_files_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
current_user=current_user
)
api_logger.info(f"File query successful: total={total}, returned={len(items)} records")
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
# 4. Return structured response
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
return success(data=result, msg="Query of file list succeeded")
@router.post("/folder", response_model=ApiResponse)
def create_folder(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
folder_name: str = '/',
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Create a new folder
"""
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}, username: {current_user.username}")
try:
api_logger.debug(f"Start creating a folder: {folder_name}")
create_folder = file_schema.FileCreate(
kb_id=kb_id,
created_by=current_user.id,
parent_id=parent_id,
file_name=folder_name,
file_ext='folder',
file_size=0,
)
db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user)
api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})")
return success(data=file_schema.File.model_validate(db_file), msg="Folder creation successful")
except Exception as e:
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
raise
@router.post("/file", response_model=ApiResponse)
async def upload_file(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
upload file
"""
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}, username: {current_user.username}")
# Read the contents of the file
contents = await file.read()
# Check file size
file_size = len(contents)
print(f"file size: {file_size} byte")
if file_size == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The file is empty."
)
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
if file_size > settings.MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE}byte limit"
)
# Extract the extension using `os.path.splitext`
_, file_extension = os.path.splitext(file.filename)
upload_file = file_schema.FileCreate(
kb_id=kb_id,
created_by=current_user.id,
parent_id=parent_id,
file_name=file.filename,
file_ext=file_extension.lower(),
file_size=file_size,
)
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension}
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
save_path = os.path.join(save_dir, f"{db_file.id}{file_extension}")
# Save file
with open(save_path, "wb") as f:
f.write(contents)
# Verify whether the file has been saved successfully
if not os.path.exists(save_path):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="File save failed"
)
# Create a document
create_data = document_schema.DocumentCreate(
kb_id=kb_id,
created_by=current_user.id,
file_id=db_file.id,
file_name=db_file.file_name,
file_ext=db_file.file_ext,
file_size=db_file.file_size,
file_meta={},
parser_id="naive",
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
}
)
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
api_logger.info(f"File upload successfully: {file.filename} (file_id: {db_file.id}, document_id: {db_document.id})")
return success(data=document_schema.Document.model_validate(db_document), msg="File upload successful")
@router.post("/customtext", response_model=ApiResponse)
async def custom_text(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
create_data: file_schema.CustomTextFileCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
custom text
"""
api_logger.info(f"custom text upload request: kb_id={kb_id}, parent_id={parent_id}, title={create_data.title}, content={create_data.content}, username: {current_user.username}")
# Check file content size
# 将内容编码为字节UTF-8
content_bytes = create_data.content.encode('utf-8')
file_size = len(content_bytes)
print(f"file size: {file_size} byte")
if file_size == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The content is empty."
)
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
if file_size > settings.MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit"
)
upload_file = file_schema.FileCreate(
kb_id=kb_id,
created_by=current_user.id,
parent_id=parent_id,
file_name=f"{create_data.title}.txt",
file_ext=".txt",
file_size=file_size,
)
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension}
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
save_path = os.path.join(save_dir, f"{db_file.id}.txt")
# Save file
with open(save_path, "wb") as f:
f.write(content_bytes)
# Verify whether the file has been saved successfully
if not os.path.exists(save_path):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="File save failed"
)
# Create a document
create_document_data = document_schema.DocumentCreate(
kb_id=kb_id,
created_by=current_user.id,
file_id=db_file.id,
file_name=db_file.file_name,
file_ext=db_file.file_ext,
file_size=db_file.file_size,
file_meta={},
parser_id="naive",
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
}
)
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})")
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
@router.get("/{file_id}", response_model=Any)
async def get_file(
file_id: uuid.UUID,
db: Session = Depends(get_db)
) -> Any:
"""
Download the file based on the file_id
- Query file information from the database
- Construct the file path and check if it exists
- Return a FileResponse to download the file
"""
api_logger.info(f"Download the file based on the file_id: file_id={file_id}")
# 1. Query file information from the database
db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 2. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
file_path = os.path.join(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 3. Check if the file exists
if not os.path.exists(file_path):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)"
)
# 4.Return FileResponse (automatically handle download)
return FileResponse(
path=file_path,
filename=db_file.file_name, # Use original file name
media_type="application/octet-stream" # Universal binary stream type
)
@router.put("/{file_id}", response_model=ApiResponse)
async def update_file(
file_id: uuid.UUID,
update_data: file_schema.FileUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Update file information (such as file name)
- Only specified fields such as file_name are allowed to be modified
"""
api_logger.debug(f"Query the file to be updated: {file_id}")
# 1. Check if the file exists
db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 2. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the file fields: {file_id}")
updated_fields = []
for field, value in update_data.items():
if hasattr(db_file, field):
old_value = getattr(db_file, field)
if old_value != value:
# update value
setattr(db_file, field, value)
updated_fields.append(f"{field}: {old_value} -> {value}")
if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
# 3. Save to database
try:
db.commit()
db.refresh(db_file)
api_logger.info(f"The file has been successfully updated: {db_file.file_name} (ID: {db_file.id})")
except Exception as e:
db.rollback()
api_logger.error(f"File update failed: file_id={file_id} - {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"File update failed: {str(e)}"
)
# 4. Return the updated file
return success(data=file_schema.File.model_validate(db_file), msg="File information updated successfully")
@router.delete("/{file_id}", response_model=ApiResponse)
async def delete_file(
file_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Delete a file or folder
"""
api_logger.info(f"Request to delete file: file_id={file_id}, username: {current_user.username}")
await _delete_file(db=db, file_id=file_id, current_user=current_user)
return success(msg="File deleted successfully")
async def _delete_file(
file_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> None:
"""
Delete a file or folder
"""
# 1. Check if the file exists
db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 2. Construct physical path
file_path = Path(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.id)
) if db_file.file_ext == 'folder' else Path(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 3. Delete physical files/folders
try:
if file_path.exists():
if db_file.file_ext == 'folder':
shutil.rmtree(file_path) # Recursively delete folders
else:
file_path.unlink() # Delete a single file
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete physical file/folder: {str(e)}"
)
# 4.Delete db_file
if db_file.file_ext == 'folder':
db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete()
db.delete(db_file)
db.commit()

View File

@@ -0,0 +1,305 @@
from typing import Optional
import datetime
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy import or_
from sqlalchemy.orm import Session
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.models import knowledge_model, document_model, file_model
from app.schemas import knowledge_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import knowledge_service, document_service
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/knowledges",
tags=["knowledges"],
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
)
@router.get("/knowledgetype", response_model=ApiResponse)
def get_knowledge_types():
return success(msg="Successfully obtained the knowledge type", data=list(knowledge_model.KnowledgeType))
@router.get("/permissiontype", response_model=ApiResponse)
def get_permission_types():
return success(msg="Successfully obtained the knowledge permission type", data=list(knowledge_model.PermissionType))
@router.get("/parsertype", response_model=ApiResponse)
def get_parser_types():
return success(msg="Successfully obtained the knowledge parser type", data=list(knowledge_model.ParserType))
@router.get("/knowledges", response_model=ApiResponse)
async def get_knowledges(
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id"),
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"),
keywords: Optional[str] = Query(None, description="Search keywords (knowledge base name)"),
kb_ids: Optional[str] = Query(None, description="Knowledge base ids, separated by commas"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Query the knowledge base list in pages
- Support filtering by parent_id
- Support keyword search for knowledge base names
- Support dynamic sorting
- Return paging metadata + file list
"""
api_logger.info(f"Query knowledge base list: workspace_id={current_user.current_workspace_id}, page={page}, pagesize={pagesize}, keywords={keywords}, kb_ids={kb_ids}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Construct query conditions
filters = [
knowledge_model.Knowledge.workspace_id == current_user.current_workspace_id
]
if parent_id:
filters.append(knowledge_model.Knowledge.parent_id == parent_id)
# Keyword search (fuzzy matching of knowledge base name)
if keywords:
api_logger.debug(f"Add keyword search criteria: {keywords}")
filters.append(
or_(
knowledge_model.Knowledge.name.ilike(f"%{keywords}%"),
knowledge_model.Knowledge.description.ilike(f"%{keywords}%")
)
)
# Knowledge base ids
if kb_ids:
filters.append(knowledge_model.Knowledge.id.in_(kb_ids.split(',')))
else:
filters.append(knowledge_model.Knowledge.status != 2)
# 3. Execute paged query
try:
api_logger.debug(f"Start executing knowledge base paging query")
total, items = knowledge_service.get_knowledges_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
current_user=current_user
)
api_logger.info(f"Knowledge base query successful: total={total}, returned={len(items)} records")
except Exception as e:
api_logger.error(f"Knowledge base query failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
# 4. Return structured response
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page*pagesize < total else False
}
}
return success(data=result, msg="Query of knowledge base list successful")
@router.post("/knowledge", response_model=ApiResponse)
async def create_knowledge(
create_data: knowledge_schema.KnowledgeCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
create knowledge
"""
api_logger.info(f"Request to create a knowledge base: name={create_data.name}, workspace_id={current_user.current_workspace_id}, username: {current_user.username}")
try:
api_logger.debug(f"Start creating the knowledge base: {create_data.name}")
# 1. Check if the knowledge base name already exists
db_knowledge_exist = knowledge_service.get_knowledge_by_name(db, name=create_data.name, current_user=current_user)
if db_knowledge_exist:
api_logger.warning(f"The knowledge base name already exists: {create_data.name}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The knowledge base name already exists: {create_data.name}"
)
db_knowledge = knowledge_service.create_knowledge(db=db, knowledge=create_data, current_user=current_user)
api_logger.info(f"The knowledge base has been successfully created: {db_knowledge.name} (ID: {db_knowledge.id})")
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base has been successfully created")
except Exception as e:
api_logger.error(f"The creation of the knowledge base failed: {create_data.name} - {str(e)}")
raise
@router.get("/{knowledge_id}", response_model=ApiResponse)
async def get_knowledge(
knowledge_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve knowledge base information based on knowledge_id
"""
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
try:
# 1. Query knowledge base information from the database
api_logger.debug(f"Query knowledge base: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
api_logger.info(f"Knowledge base query successful: {db_knowledge.name} (ID: {db_knowledge.id})")
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="Successfully obtained knowledge base information")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Knowledge base query failed: knowledge_id={knowledge_id} - {str(e)}")
raise
@router.put("/{knowledge_id}", response_model=ApiResponse)
async def update_knowledge(
knowledge_id: uuid.UUID,
update_data: knowledge_schema.KnowledgeUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
api_logger.info(f"Update knowledge base request: knowledge_id={knowledge_id}, username: {current_user.username}")
db_knowledge = await _update_knowledge(knowledge_id=knowledge_id, update_data=update_data, db=db, current_user=current_user)
return success(data=knowledge_schema.Knowledge.model_validate(db_knowledge), msg="The knowledge base information has been successfully updated")
async def _update_knowledge(
knowledge_id: uuid.UUID,
update_data: knowledge_schema.KnowledgeUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> knowledge_schema.Knowledge:
"""
Update knowledge base information
"""
try:
# 1. Check whether the knowledge base exists
api_logger.debug(f"Query the knowledge base to be updated: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or you do not have permission to access it"
)
# 2. If updating the embedding_id, delete the knowledge base vector index, reset all document parsing progress to 0, and set chunk_num to 0
update_dict = update_data.dict(exclude_unset=True)
if "name" in update_dict:
name = update_dict["name"]
if name != db_knowledge.name:
# Check if the knowledge base name already exists
db_knowledge_exist = knowledge_service.get_knowledge_by_name(db, name=name, current_user=current_user)
if db_knowledge_exist:
api_logger.warning(f"The knowledge base name already exists: {name}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The knowledge base name already exists: {name}"
)
if "embedding_id" in update_dict:
embedding_id = update_dict["embedding_id"]
if embedding_id != db_knowledge.embedding_id:
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
vector_service.delete()
document_service.reset_documents_progress_by_kb_id(db, kb_id=db_knowledge.id, current_user=current_user)
# 2. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the knowledge base fields: {knowledge_id}")
updated_fields = []
for field, value in update_data.dict(exclude_unset=True).items():
if hasattr(db_knowledge, field):
old_value = getattr(db_knowledge, field)
if old_value != value:
# update value
setattr(db_knowledge, field, value)
updated_fields.append(f"{field}: {old_value} -> {value}")
if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
db_knowledge.updated_at = datetime.datetime.now()
# 3. Save to database
db.commit()
db.refresh(db_knowledge)
api_logger.info(f"The knowledge base has been successfully updated: {db_knowledge.name} (ID: {db_knowledge.id})")
# 4. Return the updated knowledge base
return db_knowledge
except HTTPException:
raise
except Exception as e:
db.rollback()
api_logger.error(f"Knowledge base update failed: knowledge_id={knowledge_id} - {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Knowledge base update failed: {str(e)}"
)
@router.delete("/{knowledge_id}", response_model=ApiResponse)
async def delete_knowledge(
knowledge_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Soft-delete knowledge base
"""
api_logger.info(f"Request to delete knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
try:
# 1. Check whether the knowledge base exists
api_logger.debug(f"Check whether the knowledge base exists: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or you do not have permission to access it: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or you do not have permission to access it"
)
# 2. Soft-delete knowledge base
api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})")
db_knowledge.status = 2
db.commit()
api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
return success(msg="The knowledge base has been successfully deleted")
except Exception as e:
api_logger.error(f"Failed to delete from the knowledge base: knowledge_id={knowledge_id} - {str(e)}")
raise

View File

@@ -0,0 +1,199 @@
from typing import Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.models import knowledgeshare_model, knowledge_model
from app.schemas import knowledgeshare_schema, knowledge_schema
from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import knowledgeshare_service, knowledge_service
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/knowledgeshares",
tags=["knowledgeshares"],
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
)
@router.get("/{kb_id}/knowledgeshares", response_model=ApiResponse)
async def get_knowledgeshares(
kb_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Paged query knowledge base sharing list
- Support filtering by kb_id
- Support dynamic sorting
- Return paging metadata + share list
"""
api_logger.info(
f"Query knowledge base sharing list: workspace_id={current_user.current_workspace_id}, kb_id={kb_id}, page={page}, pagesize={pagesize}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Construct query conditions
filters = [
knowledgeshare_model.KnowledgeShare.source_workspace_id == current_user.current_workspace_id,
knowledgeshare_model.KnowledgeShare.source_kb_id == kb_id
]
# 3. Execute paged query
try:
api_logger.debug(f"Start executing knowledge base sharing and paging query")
total, items = knowledgeshare_service.get_knowledgeshares_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
current_user=current_user
)
api_logger.info(f"Knowledge base sharing query successful: total={total}, returned={len(items)} records")
except Exception as e:
api_logger.error(f"Knowledge base sharing query failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
# 4. Return structured response
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
return success(data=result, msg="Query of knowledge base sharing list successful")
@router.post("/knowledgeshare", response_model=ApiResponse)
async def create_knowledgeshare(
create_data: knowledgeshare_schema.KnowledgeShareCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
create knowledgeshare
"""
api_logger.info(
f"Create a knowledge base sharing request: source_kb_id={create_data.source_kb_id}, source_workspace_id={current_user.current_workspace_id}, username: {current_user.username}")
try:
# 1.Create a knowledge base with permission_id=knowledge_model.PermissionType.Share
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=create_data.source_kb_id, current_user=current_user)
knowledge = knowledge_schema.KnowledgeCreate(
workspace_id=create_data.target_workspace_id,
created_by=current_user.id,
parent_id=create_data.target_workspace_id,
name=db_knowledge.name,
description=db_knowledge.description,
avatar=db_knowledge.avatar,
type=db_knowledge.type,
permission_id=knowledge_model.PermissionType.Share,
embedding_id=db_knowledge.embedding_id,
reranker_id=db_knowledge.reranker_id,
llm_id=db_knowledge.llm_id,
image2text_id=db_knowledge.image2text_id,
doc_num=db_knowledge.doc_num,
chunk_num=db_knowledge.chunk_num,
parser_id=db_knowledge.parser_id,
parser_config=db_knowledge.parser_config
)
db_knowledge = knowledge_service.create_knowledge(db=db, knowledge=knowledge, current_user=current_user)
# 2. Create a knowledge base for sharing
api_logger.debug(f"Start creating the knowledge base sharing: {db_knowledge.name}")
create_data.target_kb_id = db_knowledge.id
db_knowledgeshare = knowledgeshare_service.create_knowledgeshare(db=db, knowledgeshare=create_data, current_user=current_user)
api_logger.info(f"The knowledge base sharing has been successfully created: (ID: {db_knowledgeshare.id})")
return success(data=knowledgeshare_schema.KnowledgeShare.model_validate(db_knowledgeshare), msg="The knowledge base sharing has been successfully created")
except Exception as e:
api_logger.error(f"The creation of the knowledge base sharing failed: {str(e)}")
raise
@router.get("/{knowledgeshare_id}", response_model=ApiResponse)
async def get_knowledgeshare(
knowledgeshare_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve knowledge base sharing information based on knowledgeshare_id
"""
api_logger.info(f"Obtain details of the knowledge base sharing: knowledgeshare_id={knowledgeshare_id}, username: {current_user.username}")
try:
# 1. Query knowledge base sharing information from the database
api_logger.debug(f"Query knowledge base sharing: {knowledgeshare_id}")
db_knowledgeshare = knowledgeshare_service.get_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user)
if not db_knowledgeshare:
api_logger.warning(f"The knowledge base sharing does not exist or access is denied: knowledgeshare_id={knowledgeshare_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base sharing does not exist or access is denied"
)
api_logger.info(f"Knowledge base sharing query successful: (ID: {db_knowledgeshare.id})")
return success(data=knowledgeshare_schema.KnowledgeShare.model_validate(db_knowledgeshare), msg="Successfully obtained knowledge base sharing information")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Knowledge base sharing query failed: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
raise
@router.delete("/{knowledgeshare_id}", response_model=ApiResponse)
async def delete_knowledgeshare(
knowledgeshare_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Delete knowledge base sharing
"""
api_logger.info(f"Delete knowledge base sharing request: knowledgeshare_id={knowledgeshare_id}, username: {current_user.username}")
try:
# 1. Query knowledge base sharing information from the database
api_logger.debug(f"Query knowledge base sharing: {knowledgeshare_id}")
db_knowledgeshare = knowledgeshare_service.get_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user)
if not db_knowledgeshare:
api_logger.warning(f"The knowledge base sharing does not exist or access is denied: knowledgeshare_id={knowledgeshare_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base sharing does not exist or access is denied"
)
# 2. Deleting shared knowledge base
knowledge_service.delete_knowledge_by_id(db, knowledge_id=db_knowledgeshare.target_kb_id ,current_user=current_user)
# 3. Delete knowledge base sharing
api_logger.debug(f"perform knowledge base sharing delete: (ID: {knowledgeshare_id})")
knowledgeshare_service.delete_knowledgeshare_by_id(db, knowledgeshare_id=knowledgeshare_id, current_user=current_user)
api_logger.info(f"The knowledge base sharing has been successfully deleted: (ID: {knowledgeshare_id})")
return success(msg="The knowledge base sharing has been successfully deleted")
except Exception as e:
api_logger.error(f"Failed to delete from the knowledge base sharing: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
raise

View File

@@ -0,0 +1,802 @@
import json
import time
from typing import Optional, List
from fastapi import APIRouter, Depends, Query, UploadFile
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.db import get_db
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.rag.llm.cv_model import QWenCV
from app.models import ModelApiKey, Knowledge
from app.services.memory_agent_service import MemoryAgentService
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.services import task_service, workspace_service
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse
from app.dependencies import get_current_user
from app.models.user_model import User
from fastapi import APIRouter, Depends, File, UploadFile, Form
from app.repositories import knowledge_repository
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
import os
# 加载.env文件
load_dotenv()
# Get API logger
api_logger = get_api_logger()
# Initialize service
memory_agent_service = MemoryAgentService()
router = APIRouter(
prefix="/memory",
tags=["Memory"],
)
def validate_config_id(config_id: int, db: Session) -> int:
"""
Validate and ensure config_id is available, valid, and exists in database.
Args:
config_id: Configuration ID to validate
db: Database session for checking existence
Returns:
int: Validated config_id
Raises:
ValueError: If config_id is None, invalid, or doesn't exist in database
"""
if config_id is None:
api_logger.info(f"config_id is required but was not provided")
config_id = os.getenv('config_id')
if config_id is None:
raise ValueError("config_id is required but was not provided")
# Check if config exists in database
try:
from app.models.data_config_model import DataConfig
from app.models.models_model import ModelConfig
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if config is None:
error_msg = f"Configuration with config_id={config_id} does not exist in database"
api_logger.error(error_msg)
raise ValueError(error_msg)
# Validate llm_id exists and is usable
if config.llm_id:
try:
llm_config = db.query(ModelConfig).filter(ModelConfig.id == config.llm_id).first()
if llm_config is None:
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) does not exist"
api_logger.error(error_msg)
raise ValueError(error_msg)
if not llm_config.is_active:
error_msg = f"LLM model with id={config.llm_id} (from config_id={config_id}) is not active"
api_logger.error(error_msg)
raise ValueError(error_msg)
api_logger.debug(f"LLM validation successful: llm_id={config.llm_id}, name={llm_config.name}")
except ValueError:
raise
except Exception as e:
error_msg = f"Error validating LLM model: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
else:
api_logger.error(f"Config {config_id} has no llm_id set")
raise ValueError(f"Config {config_id} has no llm_id set")
# Validate embedding_id exists and is usable
if config.embedding_id:
try:
embedding_config = db.query(ModelConfig).filter(ModelConfig.id == config.embedding_id).first()
if embedding_config is None:
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) does not exist"
api_logger.error(error_msg)
raise ValueError(error_msg)
if not embedding_config.is_active:
error_msg = f"Embedding model with id={config.embedding_id} (from config_id={config_id}) is not active"
api_logger.error(error_msg)
raise ValueError(error_msg)
api_logger.debug(f"Embedding validation successful: embedding_id={config.embedding_id}, name={embedding_config.name}")
except ValueError:
raise
except Exception as e:
error_msg = f"Error validating embedding model: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
else:
api_logger.error(f"Config {config_id} has no embedding_id set")
raise ValueError(f"Config {config_id} has no embedding_id set")
api_logger.info(f"Config validation successful: config_id={config_id}, config_name={config.config_name}, llm_id={config.llm_id}, embedding_id={config.embedding_id}")
return config_id
except ValueError:
# Re-raise ValueError from above
raise
except Exception as e:
error_msg = f"Database error while validating config_id={config_id}: {str(e)}"
api_logger.error(error_msg, exc_info=True)
raise ValueError(error_msg)
@router.get("/health/status", response_model=ApiResponse)
async def get_health_status(
current_user: User = Depends(get_current_user)
):
"""
Get latest health status written by Celery periodic task
Returns health status information from Redis cache
"""
api_logger.info("Health status check requested")
try:
result = await memory_agent_service.get_health_status()
return success(data=result["status"])
except Exception as e:
api_logger.error(f"Health status check failed: {str(e)}")
return fail(BizCode.SERVICE_UNAVAILABLE, "健康状态查询失败", str(e))
@router.get("/download_log")
async def download_log(
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
):
"""
Download or stream agent service log file
log_type: str = Query("file", regex="^(file|transmission)$",
description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
Args:
log_type: Log retrieval mode
- "file": Returns complete log file content in single response (default)
- "transmission": Real-time streaming of log content using Server-Sent Events
Returns:
- file mode: ApiResponse with log content
- transmission mode: StreamingResponse with SSE
"""
api_logger.info(f"Log download requested with log_type={log_type}")
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
if log_type not in ["file", "transmission"]:
api_logger.warning(f"Invalid log_type parameter: {log_type}")
return fail(
BizCode.BAD_REQUEST,
"无效的log_type参数",
"log_type必须是'file''transmission'"
)
# Route to appropriate mode
if log_type == "file":
# File mode: Return complete log file content
try:
log_content = memory_agent_service.get_log_content()
return success(data=log_content)
except ValueError as e:
api_logger.warning(f"Log content issue: {str(e)}")
return fail(BizCode.FILE_NOT_FOUND, str(e))
except Exception as e:
api_logger.error(f"Log reading failed: {str(e)}")
return fail(BizCode.FILE_READ_ERROR, "日志读取失败", str(e))
else: # log_type == "transmission"
# Transmission mode: Stream log content using SSE
try:
api_logger.info("Starting SSE log streaming")
return StreamingResponse(
memory_agent_service.stream_log_content(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # Disable nginx buffering
}
)
except Exception as e:
api_logger.error(f"Failed to start log streaming: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
@router.post("/writer_service", response_model=ApiResponse)
@cur_workspace_access_guard()
async def write_server(
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Write service endpoint - processes write operations synchronously
Args:
user_input: Write request containing message and group_id
Returns:
Response with write operation status
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
workspace_id = current_user.current_workspace_id
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
# 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id
if storage_type == 'rag':
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
else:
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j'
else:
api_logger.warning(f"workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j'
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try:
result = await memory_agent_service.write_memory(
user_input.group_id,
user_input.message,
config_id,
storage_type,
user_rag_memory_id
)
return success(data=result, msg="写入成功")
except Exception as e:
api_logger.error(f"Write operation error: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
@router.post("/writer_service_async", response_model=ApiResponse)
@cur_workspace_access_guard()
async def write_server_async(
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Async write service endpoint - enqueues write processing to Celery
Args:
user_input: Write request containing message and group_id
Returns:
Task ID for tracking async operation
Use GET /memory/write_result/{task_id} to check task status and get result
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
task = celery_app.send_task(
"app.core.memory.agent.write_message",
args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Write task queued: {task.id}")
return success(data={"task_id": task.id}, msg="写入任务已提交")
except Exception as e:
api_logger.error(f"Async write operation failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
@router.post("/read_service", response_model=ApiResponse)
@cur_workspace_access_guard()
async def read_server(
user_input: UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Read service endpoint - processes read operations synchronously
search_switch values:
- "0": Requires verification
- "1": No verification, direct split
- "2": Direct answer based on context
Args:
user_input: Read request with message, history, search_switch, and group_id
Returns:
Response with query answer
"""
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
workspace_id = current_user.current_workspace_id
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try:
result = await memory_agent_service.read_memory(
user_input.group_id,
user_input.message,
user_input.history,
user_input.search_switch,
config_id,
storage_type,
user_rag_memory_id
)
return success(data=result, msg="回复对话消息成功")
except Exception as e:
api_logger.error(f"Read operation error: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
@router.post("/file", response_model=ApiResponse)
async def file_update(
files: List[UploadFile] = File(..., description="要上传的文件"),
model_id:str = Form(..., description="模型ID"),
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
current_user: User = Depends(get_current_user)
):
"""
文件上传接口 - 支持图片识别
Args:
files: 上传的文件列表
metadata: 文件元数据(可选)
current_user: 当前用户
Returns:
文件处理结果
"""
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen)
api_logger.info(f"File upload requested, file count: {len(files)}")
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
apiConfig: ModelApiKey = config.api_keys[0]
file_content = []
try:
for file in files:
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
content = await file.read()
if file.content_type and file.content_type.startswith("image/"):
vision_model = QWenCV(
key=apiConfig.api_key,
model_name=apiConfig.model_name,
lang="Chinese",
base_url=apiConfig.api_base
)
description, token_count = vision_model.describe(content)
file_content.append(description)
api_logger.info(f"Image processed: {file.filename}, tokens: {token_count}")
else:
api_logger.warning(f"Unsupported file type: {file.content_type}")
file_content.append(f"[不支持的文件类型: {file.content_type}]")
result_text = ';'.join(file_content)
api_logger.info(f"File processing completed, result length: {len(result_text)}")
return success(data=result_text, msg="转换文本成功")
except Exception as e:
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
@router.post("/read_service_async", response_model=ApiResponse)
@cur_workspace_access_guard()
async def read_server_async(
user_input: UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
# Validate config_id
try:
config_id = validate_config_id(user_input.config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async read service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async read: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
task = celery_app.send_task(
"app.core.memory.agent.read_message",
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Read task queued: {task.id}")
return success(data={"task_id": task.id}, msg="查询任务已提交")
except Exception as e:
api_logger.error(f"Async read operation failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "回复对话消息失败", str(e))
@router.get("/read_result/", response_model=ApiResponse)
async def get_read_task_result(
task_id: str,
current_user: User = Depends(get_current_user)
):
"""
Get the status and result of an async read task
Args:
task_id: Celery task ID returned from /read_service_async
Returns:
Task status and result if completed
Response format:
- PENDING: Task is waiting to be executed
- STARTED: Task has started
- SUCCESS: Task completed successfully, returns result
- FAILURE: Task failed, returns error message
"""
api_logger.info(f"Read task status check requested for task {task_id}")
try:
result = task_service.get_task_memory_read_result(task_id)
status = result.get("status")
if status == "SUCCESS":
# 任务成功完成
task_result = result.get("result", {})
if isinstance(task_result, dict):
# 新格式:包含详细信息
return success(
data={
"result": task_result.get("result"),
"group_id": task_result.get("group_id"),
"elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id
},
msg="查询任务已完成"
)
else:
# 旧格式:直接返回结果
return success(data=task_result, msg="查询任务已完成")
elif status == "FAILURE":
# 任务失败
error_info = result.get("result", "Unknown error")
if isinstance(error_info, dict):
error_msg = error_info.get("error", str(error_info))
else:
error_msg = str(error_info)
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
elif status in ["PENDING", "STARTED"]:
# 任务进行中
return success(
data={
"status": status,
"task_id": task_id,
"message": "任务处理中,请稍后查询"
},
msg="查询任务处理中"
)
else:
# 未知状态
return success(
data={
"status": status,
"task_id": task_id
},
msg=f"任务状态: {status}"
)
except Exception as e:
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
@router.get("/write_result/", response_model=ApiResponse)
async def get_write_task_result(
task_id: str,
current_user: User = Depends(get_current_user)
):
"""
Get the status and result of an async write task
Args:
task_id: Celery task ID returned from /writer_service_async
Returns:
Task status and result if completed
Response format:
- PENDING: Task is waiting to be executed
- STARTED: Task has started
- SUCCESS: Task completed successfully, returns result
- FAILURE: Task failed, returns error message
"""
api_logger.info(f"Write task status check requested for task {task_id}")
try:
result = task_service.get_task_memory_write_result(task_id)
status = result.get("status")
if status == "SUCCESS":
# 任务成功完成
task_result = result.get("result", {})
if isinstance(task_result, dict):
# 新格式:包含详细信息
return success(
data={
"result": task_result.get("result"),
"group_id": task_result.get("group_id"),
"elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id
},
msg="写入任务已完成"
)
else:
# 旧格式:直接返回结果
return success(data=task_result, msg="写入任务已完成")
elif status == "FAILURE":
# 任务失败
error_info = result.get("result", "Unknown error")
if isinstance(error_info, dict):
error_msg = error_info.get("error", str(error_info))
else:
error_msg = str(error_info)
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
elif status in ["PENDING", "STARTED"]:
# 任务进行中
return success(
data={
"status": status,
"task_id": task_id,
"message": "任务处理中,请稍后查询"
},
msg="写入任务处理中"
)
else:
# 未知状态
return success(
data={
"status": status,
"task_id": task_id
},
msg=f"任务状态: {status}"
)
except Exception as e:
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
@router.post("/status_type", response_model=ApiResponse)
async def status_type(
user_input: Write_UserInput,
current_user: User = Depends(get_current_user)
):
"""
Determine the type of user message (read or write)
Args:
user_input: Request containing user message and group_id
Returns:
Type classification result
"""
api_logger.info(f"Status type check requested for group {user_input.group_id}")
try:
result = await memory_agent_service.classify_message_type(user_input.message)
return success(data=result)
except Exception as e:
api_logger.error(f"Message type classification failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "类型判断失败", str(e))
# ==================== 新增的三个接口路由 ====================
@router.get("/stats/types", response_model=ApiResponse)
async def get_knowledge_type_stats_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
current_user: User = Depends(get_current_user)
):
"""
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory。
会对缺失类型补 0返回字典形式。
可选按状态过滤。
- 知识库类型根据当前用户的 current_workspace_id 过滤
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
- 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0
"""
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
try:
from app.db import get_db
# 获取数据库会话
db_gen = get_db()
db = next(db_gen)
# 调用service层函数
result = await memory_agent_service.get_knowledge_type_stats(
end_user_id=end_user_id,
only_active=only_active,
current_workspace_id=current_user.current_workspace_id,
db=db
)
return success(data=result, msg="获取知识库类型统计成功")
except Exception as e:
api_logger.error(f"Knowledge type stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
async def get_hot_memory_tags_by_user_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
limit: int = Query(20, description="返回标签数量限制"),
current_user: User = Depends(get_current_user)
):
"""
获取指定用户的热门记忆标签
返回格式:
[
{"name": "标签名", "frequency": 频次},
...
]
"""
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
try:
result = await memory_agent_service.get_hot_memory_tags_by_user(
end_user_id=end_user_id,
limit=limit
)
return success(data=result, msg="获取热门记忆标签成功")
except Exception as e:
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
@router.get("/analytics/user_profile", response_model=ApiResponse)
async def get_user_profile_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
current_user: User = Depends(get_current_user)
):
"""
获取用户详情,包含:
- name: 用户名字(直接使用 end_user_id
- tags: 3个用户特征标签从语句和实体中LLM总结
- hot_tags: 4个热门记忆标签
返回格式:
{
"name": "用户名",
"tags": ["产品设计师", "旅行爱好者", "摄影发烧友"],
"hot_tags": [
{"name": "标签1", "frequency": 10},
{"name": "标签2", "frequency": 8},
...
]
}
"""
api_logger.info(f"User profile requested: end_user_id={end_user_id}, current_user={current_user.id}")
try:
result = await memory_agent_service.get_user_profile(
end_user_id=end_user_id,
current_user_id=str(current_user.id)
)
return success(data=result, msg="获取用户详情成功")
except Exception as e:
api_logger.error(f"User profile failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取用户详情失败", str(e))
# @router.get("/docs/api", response_model=ApiResponse)
# async def get_api_docs_api(
# file_path: Optional[str] = Query(None, description="API文档文件路径不传则使用默认路径")
# ):
# """
# Get parsed API documentation (Public endpoint - no authentication required)
# Args:
# file_path: Optional path to API docs file. If None, uses default path.
# Returns:
# Parsed API documentation including title, meta info, and sections
# """
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
# try:
# result = await memory_agent_service.get_api_docs(file_path)
# if result.get("success"):
# return success(msg=result["msg"], data=result["data"])
# else:
# return fail(
# code=BizCode.BAD_REQUEST,
# msg=result["msg"],
# error=result.get("data", {}).get("error", result.get("error_code", ""))
# )
# except Exception as e:
# api_logger.error(f"API docs retrieval failed: {str(e)}")
# return fail(BizCode.INTERNAL_ERROR, "API文档获取失败", str(e))

View File

@@ -0,0 +1,516 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.schemas.app_schema import App as AppSchema
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
from app.core.logging_config import get_api_logger
# 获取API专用日志器
api_logger = get_api_logger()
router = APIRouter(
prefix="/dashboard",
tags=["Dashboard"],
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
)
@router.get("/total_end_users", response_model=ApiResponse)
def get_workspace_total_end_users(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取用户列表的总用户数
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
total_end_users = memory_dashboard_service.get_workspace_total_end_users(
db=db,
workspace_id=workspace_id,
current_user=current_user
)
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
return success(data=total_end_users, msg="用户数量获取成功")
@router.get("/end_users", response_model=ApiResponse)
async def get_workspace_end_users(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取工作空间的宿主列表
返回格式与原 memory_list 接口中的 end_users 字段相同
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
end_users = memory_dashboard_service.get_workspace_end_users(
db=db,
workspace_id=workspace_id,
current_user=current_user
)
result = []
for end_user in end_users:
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
memory_num = await memory_storage_service.search_all(str(end_user.id))
result.append(
{
'end_user':end_user,
'memory_num':memory_num
}
)
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return success(data=result, msg="宿主列表获取成功")
@router.get("/memory_increment", response_model=ApiResponse)
def get_workspace_memory_increment(
limit: int = Query(7, description="返回记录数"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取工作空间的记忆增量"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆增量")
memory_increment = memory_dashboard_service.get_workspace_memory_increment(
db=db,
workspace_id=workspace_id,
current_user=current_user,
limit=limit
)
api_logger.info(f"成功获取 {len(memory_increment)} 条记忆增量记录")
return success(data=memory_increment, msg="记忆增量获取成功")
@router.get("/api_increment", response_model=ApiResponse)
def get_workspace_api_increment(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取API调用趋势"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的API调用增量")
api_increment = memory_dashboard_service.get_workspace_api_increment(
db=db,
workspace_id=workspace_id,
current_user=current_user
)
api_logger.info(f"成功获取 {api_increment} API调用增量")
return success(data=api_increment, msg="API调用增量获取成功")
@router.post("/total_memory", response_model=ApiResponse)
def write_workspace_total_memory(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""工作空间记忆总量的写入(异步任务)"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求写入工作空间 {workspace_id} 的记忆总量")
# 触发 Celery 异步任务
from app.celery_app import celery_app
task = celery_app.send_task(
"app.controllers.memory_storage_controller.search_all",
kwargs={"workspace_id": str(workspace_id)}
)
api_logger.info(f"已触发记忆总量统计任务task_id: {task.id}")
return success(
data={"task_id": task.id, "workspace_id": str(workspace_id)},
msg="记忆总量统计任务已启动"
)
@router.get("/task_status/{task_id}", response_model=ApiResponse)
def get_task_status(
task_id: str,
current_user: User = Depends(get_current_user),
):
"""查询异步任务的执行状态和结果"""
api_logger.info(f"用户 {current_user.username} 查询任务状态: task_id={task_id}")
from app.celery_app import celery_app
from celery.result import AsyncResult
# 获取任务结果
task_result = AsyncResult(task_id, app=celery_app)
response_data = {
"task_id": task_id,
"status": task_result.state, # PENDING, STARTED, SUCCESS, FAILURE, RETRY, REVOKED
}
# 如果任务完成,返回结果
if task_result.ready():
if task_result.successful():
response_data["result"] = task_result.result
api_logger.info(f"任务 {task_id} 执行成功")
return success(data=response_data, msg="任务执行成功")
else:
# 任务失败
response_data["error"] = str(task_result.result)
api_logger.error(f"任务 {task_id} 执行失败: {task_result.result}")
return success(data=response_data, msg="任务执行失败")
else:
# 任务还在执行中
api_logger.info(f"任务 {task_id} 状态: {task_result.state}")
return success(data=response_data, msg=f"任务状态: {task_result.state}")
@router.get("/memory_list", response_model=ApiResponse)
def get_workspace_memory_list(
limit: int = Query(7, description="记忆增量返回记录数"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
用户记忆列表整合接口
整合以下三个接口的数据:
1. total_memory - 工作空间记忆总量
2. memory_increment - 工作空间记忆增量
3. hosts - 工作空间宿主列表
返回格式:
{
"total_memory": float,
"memory_increment": [
{"date": "2024-01-01", "count": 100},
...
],
"hosts": [
{"id": "uuid", "name": "宿主名", ...},
...
]
}
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆列表")
memory_list = memory_dashboard_service.get_workspace_memory_list(
db=db,
workspace_id=workspace_id,
current_user=current_user,
limit=limit
)
api_logger.info(f"成功获取记忆列表")
return success(data=memory_list, msg="记忆列表获取成功")
@router.get("/total_memory_count", response_model=ApiResponse)
async def get_workspace_total_memory_count(
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取工作空间的记忆总量通过聚合所有host的记忆数
逻辑:
1. 从 memory_list 获取所有 host_id
2. 对每个 host_id 调用 search_all 获取 total
3. 将所有 total 求和返回
返回格式:
{
"total_memory_count": int,
"host_count": int,
"details": [
{"host_id": "uuid", "count": 100},
...
]
}
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的记忆总量")
total_memory_count = await memory_dashboard_service.get_workspace_total_memory_count(
db=db,
workspace_id=workspace_id,
current_user=current_user,
end_user_id=end_user_id
)
api_logger.info(f"成功获取记忆总量: {total_memory_count.get('total_memory_count', 0)}")
return success(data=total_memory_count, msg="记忆总量获取成功")
# ======== RAG 数据统计 ========
@router.get("/total_rag_count", response_model=ApiResponse)
def get_workspace_total_rag_count(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取 rag 的总文档数、总chunk数、总知识库数量、总api调用数量
"""
total_documents = memory_dashboard_service.get_rag_total_doc(db, current_user)
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
data = {
'total_documents':total_documents,
'total_chunk':total_chunk,
'total_kb':total_kb,
'total_api':1024
}
return success(data=data, msg="RAG相关数据获取成功")
@router.get("/current_user_rag_total_num", response_model=ApiResponse)
def get_current_user_rag_total_num(
end_user_id: str = Query(..., description="宿主ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取当前宿主的 RAG 的总chunk数量
"""
total_chunk = memory_dashboard_service.get_current_user_total_chunk(end_user_id, db, current_user)
return success(data=total_chunk, msg="宿主RAG知识数据获取成功")
@router.get("/rag_content", response_model=ApiResponse)
def get_rag_content(
end_user_id: str = Query(..., description="宿主ID"),
limit: int = Query(15, description="返回记录数"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取当前宿主知识库中的chunk内容
"""
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user)
return success(data=data, msg="宿主RAGchunk数据获取成功")
@router.get("/chunk_summary_tag", response_model=ApiResponse)
async def get_chunk_summary_tag(
end_user_id: str = Query(..., description="宿主ID"),
limit: int = Query(15, description="返回记录数"),
max_tags: int = Query(10, description="最大标签数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取chunk总结、提取的标签和人物形象
返回格式:
{
"summary": "chunk内容的总结",
"tags": [
{"tag": "标签1", "frequency": 5},
{"tag": "标签2", "frequency": 3},
...
],
"personas": [
"产品设计师",
"旅行爱好者",
"摄影发烧友",
...
]
}
"""
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk摘要、标签和人物形象")
data = await memory_dashboard_service.get_chunk_summary_and_tags(
end_user_id=end_user_id,
limit=limit,
max_tags=max_tags,
db=db,
current_user=current_user
)
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象")
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
@router.get("/chunk_insight", response_model=ApiResponse)
async def get_chunk_insight(
end_user_id: str = Query(..., description="宿主ID"),
limit: int = Query(15, description="返回记录数"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取chunk的洞察内容
返回格式:
{
"insight": "对chunk内容的深度洞察分析"
}
"""
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk洞察")
data = await memory_dashboard_service.get_chunk_insight(
end_user_id=end_user_id,
limit=limit,
db=db,
current_user=current_user
)
api_logger.info(f"成功获取chunk洞察")
return success(data=data, msg="chunk洞察获取成功")
@router.get("/dashboard_data", response_model=ApiResponse)
async def dashboard_data(
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
整合dashboard数据接口
整合以下接口的数据:
1. /dashboard/total_memory_count - 记忆总量
2. /dashboard/api_increment - API调用增量
3. /memory/stats/types - 知识库类型统计只要total数据
4. /dashboard/total_rag_count - RAG相关数据
根据 storage_type 判断调用不同的接口
返回格式:
{
"storage_type": str,
"neo4j_data": {
"total_memory": int,
"total_app": int,
"total_knowledge": int,
"total_api_call": int
} | null,
"rag_data": {
"total_memory": int,
"total_app": int,
"total_knowledge": int,
"total_api_call": int
} | null
}
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
if storage_type is None:
storage_type = 'neo4j'
user_rag_memory_id = None
# 根据 storage_type 决定返回哪个数据对象
# 如果是 'rag'neo4j_data 为 null否则 rag_data 为 null
result = {
"storage_type": storage_type,
"neo4j_data": None,
"rag_data": None
}
try:
# 如果 storage_type 为 'neo4j' 或空,获取 neo4j_data
if storage_type == 'neo4j':
neo4j_data = {
"total_memory": None,
"total_app": None,
"total_knowledge": None,
"total_api_call": None
}
# 1. 获取记忆总量total_memory
try:
total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count(
db=db,
workspace_id=workspace_id,
current_user=current_user,
end_user_id=end_user_id
)
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
# total_app: 统计当前空间下的所有app数量
from app.repositories import app_repository
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
neo4j_data["total_app"] = len(apps_orm)
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
except Exception as e:
api_logger.warning(f"获取记忆总量失败: {str(e)}")
# 2. 获取知识库类型统计total_knowledge
try:
from app.services.memory_agent_service import MemoryAgentService
memory_agent_service = MemoryAgentService()
knowledge_stats = await memory_agent_service.get_knowledge_type_stats(
end_user_id=end_user_id,
only_active=True,
current_workspace_id=workspace_id,
db=db
)
neo4j_data["total_knowledge"] = knowledge_stats.get("total", 0)
api_logger.info(f"成功获取知识库类型统计total: {neo4j_data['total_knowledge']}")
except Exception as e:
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
# 3. 获取API调用增量total_api_call转换为整数
try:
api_increment = memory_dashboard_service.get_workspace_api_increment(
db=db,
workspace_id=workspace_id,
current_user=current_user
)
neo4j_data["total_api_call"] = api_increment
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
except Exception as e:
api_logger.warning(f"获取API调用增量失败: {str(e)}")
result["neo4j_data"] = neo4j_data
api_logger.info(f"成功获取neo4j_data")
# 如果 storage_type 为 'rag',获取 rag_data
elif storage_type == 'rag':
rag_data = {
"total_memory": None,
"total_app": None,
"total_knowledge": None,
"total_api_call": None
}
# 获取RAG相关数据
try:
# total_memory: 使用 total_chunk总chunk数
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
rag_data["total_memory"] = total_chunk
# total_app: 统计当前空间下的所有app数量
from app.repositories import app_repository
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
rag_data["total_app"] = len(apps_orm)
# total_knowledge: 使用 total_kb总知识库数
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
rag_data["total_knowledge"] = total_kb
# total_api_call: 固定值
rag_data["total_api_call"] = 1024
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
except Exception as e:
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
result["rag_data"] = rag_data
api_logger.info(f"成功获取rag_data")
api_logger.info(f"成功获取dashboard整合数据")
return success(data=result, msg="Dashboard数据获取成功")
except Exception as e:
api_logger.error(f"获取dashboard整合数据失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取dashboard整合数据失败: {str(e)}"
)

View File

@@ -0,0 +1,542 @@
from typing import Optional
import os
import uuid
from fastapi import APIRouter, Depends
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.services.memory_storage_service import (
MemoryStorageService,
DataConfigService,
kb_type_distribution,
search_dialogue,
search_chunk,
search_statement,
search_entity,
search_all,
search_detials,
search_edges,
search_entity_graph,
analytics_hot_memory_tags,
analytics_memory_insight_report,
analytics_recent_activity_stats,
analytics_user_summary,
)
from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import (
ConfigParamsCreate,
ConfigParamsDelete,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
ConfigKey,
ConfigPilotRun,
)
from app.core.memory.utils.config.definitions import reload_configuration_from_database
from app.dependencies import get_current_user
from app.models.user_model import User
# Get API logger
api_logger = get_api_logger()
# Initialize service
memory_storage_service = MemoryStorageService()
router = APIRouter(
prefix="/memory-storage",
tags=["Memory Storage"],
)
@router.get("/info", response_model=ApiResponse)
async def get_storage_info(
storage_id: str,
current_user: User = Depends(get_current_user)
):
"""
Example wrapper endpoint - retrieves storage information
Args:
storage_id: Storage identifier
Returns:
Storage information
"""
api_logger.info(f"Storage info requested ")
try:
result = await memory_storage_service.get_storage_info()
return success(data=result)
except Exception as e:
api_logger.error(f"Storage info retrieval failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
# --- DB connection dependency ---
_CONN: Optional[object] = None
"""PostgreSQL 连接生成与管理(使用 psycopg2"""
# 这个可以转移,可能是已经有的
# PostgreSQL 数据库连接
def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接
host = os.getenv("DB_HOST")
user = os.getenv("DB_USER")
password = os.getenv("DB_PASSWORD")
database = os.getenv("DB_NAME")
port_str = os.getenv("DB_PORT")
try:
import psycopg2 # type: ignore
port = int(port_str) if port_str else 5432
conn = psycopg2.connect(
host=host or "localhost",
port=port,
user=user,
password=password,
dbname=database,
)
# 设置自动提交,避免显式事务管理
conn.autocommit = True
# 设置会话时区为中国标准时间Asia/Shanghai便于直接以本地时区展示
try:
cur = conn.cursor()
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
cur.close()
except Exception:
# 时区设置失败不影响连接,仅记录但不抛出
pass
return conn
except Exception as e:
try:
print(f"[PostgreSQL] 连接失败: {e}")
except Exception:
pass
return None
def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接
global _CONN
if _CONN is None:
_CONN = _make_pgsql_conn()
return _CONN
def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
"""Close and recreate the global DB connection."""
global _CONN
try:
if _CONN:
try:
_CONN.close()
except Exception:
pass
_CONN = _make_pgsql_conn()
return _CONN is not None
except Exception:
_CONN = None
return False
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
def create_config(
payload: ConfigParamsCreate,
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}")
try:
# 将 workspace_id 注入到 payload 中(保持为 UUID 类型)
payload.workspace_id = workspace_id
svc = DataConfigService(get_db_conn())
result = svc.create(payload)
return success(data=result, msg="创建成功")
except Exception as e:
api_logger.error(f"Create config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config(
config_id: str,
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
try:
svc = DataConfigService(get_db_conn())
result = svc.delete(ConfigParamsDelete(config_id=config_id))
return success(data=result, msg="删除成功")
except Exception as e:
api_logger.error(f"Delete config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
def update_config(
payload: ConfigUpdate,
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
try:
svc = DataConfigService(get_db_conn())
result = svc.update(payload)
return success(data=result, msg="更新成功")
except Exception as e:
api_logger.error(f"Update config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "更新配置失败", str(e))
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
def update_config_extracted(
payload: ConfigUpdateExtracted,
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}")
try:
svc = DataConfigService(get_db_conn())
result = svc.update_extracted(payload)
return success(data=result, msg="更新成功")
except Exception as e:
api_logger.error(f"Update config extracted failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "更新配置失败", str(e))
# --- Forget config params ---
@router.post("/update_config_forget", response_model=ApiResponse) # 更新遗忘引擎配置参数(固定路径)
def update_config_forget(
payload: ConfigUpdateForget,
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}")
try:
svc = DataConfigService(get_db_conn())
result = svc.update_forget(payload)
return success(data=result, msg="更新成功")
except Exception as e:
api_logger.error(f"Update config forget failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e))
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted(
config_id: str,
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}")
try:
svc = DataConfigService(get_db_conn())
result = svc.get_extracted(ConfigKey(config_id=config_id))
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Read config extracted failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
@router.get("/read_config_forget", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_forget(
config_id: str,
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}")
try:
svc = DataConfigService(get_db_conn())
result = svc.get_forget(ConfigKey(config_id=config_id))
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Read config forget failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e))
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
def read_all_config(
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置")
try:
svc = DataConfigService(get_db_conn())
# 传递 workspace_id 进行过滤(保持为 UUID 类型)
result = svc.get_all(workspace_id=workspace_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Read all config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询所有配置失败", str(e))
@router.post("/pilot_run", response_model=ApiResponse) # 试运行:触发执行主管线,使用 POST 更为合理
async def pilot_run(
payload: ConfigPilotRun,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
# 先尝试从数据库加载配置
try:
config_loaded = reload_configuration_from_database(str(payload.config_id))
if not config_loaded:
api_logger.error(f"Failed to load configuration for config_id: {payload.config_id}")
return fail(BizCode.INTERNAL_ERROR, "配置加载失败", f"无法加载 config_id={payload.config_id} 的配置")
api_logger.info(f"Configuration loaded successfully for config_id: {payload.config_id}")
except Exception as e:
api_logger.error(f"Exception while loading configuration: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "配置加载异常", str(e))
try:
svc = DataConfigService(get_db_conn())
result = await svc.pilot_run(payload)
return success(data=result, msg="试运行完成")
except ValueError as e:
# 捕获参数验证错误
api_logger.error(f"Pilot run parameter validation failed: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, "参数验证失败", str(e))
except Exception as e:
api_logger.error(f"Pilot run failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "试运行失败", str(e))
"""
以下为搜索与分析接口,直接挂载到同一 router统一响应为 ApiResponse。
"""
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
async def get_kb_type_distribution(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
try:
result = await kb_type_distribution(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"KB type distribution failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e))
@router.get("/search/dialogue", response_model=ApiResponse)
async def search_dialogues_num(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
try:
result = await search_dialogue(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search dialogue failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "对话查询失败", str(e))
@router.get("/search/chunk", response_model=ApiResponse)
async def search_chunks_num(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
try:
result = await search_chunk(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search chunk failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "分块查询失败", str(e))
@router.get("/search/statement", response_model=ApiResponse)
async def search_statements_num(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
try:
result = await search_statement(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search statement failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "语句查询失败", str(e))
@router.get("/search/entity", response_model=ApiResponse)
async def search_entities_num(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
try:
result = await search_entity(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search entity failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "实体查询失败", str(e))
@router.get("/search", response_model=ApiResponse)
async def search_all_num(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
try:
result = await search_all(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search all failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "全部查询失败", str(e))
@router.get("/search/detials", response_model=ApiResponse)
async def search_entities_detials(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
try:
result = await search_detials(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search details failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "详情查询失败", str(e))
@router.get("/search/edges", response_model=ApiResponse)
async def search_entity_edges(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
try:
result = await search_edges(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search edges failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
@router.get("/search/entity_graph", response_model=ApiResponse)
async def search_for_entity_graph(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
"""
搜索所有实体之间的关系网络
"""
api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}")
try:
result = await search_entity_graph(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search entity graph failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e))
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
async def get_hot_memory_tags_api(
end_user_id: Optional[str] = None,
limit: int = 10,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Hot memory tags requested for end_user_id: {end_user_id}")
try:
result = await analytics_hot_memory_tags(end_user_id, limit)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Hot memory tags failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
async def get_memory_insight_report_api(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"Memory insight report requested for end_user_id: {end_user_id}")
try:
result = await analytics_memory_insight_report(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Memory insight report failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告生成失败", str(e))
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
async def get_recent_activity_stats_api(
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info("Recent activity stats requested")
try:
result = await analytics_recent_activity_stats()
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Recent activity stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
@router.get("/analytics/user_summary", response_model=ApiResponse)
async def get_user_summary_api(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
api_logger.info(f"User summary requested for end_user_id: {end_user_id}")
try:
result = await analytics_user_summary(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"User summary failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户摘要生成失败", str(e))
from app.core.memory.utils.self_reflexion_utils import self_reflexion
@router.get("/self_reflexion")
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
"""
自我反思接口,自动对检索出的信息进行自我反思并返回自我反思结果。
Args:
None
Returns:
自我反思结果。
"""
return await self_reflexion(host_id)

View File

@@ -0,0 +1,332 @@
from fastapi import APIRouter, Depends, status, Query
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.core.models import RedBearLLM
from app.core.models.base import RedBearModelConfig
from app.db import get_db
from app.dependencies import get_current_user
from app.models.models_model import ModelProvider, ModelType
from app.models.user_model import User
from app.schemas import model_schema
from app.core.response_utils import success
from app.schemas.response_schema import ApiResponse, PageData
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.core.logging_config import get_api_logger
# 获取API专用日志器
api_logger = get_api_logger()
router = APIRouter(
prefix="/models",
tags=["Models"],
)
@router.get("/type", response_model=ApiResponse)
def get_model_types():
return success(msg="获取模型类型成功", data=list(ModelType))
@router.get("/provider", response_model=ApiResponse)
def get_model_providers():
return success(msg="获取模型提供商成功", data=list(ModelProvider))
@router.get("", response_model=ApiResponse)
def get_model_list(
type: Optional[List[model_schema.ModelType]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM&type=EMBEDDING"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
search: Optional[str] = Query(None, description="搜索关键词"),
page: int = Query(1, ge=1, description="页码"),
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
db: Session = Depends(get_db)
):
"""
获取模型配置列表
支持多个 type 参数:
- 单个:?type=LLM
- 多个:?type=LLM&type=EMBEDDING
"""
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}")
try:
query = model_schema.ModelConfigQuery(
type=type,
provider=provider,
is_active=is_active,
is_public=is_public,
search=search,
page=page,
pagesize=pagesize
)
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
result_orm = ModelConfigService.get_model_list(db=db, query=query)
result = PageData.model_validate(result_orm)
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
return success(data=result, msg="模型配置列表获取成功")
except Exception as e:
api_logger.error(f"获取模型配置列表失败: {str(e)}")
raise
@router.get("/{model_id}", response_model=ApiResponse)
def get_model_by_id(
model_id: uuid.UUID,
db: Session = Depends(get_db)
):
"""
根据ID获取模型配置
"""
api_logger.info(f"获取模型配置请求: model_id={model_id}")
try:
api_logger.debug(f"开始获取模型配置: model_id={model_id}")
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
api_logger.info(f"模型配置获取成功: {result_orm.name}")
# 将ORM对象转换为Pydantic模型
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
return success(data=result_pydantic, msg="模型配置获取成功")
except Exception as e:
api_logger.error(f"获取模型配置失败: model_id={model_id} - {str(e)}")
raise
@router.post("", response_model=ApiResponse)
async def create_model(
model_data: model_schema.ModelConfigCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
创建模型配置
- 创建模型配置基础信息
- 如果包含 API Key会先验证配置有效性然后创建
- 验证失败时会抛出异常,不会创建配置
- 可通过 skip_validation=true 跳过验证
"""
api_logger.info(f"创建模型配置请求: {model_data.name}, 用户: {current_user.username}")
try:
api_logger.debug(f"开始创建模型配置: {model_data.name}")
result_orm = await ModelConfigService.create_model(db=db, model_data=model_data)
api_logger.info(f"模型配置创建成功: {result_orm.name} (ID: {result_orm.id})")
# 将ORM对象转换为Pydantic模型
result = model_schema.ModelConfig.model_validate(result_orm)
return success(data=result, msg="模型配置创建成功")
except Exception as e:
api_logger.error(f"创建模型配置失败: {model_data.name} - {str(e)}")
raise
@router.put("/{model_id}", response_model=ApiResponse)
def update_model(
model_id: uuid.UUID,
model_data: model_schema.ModelConfigUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
更新模型配置
"""
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}")
try:
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
result_orm = ModelConfigService.update_model(db=db, model_id=model_id, model_data=model_data)
api_logger.info(f"模型配置更新成功: {result_orm.name} (ID: {model_id})")
# 将ORM对象转换为Pydantic模型
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
return success(data=result_pydantic, msg="模型配置更新成功")
except Exception as e:
api_logger.error(f"更新模型配置失败: model_id={model_id} - {str(e)}")
raise
@router.delete("/{model_id}", response_model=ApiResponse)
def delete_model(
model_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
删除模型配置
"""
api_logger.info(f"删除模型配置请求: model_id={model_id}, 用户: {current_user.username}")
try:
api_logger.debug(f"开始删除模型配置: model_id={model_id}")
ModelConfigService.delete_model(db=db, model_id=model_id)
api_logger.info(f"模型配置删除成功: model_id={model_id}")
return success(msg="模型配置删除成功")
except Exception as e:
api_logger.error(f"删除模型配置失败: model_id={model_id} - {str(e)}")
raise
# API Key 相关接口
@router.get("/{model_id}/apikeys", response_model=ApiResponse)
def get_model_api_keys(
model_id: uuid.UUID,
is_active: bool = Query(True, description="是否只获取活跃的API Key"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取模型的API Key列表
"""
api_logger.info(f"获取模型API Key列表请求: model_id={model_id}, 用户: {current_user.username}")
try:
api_logger.debug(f"开始获取模型API Key列表: model_id={model_id}")
result_orm = ModelApiKeyService.get_api_keys_by_model(
db=db, model_config_id=model_id, is_active=is_active
)
# 将ORM对象列表转换为Pydantic模型列表
result_pydantic = [model_schema.ModelApiKey.model_validate(item) for item in result_orm]
api_logger.info(f"模型API Key列表获取成功: 数量={len(result_pydantic)}")
return success(data=result_pydantic, msg="模型API Key列表获取成功")
except Exception as e:
api_logger.error(f"获取模型API Key列表失败: model_id={model_id} - {str(e)}")
raise
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
async def create_model_api_key(
model_id: uuid.UUID,
api_key_data: model_schema.ModelApiKeyCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
为模型创建API Key
"""
api_logger.info(f"创建模型API Key请求: model_id={model_id}, model_name={api_key_data.model_name}, 用户: {current_user.username}")
try:
# 设置模型配置ID
api_key_data.model_config_id = model_id
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})")
return success(data=result, msg="模型API Key创建成功")
except Exception as e:
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
raise
@router.get("/apikeys/{api_key_id}", response_model=ApiResponse)
def get_api_key_by_id(
api_key_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
根据ID获取API Key
"""
api_logger.info(f"获取API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}")
try:
api_logger.debug(f"开始获取API Key: api_key_id={api_key_id}")
result = ModelApiKeyService.get_api_key_by_id(db=db, api_key_id=api_key_id)
api_logger.info(f"API Key获取成功: {result.model_name}")
return success(data=result, msg="API Key获取成功")
except Exception as e:
api_logger.error(f"获取API Key失败: api_key_id={api_key_id} - {str(e)}")
raise
@router.put("/apikeys/{api_key_id}", response_model=ApiResponse)
async def update_api_key(
api_key_id: uuid.UUID,
api_key_data: model_schema.ModelApiKeyUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
更新API Key
"""
api_logger.info(f"更新API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}")
try:
api_logger.debug(f"开始更新API Key: api_key_id={api_key_id}")
result = await ModelApiKeyService.update_api_key(db=db, api_key_id=api_key_id, api_key_data=api_key_data)
api_logger.info(f"API Key更新成功: {result.model_name} (ID: {api_key_id})")
result_pydantic = model_schema.ModelApiKey.model_validate(result)
return success(data=result_pydantic, msg="API Key更新成功")
except Exception as e:
api_logger.error(f"更新API Key失败: api_key_id={api_key_id} - {str(e)}")
raise
@router.delete("/apikeys/{api_key_id}", response_model=ApiResponse)
def delete_api_key(
api_key_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
删除API Key
"""
api_logger.info(f"删除API Key请求: api_key_id={api_key_id}, 用户: {current_user.username}")
try:
api_logger.debug(f"开始删除API Key: api_key_id={api_key_id}")
ModelApiKeyService.delete_api_key(db=db, api_key_id=api_key_id)
api_logger.info(f"API Key删除成功: api_key_id={api_key_id}")
return success(msg="API Key删除成功")
except Exception as e:
api_logger.error(f"删除API Key失败: api_key_id={api_key_id} - {str(e)}")
raise
@router.post("/validate", response_model=ApiResponse)
async def validate_model_config(
validate_data: model_schema.ModelValidateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
验证模型配置是否有效
支持验证不同类型的模型:
- llm: 大语言模型
- chat: 对话模型
- embedding: 向量模型
- rerank: 重排序模型
"""
api_logger.info(f"验证模型配置请求: {validate_data.model_name} ({validate_data.model_type}), 用户: {current_user.username}")
result = await ModelConfigService.validate_model_config(
db=db,
model_name=validate_data.model_name,
provider=validate_data.provider,
api_key=validate_data.api_key,
api_base=validate_data.api_base,
model_type=validate_data.model_type,
test_message=validate_data.test_message
)
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")

View File

@@ -0,0 +1,404 @@
"""多 Agent 控制器"""
import uuid
from fastapi import APIRouter, Depends, Query, Path
from sqlalchemy.orm import Session
from app.db import get_db
from app.dependencies import get_current_user
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
from app.schemas import multi_agent_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services.multi_agent_service import MultiAgentService
from app.models import User
router = APIRouter(prefix="/apps", tags=["Multi-Agent"])
logger = get_business_logger()
# ==================== 多 Agent 配置管理 ====================
@router.post(
"/{app_id}/multi-agent",
summary="创建多 Agent 配置"
)
def create_multi_agent_config(
app_id: uuid.UUID = Path(..., description="应用 ID"),
data: multi_agent_schema.MultiAgentConfigCreate = ...,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""创建多 Agent 配置
支持四种编排模式:
- sequential: 顺序执行
- parallel: 并行执行
- conditional: 条件路由
- loop: 循环执行
"""
service = MultiAgentService(db)
config = service.create_config(
app_id=app_id,
data=data,
created_by=current_user.id
)
return success(
data=multi_agent_schema.MultiAgentConfigSchema.model_validate(config),
msg="多 Agent 配置创建成功"
)
@router.get(
"/{app_id}/multi-agent",
summary="获取当前应用的最新有效多 Agent 配置"
)
def get_multi_agent_configs(
app_id: uuid.UUID = Path(..., description="应用 ID"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取指定应用的最新有效多 Agent 配置,如果不存在则返回默认模板"""
service = MultiAgentService(db)
# 通过 app_id 获取最新有效配置(已转换 agent_id 为 app_id
config = service.get_multi_agent_configs(app_id)
if not config:
# 返回默认模板
default_template = {
"app_id": str(app_id),
"master_agent_id": None,
"master_agent_name": None,
"orchestration_mode": "conditional",
"sub_agents": [],
"routing_rules": [],
"execution_config": {
"max_iterations": 10,
"timeout": 300,
"enable_parallel": False,
"error_handling": "stop"
},
"aggregation_strategy": "merge",
}
return success(
data=default_template,
msg="该应用暂无配置,返回默认模板"
)
# config 已经是字典格式,直接返回
return success(data=config)
@router.put(
"/{app_id}/multi-agent",
summary="更新多 Agent 配置"
)
def update_multi_agent_config(
app_id: uuid.UUID = Path(..., description="应用 ID"),
data: multi_agent_schema.MultiAgentConfigUpdate = ...,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""更新多 Agent 配置"""
service = MultiAgentService(db)
config = service.update_config(app_id, data)
return success(
data=multi_agent_schema.MultiAgentConfigSchema.model_validate(config),
msg="多 Agent 配置更新成功"
)
@router.delete(
"/{app_id}/multi-agent",
summary="删除多 Agent 配置"
)
def delete_multi_agent_config(
app_id: uuid.UUID = Path(..., description="应用 ID"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除多 Agent 配置"""
service = MultiAgentService(db)
service.delete_config(app_id)
return success(msg="多 Agent 配置删除成功")
# ==================== 多 Agent 运行 ====================
@router.post(
"/{app_id}/multi-agent/run",
summary="运行多 Agent 任务"
)
async def run_multi_agent(
app_id: uuid.UUID = Path(..., description="应用 ID"),
request: multi_agent_schema.MultiAgentRunRequest = ...,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""运行多 Agent 任务
根据配置的编排模式执行多个 Agent
- sequential: 按优先级顺序执行
- parallel: 并行执行所有 Agent
- conditional: 根据条件选择 Agent
- loop: 循环执行直到满足条件
"""
service = MultiAgentService(db)
result = await service.run(app_id, request)
return success(
data=multi_agent_schema.MultiAgentRunResponse(**result),
msg="多 Agent 任务执行成功"
)
# ==================== 智能路由测试 ====================
@router.post(
"/{app_id}/multi-agent/test-routing",
summary="测试智能路由"
)
async def test_routing(
app_id: uuid.UUID = Path(..., description="应用 ID"),
request: multi_agent_schema.RoutingTestRequest = ...,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""测试智能路由功能
支持三种路由模式:
- keyword: 仅使用关键词路由
- llm: 使用 LLM 路由(需要提供 routing_model_id
- hybrid: 混合路由(关键词 + LLM
参数:
- message: 测试消息
- conversation_id: 会话 ID可选
- routing_model_id: 路由模型 ID可选用于 LLM 路由)
- use_llm: 是否启用 LLM默认 False
- keyword_threshold: 关键词置信度阈值(默认 0.8
"""
from app.services.conversation_state_manager import ConversationStateManager
from app.services.llm_router import LLMRouter
from app.models import ModelConfig
# 1. 获取多 Agent 配置
service = MultiAgentService(db)
config = service.get_config(app_id)
if not config:
return success(
data=None,
msg="应用未配置多 Agent无法测试路由"
)
# 2. 准备子 Agent 信息
sub_agents = {}
for sub_agent_info in config.sub_agents:
agent_id = sub_agent_info["agent_id"]
sub_agents[agent_id] = {
"name": sub_agent_info.get("name", agent_id),
"role": sub_agent_info.get("role", "")
}
# 3. 获取路由模型(如果指定)
routing_model = None
if request.routing_model_id:
routing_model = db.get(ModelConfig, request.routing_model_id)
if not routing_model:
return success(
data=None,
msg=f"路由模型不存在: {request.routing_model_id}"
)
# 4. 初始化路由器
state_manager = ConversationStateManager()
router = LLMRouter(
db=db,
state_manager=state_manager,
routing_rules=config.routing_rules or [],
sub_agents=sub_agents,
routing_model_config=routing_model,
use_llm=request.use_llm and routing_model is not None
)
# 5. 设置阈值
if request.keyword_threshold:
router.keyword_high_confidence_threshold = request.keyword_threshold
# 6. 执行路由
try:
routing_result = await router.route(
message=request.message,
conversation_id=str(request.conversation_id) if request.conversation_id else None,
force_new=request.force_new
)
# 7. 获取 Agent 信息
agent_id = routing_result["agent_id"]
agent_info = sub_agents.get(agent_id, {})
# 8. 构建响应
response_data = {
"message": request.message,
"routing_result": {
"agent_id": agent_id,
"agent_name": agent_info.get("name", agent_id),
"agent_role": agent_info.get("role", ""),
"confidence": routing_result["confidence"],
"strategy": routing_result["strategy"],
"topic": routing_result["topic"],
"topic_changed": routing_result["topic_changed"],
"reason": routing_result["reason"],
"routing_method": routing_result["routing_method"]
},
"cmulti-agent/batch-test-routingonfig_info": {
"use_llm": request.use_llm and routing_model is not None,
"routing_model": routing_model.name if routing_model else None,
"keyword_threshold": router.keyword_high_confidence_threshold,
"total_sub_agents": len(sub_agents)
}
}
return success(
data=response_data,
msg="路由测试成功"
)
except Exception as e:
logger.error(f"路由测试失败: {str(e)}")
return success(
data=None,
msg=f"路由测试失败: {str(e)}"
)
@router.post(
"/{app_id}/",
summary="批量测试智能路由"
)
async def batch_test_routing(
app_id: uuid.UUID = Path(..., description="应用 ID"),
request: multi_agent_schema.BatchRoutingTestRequest = ...,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""批量测试智能路由功能
用于测试多条消息的路由效果,并统计准确率
参数:
- test_cases: 测试用例列表
- routing_model_id: 路由模型 ID可选
- use_llm: 是否启用 LLM
- keyword_threshold: 关键词置信度阈值
"""
from app.services.conversation_state_manager import ConversationStateManager
from app.services.llm_router import LLMRouter
from app.models import ModelConfig
# 1. 获取多 Agent 配置
service = MultiAgentService(db)
config = service.get_config(app_id)
if not config:
return success(
data=None,
msg="应用未配置多 Agent无法测试路由"
)
# 2. 准备子 Agent 信息
sub_agents = {}
for sub_agent_info in config.sub_agents:
agent_id = sub_agent_info["agent_id"]
sub_agents[agent_id] = {
"name": sub_agent_info.get("name", agent_id),
"role": sub_agent_info.get("role", "")
}
# 3. 获取路由模型
routing_model = None
if request.routing_model_id:
routing_model = db.get(ModelConfig, request.routing_model_id)
# 4. 初始化路由器
state_manager = ConversationStateManager()
router = LLMRouter(
db=db,
state_manager=state_manager,
routing_rules=config.routing_rules or [],
sub_agents=sub_agents,
routing_model_config=routing_model,
use_llm=request.use_llm and routing_model is not None
)
if request.keyword_threshold:
router.keyword_high_confidence_threshold = request.keyword_threshold
# 5. 批量测试
results = []
correct_count = 0
total_count = len(request.test_cases)
for test_case in request.test_cases:
try:
routing_result = await router.route(
message=test_case.message,
conversation_id=str(uuid.uuid4()) # 每个测试用例使用独立会话
)
agent_id = routing_result["agent_id"]
agent_info = sub_agents.get(agent_id, {})
# 判断是否正确
is_correct = None
if test_case.expected_agent_id:
is_correct = (agent_id == str(test_case.expected_agent_id))
if is_correct:
correct_count += 1
results.append({
"message": test_case.message,
"description": test_case.description,
"routed_agent_id": agent_id,
"routed_agent_name": agent_info.get("name"),
"expected_agent_id": str(test_case.expected_agent_id) if test_case.expected_agent_id else None,
"is_correct": is_correct,
"confidence": routing_result["confidence"],
"routing_method": routing_result["routing_method"],
"strategy": routing_result["strategy"]
})
except Exception as e:
logger.error(f"测试用例失败: {test_case.message}, 错误: {str(e)}")
results.append({
"message": test_case.message,
"description": test_case.description,
"error": str(e)
})
# 6. 统计
accuracy = None
if correct_count > 0:
total_with_expected = sum(1 for r in results if r.get("expected_agent_id"))
if total_with_expected > 0:
accuracy = correct_count / total_with_expected * 100
response_data = {
"total_count": total_count,
"correct_count": correct_count,
"accuracy": accuracy,
"results": results,
"config_info": {
"use_llm": request.use_llm and routing_model is not None,
"routing_model": routing_model.name if routing_model else None,
"keyword_threshold": router.keyword_high_confidence_threshold
}
}
return success(
data=response_data,
msg=f"批量测试完成,准确率: {accuracy:.1f}%" if accuracy else "批量测试完成"
)

View File

@@ -0,0 +1,437 @@
from fastapi import APIRouter, Depends, Query, Request, Header
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
import uuid
import hashlib
import time
import jwt
from typing import Optional, Dict
from functools import wraps
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.core.config import settings
from app.schemas import release_share_schema, conversation_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService
from app.services.conversation_service import ConversationService
from app.services.auth_service import create_access_token
from app.dependencies import get_share_user_id, ShareTokenData
router = APIRouter(prefix="/public/share", tags=["Public Share"])
logger = get_business_logger()
def get_base_url(request: Request) -> str:
"""从请求中获取基础 URL"""
return f"{request.url.scheme}://{request.url.netloc}"
def get_or_generate_user_id(payload_user_id: str, request: Request) -> str:
"""获取或生成用户 ID
优先级:
1. 使用前端传递的 user_id
2. 基于 IP + User-Agent 生成唯一 ID
Args:
payload_user_id: 前端传递的 user_id
request: FastAPI Request 对象
Returns:
用户 ID
"""
if payload_user_id:
return payload_user_id
# 获取客户端 IP
client_ip = request.client.host if request.client else "unknown"
# 获取 User-Agent
user_agent = request.headers.get("user-agent", "unknown")
# 生成唯一 ID基于 IP + User-Agent 的哈希
unique_string = f"{client_ip}_{user_agent}"
hash_value = hashlib.md5(unique_string.encode()).hexdigest()[:16]
return f"guest_{hash_value}"
@router.post(
"/{share_token}/token",
summary="获取访问 token"
)
def get_access_token(
share_token: str,
payload: release_share_schema.TokenRequest,
request: Request,
db: Session = Depends(get_db),
):
"""获取访问 token
- 用户通过 user_id + share_token 换取访问 token
- 后续请求需要携带此 token
"""
# 获取或生成 user_id
user_id = get_or_generate_user_id(payload.user_id, request)
# 验证分享链接(可选:验证密码)
service = ReleaseShareService(db)
try:
service.get_shared_release_info(
share_token=share_token,
password=payload.password
)
except Exception as e:
logger.error(f"获取分享信息失败: {str(e)}")
raise
# 生成 token
access_token = create_access_token(user_id, share_token)
logger.info(
f"生成访问 token",
extra={
"share_token": share_token,
"user_id": user_id
}
)
return success(data={
"access_token": access_token,
"token_type": "Bearer",
"user_id": user_id
})
@router.get(
"",
summary="获取公开分享的应用信息",
response_model=None
)
def get_shared_release(
password: str = Query(None, description="访问密码(如果需要)"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取公开分享的发布版本信息
- 无需认证即可访问
- 如果设置了密码保护,需要提供正确的密码
- 如果密码错误或未提供密码,返回基本信息(不含配置详情)
"""
service = ReleaseShareService(db)
info = service.get_shared_release_info(
share_token=share_data.share_token,
password=password
)
return success(data=info)
@router.post(
"/verify",
summary="验证访问密码"
)
def verify_password(
payload: release_share_schema.PasswordVerifyRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""验证分享的访问密码
- 用于前端先验证密码,再获取完整信息
"""
service = ReleaseShareService(db)
is_valid = service.verify_password(
share_token=share_data.share_token,
password=payload.password
)
return success(data={"valid": is_valid})
@router.get(
"/embed",
summary="获取嵌入代码"
)
def get_embed_code(
width: str = Query("100%", description="iframe 宽度"),
height: str = Query("600px", description="iframe 高度"),
request: Request = None,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取嵌入代码
- 返回 iframe 嵌入代码
- 可以自定义宽度和高度
"""
base_url = get_base_url(request) if request else None
service = ReleaseShareService(db)
embed_code = service.get_embed_code(
share_token=share_data.share_token,
width=width,
height=height,
base_url=base_url
)
return success(data=embed_code)
# ---------- 会话管理接口 ----------
@router.get(
"/conversations",
summary="获取会话列表"
)
def list_conversations(
password: str = Query(None, description="访问密码"),
page: int = Query(1, ge=1),
pagesize: int = Query(20, ge=1, le=100),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取分享应用的会话列表
- 可以按 user_id 筛选
- 支持分页
"""
logger.debug(f"share_data:{share_data.user_id}")
other_id = share_data.user_id
service = SharedChatService(db)
share, release = service._get_release_by_share_token(share_data.share_token, password)
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
other_id=other_id
)
logger.debug(new_end_user.id)
service = SharedChatService(db)
conversations, total = service.list_conversations(
share_token=share_data.share_token,
user_id=str(new_end_user.id),
password=password,
page=page,
pagesize=pagesize
)
items = [conversation_schema.Conversation.model_validate(c) for c in conversations]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
return success(data=PageData(page=meta, items=items))
@router.get(
"/conversations/{conversation_id}",
summary="获取会话详情(含消息)"
)
def get_conversation(
conversation_id: uuid.UUID,
password: str = Query(None, description="访问密码"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取会话详情和消息历史"""
chat_service = SharedChatService(db)
conversation = chat_service.get_conversation_messages(
share_token=share_data.share_token,
conversation_id=conversation_id,
password=password
)
# 获取消息
conv_service = ConversationService(db)
messages = conv_service.get_messages(conversation_id)
# 构建响应
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
conv_dict["messages"] = [
conversation_schema.Message.model_validate(m) for m in messages
]
return success(data=conv_dict)
# ---------- 聊天接口 ----------
@router.post(
"/chat",
summary="发送消息(支持流式和非流式)"
)
async def chat(
payload: conversation_schema.ChatRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""发送消息并获取回复
使用 Bearer token 认证:
- Header: Authorization: Bearer {token}
- user_id 和 share_token 从 token 中解码
- 支持多轮对话(提供 conversation_id
- 支持流式返回(设置 stream=true
- 如果不提供 conversation_id会自动创建新会话
"""
service = SharedChatService(db)
# 从依赖中获取 user_id 和 share_token
user_id = share_data.user_id
share_token = share_data.share_token
password = None # Token 认证不需要密码
# end_user_id = user_id
other_id = user_id
# 提前验证和准备(在流式响应开始前完成)
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
from app.models.app_model import AppType
try:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.services.app_service import AppService
# 验证分享链接和密码
share, release = service._get_release_by_share_token(share_token, password)
# # Create end_user_id by concatenating app_id with user_id
# end_user_id = f"{share.app_id}_{user_id}"
# Store end_user_id in database with original user_id
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
other_id=other_id,
original_user_id=user_id # Save original user_id to other_id
)
# 获取应用类型
app_type = release.app.type if release.app else None
# 根据应用类型验证配置
if app_type == "agent":
# Agent 类型:验证模型配置
model_config_id = release.default_model_config_id
if not model_config_id:
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
elif app_type == "multi_agent":
# Multi-Agent 类型:验证多 Agent 配置
config = release.config or {}
if not config.get("sub_agents"):
raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING)
else:
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
# 获取或创建会话(提前验证)
conversation = service.create_or_get_conversation(
share_token=share_data.share_token,
conversation_id=payload.conversation_id,
user_id=str(new_end_user.id), # 转换为字符串
password=password
)
logger.debug(
f"参数验证完成",
extra={
"share_token": share_token,
"app_type": app_type,
"conversation_id": str(conversation.id),
"stream": payload.stream
}
)
except Exception as e:
# 验证失败,直接抛出异常(会被 FastAPI 的异常处理器捕获)
logger.error(f"参数验证失败: {str(e)}")
raise
if app_type == AppType.AGENT:
# 流式返回
if payload.stream:
async def event_generator():
async for event in service.chat_stream(
share_token=share_token,
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
password=password,
web_search=payload.web_search,
memory=payload.memory
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 非流式返回
result = await service.chat(
share_token=share_token,
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
password=password,
web_search=payload.web_search,
memory=payload.memory
)
return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.MULTI_AGENT:
# 多 Agent 流式返回
if payload.stream:
async def event_generator():
async for event in service.multi_agent_chat_stream(
share_token=share_token,
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
password=password,
web_search=payload.web_search,
memory=payload.memory
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 多 Agent 非流式返回
result = await service.multi_agent_chat(
share_token=share_token,
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
password=password,
web_search=payload.web_search,
memory=payload.memory
)
return success(data=conversation_schema.ChatResponse(**result))
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
pass

View File

@@ -0,0 +1,170 @@
import uuid
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
from app.schemas import release_share_schema
from app.services.release_share_service import ReleaseShareService
from app.dependencies import get_current_user, cur_workspace_access_guard
router = APIRouter(tags=["Release Share"])
logger = get_business_logger()
def get_base_url(request: Request) -> str:
"""从请求中获取基础 URL"""
return f"{request.url.scheme}://{request.url.netloc}"
@router.post(
"/apps/{app_id}/releases/{release_id}/share",
summary="创建/启用分享配置"
)
@cur_workspace_access_guard()
def create_share(
app_id: uuid.UUID,
release_id: uuid.UUID,
payload: release_share_schema.ReleaseShareCreate,
request: Request,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""创建或更新发布版本的分享配置
- 如果已存在分享配置,则更新
- 自动生成唯一的分享 token
- 返回完整的分享 URL
"""
workspace_id = current_user.current_workspace_id
base_url = get_base_url(request)
service = ReleaseShareService(db)
share = service.create_or_update_share(
release_id=release_id,
user_id=current_user.id,
workspace_id=workspace_id,
data=payload,
base_url=base_url
)
share_schema = service._convert_to_schema(share, base_url)
return success(data=share_schema, msg="分享配置已创建")
@router.put(
"/apps/{app_id}/releases/{release_id}/share",
summary="更新分享配置"
)
@cur_workspace_access_guard()
def update_share(
app_id: uuid.UUID,
release_id: uuid.UUID,
payload: release_share_schema.ReleaseShareUpdate,
request: Request,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""更新分享配置
- 可以更新启用状态、密码、嵌入设置等
- 不会改变 share_token
"""
workspace_id = current_user.current_workspace_id
base_url = get_base_url(request)
service = ReleaseShareService(db)
share = service.update_share(
release_id=release_id,
workspace_id=workspace_id,
data=payload
)
share_schema = service._convert_to_schema(share, base_url)
return success(data=share_schema, msg="分享配置已更新")
@router.get(
"/apps/{app_id}/releases/{release_id}/share",
summary="获取分享配置"
)
@cur_workspace_access_guard()
def get_share(
app_id: uuid.UUID,
release_id: uuid.UUID,
request: Request,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""获取发布版本的分享配置
- 如果不存在分享配置,返回 null
"""
workspace_id = current_user.current_workspace_id
base_url = get_base_url(request)
service = ReleaseShareService(db)
share = service.get_share(
release_id=release_id,
workspace_id=workspace_id,
base_url=base_url
)
return success(data=share)
@router.delete(
"/apps/{app_id}/releases/{release_id}/share",
summary="删除分享配置"
)
@cur_workspace_access_guard()
def delete_share(
app_id: uuid.UUID,
release_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""删除分享配置
- 删除后,公开访问链接将失效
"""
workspace_id = current_user.current_workspace_id
service = ReleaseShareService(db)
service.delete_share(
release_id=release_id,
workspace_id=workspace_id
)
return success(msg="分享配置已删除")
@router.post(
"/apps/{app_id}/releases/{release_id}/share/regenerate-token",
summary="重新生成分享链接"
)
@cur_workspace_access_guard()
def regenerate_token(
app_id: uuid.UUID,
release_id: uuid.UUID,
request: Request,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""重新生成分享 token
- 旧的分享链接将失效
- 生成新的唯一 token
"""
workspace_id = current_user.current_workspace_id
base_url = get_base_url(request)
service = ReleaseShareService(db)
share = service.regenerate_token(
release_id=release_id,
workspace_id=workspace_id
)
share_schema = service._convert_to_schema(share, base_url)
return success(data=share_schema, msg="分享链接已重新生成")

View File

@@ -0,0 +1,17 @@
"""Service API Controllers - 基于 API Key 认证的服务接口
路由前缀: /v1
认证方式: API Key
"""
from fastapi import APIRouter
from . import app_api_controller, rag_api_controller, memory_api_controller
# 创建 V1 API 路由器
service_router = APIRouter()
# 注册子路由
service_router.include_router(app_api_controller.router)
service_router.include_router(rag_api_controller.router)
service_router.include_router(memory_api_controller.router)
__all__ = ["service_router"]

View File

@@ -0,0 +1,16 @@
"""App 服务接口 - 基于 API Key 认证"""
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
router = APIRouter(prefix="/v1/apps", tags=["V1 - App API"])
logger = get_business_logger()
@router.get("")
async def list_apps():
"""列出可访问的应用(占位)"""
return success(data=[], msg="App API - Coming Soon")

View File

@@ -0,0 +1,16 @@
"""Memory 服务接口 - 基于 API Key 认证"""
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
logger = get_business_logger()
@router.get("")
async def get_memory_info():
"""获取记忆服务信息(占位)"""
return success(data={}, msg="Memory API - Coming Soon")

View File

@@ -0,0 +1,16 @@
"""RAG 服务接口 - 基于 API Key 认证"""
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
router = APIRouter(prefix="/knowledge", tags=["V1 - RAG API"])
logger = get_business_logger()
@router.get("")
async def list_knowledge():
"""列出可访问的知识库(占位)"""
return success(data=[], msg="RAG API - Coming Soon")

View File

@@ -0,0 +1,23 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.response_utils import success
from app.db import get_db
from app.schemas.response_schema import ApiResponse
from app.services import user_service
router = APIRouter(
prefix="/setup",
tags=["Setup"],
)
@router.post("", summary="Create the first superuser", response_model=ApiResponse)
def setup_initial_user(db: Session = Depends(get_db)):
"""
Create the initial superuser. This can only be run once.
Reads credentials from environment variables.
"""
user = user_service.create_initial_superuser(db)
if not user:
return success(msg="Superuser already exists.")
return success(msg="Superuser created successfully.")

View File

@@ -0,0 +1,25 @@
from fastapi import APIRouter, status
from app.schemas.item_schema import Item
from app.services import task_service
router = APIRouter(
prefix="/tasks",
tags=["Tasks"],
)
@router.post("/process_item", status_code=status.HTTP_202_ACCEPTED)
def process_item_task(item: Item):
"""
This endpoint receives an item, and instead of processing it directly,
it sends a task to the Celery queue via the task service.
"""
task_id = task_service.create_processing_task(item.dict())
return {"message": "Task accepted. The item is being processed in the background.", "task_id": task_id}
@router.get("/result/{task_id}")
def get_task_result_controller(task_id: str):
"""
This endpoint allows clients to check the status and result of a
previously submitted task using its ID, by calling the task service.
"""
return task_service.get_task_result(task_id)

View File

@@ -0,0 +1,126 @@
from fastapi import APIRouter, Depends, status, Query, HTTPException
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.core.models import RedBearLLM, RedBearRerank
from app.core.models.base import RedBearModelConfig
from app.core.models.embedding import RedBearEmbeddings
from app.db import get_db
from app.dependencies import get_current_user
from app.models.models_model import ModelApiKey, ModelProvider, ModelType
from app.models.user_model import User
from app.schemas import model_schema
from app.core.response_utils import success
from app.schemas.response_schema import ApiResponse, PageData
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.core.logging_config import get_api_logger
# 获取API专用日志器
api_logger = get_api_logger()
router = APIRouter(
prefix="/test",
tags=["test"],
)
@router.get(f"/llm/{{model_id}}", response_model=ApiResponse)
def test_llm(
model_id: uuid.UUID,
db: Session = Depends(get_db)
):
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
api_logger.error(f"模型ID {model_id} 不存在")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
try:
apiConfig: ModelApiKey = config.api_keys[0]
llm = RedBearLLM(RedBearModelConfig(
model_name=apiConfig.model_name,
provider=apiConfig.provider,
api_key=apiConfig.api_key,
base_url=apiConfig.api_base
), type=config.type)
print(llm.dict())
template = """Question: {question}
Answer: Let's think step by step."""
# ChatPromptTemplate
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | llm
answer = chain.invoke({"question": "What is LangChain?"})
print("Answer:", answer)
return success(msg="测试LLM成功", data={"question": "What is LangChain?", "answer": answer})
except Exception as e:
api_logger.error(f"测试LLM失败: {str(e)}")
raise
@router.get(f"/embedding/{{model_id}}", response_model=ApiResponse)
def test_embedding(
model_id: uuid.UUID,
db: Session = Depends(get_db)
):
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
api_logger.error(f"模型ID {model_id} 不存在")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
apiConfig: ModelApiKey = config.api_keys[0]
model = RedBearEmbeddings(RedBearModelConfig(
model_name=apiConfig.model_name,
provider=apiConfig.provider,
api_key=apiConfig.api_key,
base_url=apiConfig.api_base
))
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
embeddings = model.embed_documents(data)
print(embeddings)
query = "我想找一个适合学习的地方。"
query_embedding = model.embed_query(query)
print(query_embedding)
return success(msg="测试LLM成功")
@router.get(f"/rerank/{{model_id}}", response_model=ApiResponse)
def test_rerank(
model_id: uuid.UUID,
db: Session = Depends(get_db)
):
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
api_logger.error(f"模型ID {model_id} 不存在")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
apiConfig: ModelApiKey = config.api_keys[0]
model = RedBearRerank(RedBearModelConfig(
model_name=apiConfig.model_name,
provider=apiConfig.provider,
api_key=apiConfig.api_key,
base_url=apiConfig.api_base
))
query = "最近哪家咖啡店评价最好?"
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
scores = model.rerank(query=query, documents=data, top_n=3)
print(scores)
return success(msg="测试Rerank成功", data={"query": query, "documents": data, "scores": scores})

View File

@@ -0,0 +1,376 @@
"""
Upload Controller for Generic File Upload System
Handles HTTP requests for file upload, download, deletion, and metadata updates.
"""
import os
import json
from typing import List, Optional, Any
from pathlib import Path
from fastapi import APIRouter, Depends, File, UploadFile, Form
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.schemas.generic_file_schema import (
GenericFileResponse,
FileMetadataUpdate,
UploadResultSchema,
BatchUploadResponse
)
from app.core.response_utils import success, fail
from app.core.upload_enums import UploadContext
from app.services.upload_service import UploadService
from app.core.logging_config import get_logger
from app.core.exceptions import (
ValidationException,
ResourceNotFoundException,
FileUploadException,
BusinessException
)
# Get logger
logger = get_logger(__name__)
# Create router
router = APIRouter(
prefix="/api",
tags=["upload"],
dependencies=[Depends(get_current_user)]
)
# Initialize upload service
upload_service = UploadService()
@router.post("/upload", response_model=ApiResponse)
async def upload_file(
file: UploadFile = File(..., description="要上传的文件"),
context: str = Form(..., description="上传上下文 (avatar, app_icon, knowledge_base, temp, attachment)"),
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
单文件上传接口
- **file**: 要上传的文件
- **context**: 上传上下文,决定文件存储位置和验证规则
- **metadata**: 可选的文件元数据JSON格式字符串
返回上传成功的文件信息
"""
logger.info(f"Upload request: filename={file.filename}, context={context}, user={current_user.id}")
try:
# Validate and parse context
try:
upload_context = UploadContext(context)
except ValueError:
logger.warning(f"Invalid upload context: {context}")
raise ValidationException(
f"无效的上传上下文: {context}. 允许的值: {', '.join([c.value for c in UploadContext])}",
field="context"
)
# Parse metadata if provided
file_metadata = {}
if metadata:
try:
file_metadata = json.loads(metadata)
except json.JSONDecodeError:
logger.warning(f"Invalid metadata JSON: {metadata}")
raise ValidationException(
"元数据必须是有效的JSON格式",
field="metadata"
)
# Upload file
db_file = upload_service.upload_file(
file=file,
context=upload_context,
metadata=file_metadata,
current_user=current_user,
db=db
)
# Convert to response schema
file_response = GenericFileResponse.model_validate(db_file)
logger.info(f"Upload successful: {file.filename} (ID: {db_file.id})")
return success(data=file_response.dict(), msg="文件上传成功")
except BusinessException:
# Business exceptions are handled by global exception handlers
raise
except Exception as e:
logger.error(f"Upload failed: {str(e)}")
# Wrap unknown exceptions as FileUploadException
raise FileUploadException(
f"文件上传失败: {str(e)}",
cause=e
)
@router.post("/upload/batch", response_model=ApiResponse)
async def upload_files_batch(
files: List[UploadFile] = File(..., description="要上传的文件列表"),
context: str = Form(..., description="上传上下文 (avatar, app_icon, knowledge_base, temp, attachment)"),
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
批量文件上传接口
- **files**: 要上传的文件列表最多20个
- **context**: 上传上下文,决定文件存储位置和验证规则
- **metadata**: 可选的文件元数据JSON格式字符串应用于所有文件
返回每个文件的上传结果
"""
logger.info(f"Batch upload request: {len(files)} files, context={context}, user={current_user.id}")
try:
# Validate and parse context
try:
upload_context = UploadContext(context)
except ValueError:
logger.warning(f"Invalid upload context: {context}")
raise ValidationException(
f"无效的上传上下文: {context}. 允许的值: {', '.join([c.value for c in UploadContext])}",
field="context"
)
# Parse metadata if provided
file_metadata = {}
if metadata:
try:
file_metadata = json.loads(metadata)
except json.JSONDecodeError:
logger.warning(f"Invalid metadata JSON: {metadata}")
raise ValidationException(
"元数据必须是有效的JSON格式",
field="metadata"
)
# Upload files in batch
upload_results = upload_service.upload_files_batch(
files=files,
context=upload_context,
metadata=file_metadata,
current_user=current_user,
db=db
)
# Convert results to response schemas
result_schemas = []
for result in upload_results:
result_schema = UploadResultSchema(
success=result.success,
file_id=result.file_id,
file_name=result.file_name,
error=result.error,
file_info=None
)
# If upload was successful, get file info
if result.success and result.file_id:
try:
db_file = upload_service.get_file(result.file_id, current_user, db)
result_schema.file_info = GenericFileResponse.model_validate(db_file)
except Exception as e:
logger.warning(f"Failed to get file info for {result.file_id}: {str(e)}")
result_schemas.append(result_schema)
# Create batch response
batch_response = BatchUploadResponse(
total=len(files),
success_count=sum(1 for r in upload_results if r.success),
failed_count=sum(1 for r in upload_results if not r.success),
results=result_schemas
)
logger.info(f"Batch upload completed: {batch_response.success_count}/{batch_response.total} successful")
return success(data=batch_response.dict(), msg="批量上传完成")
except BusinessException:
# Business exceptions are handled by global exception handlers
raise
except Exception as e:
logger.error(f"Batch upload failed: {str(e)}")
# Wrap unknown exceptions as FileUploadException
raise FileUploadException(
f"批量上传失败: {str(e)}",
cause=e
)
@router.get("/files/{file_id}", response_model=Any)
async def download_file(
file_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> Any:
"""
文件下载接口
- **file_id**: 文件ID
返回文件内容供下载
"""
logger.info(f"Download request: file_id={file_id}, user={current_user.id}")
try:
# Parse file_id
import uuid
try:
file_uuid = uuid.UUID(file_id)
except ValueError:
logger.warning(f"Invalid file ID format: {file_id}")
raise ValidationException(
"无效的文件ID格式",
field="file_id"
)
# Get file from database
db_file = upload_service.get_file(file_uuid, current_user, db)
# Check if physical file exists
storage_path = Path(db_file.storage_path)
if not storage_path.exists():
logger.error(f"Physical file not found: {storage_path}")
raise ResourceNotFoundException(
"文件",
str(file_uuid),
context={"detail": "文件未找到(可能已被删除)"}
)
# Return file response
logger.info(f"Download successful: {db_file.file_name} (ID: {file_id})")
return FileResponse(
path=str(storage_path),
filename=db_file.file_name,
media_type=db_file.mime_type or "application/octet-stream"
)
except BusinessException:
# Business exceptions are handled by global exception handlers
raise
except Exception as e:
logger.error(f"Download failed: {str(e)}")
# Wrap unknown exceptions
raise FileUploadException(
f"文件下载失败: {str(e)}",
cause=e
)
@router.delete("/files/{file_id}", response_model=ApiResponse)
async def delete_file(
file_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
文件删除接口
- **file_id**: 文件ID
删除文件(包括物理文件和数据库记录)
"""
logger.info(f"Delete request: file_id={file_id}, user={current_user.id}")
try:
# Parse file_id
import uuid
try:
file_uuid = uuid.UUID(file_id)
except ValueError:
logger.warning(f"Invalid file ID format: {file_id}")
raise ValidationException(
"无效的文件ID格式",
field="file_id"
)
# Delete file
upload_service.delete_file(file_uuid, current_user, db)
logger.info(f"Delete successful: file_id={file_id}")
return success(msg="文件删除成功")
except BusinessException:
# Business exceptions are handled by global exception handlers
raise
except Exception as e:
logger.error(f"Delete failed: {str(e)}")
# Wrap unknown exceptions
raise FileUploadException(
f"文件删除失败: {str(e)}",
cause=e
)
@router.put("/files/{file_id}", response_model=ApiResponse)
async def update_file_metadata(
file_id: str,
update_data: FileMetadataUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
文件元数据更新接口
- **file_id**: 文件ID
- **update_data**: 要更新的元数据
更新文件的元数据(文件名、自定义元数据、公开状态)
"""
logger.info(f"Update metadata request: file_id={file_id}, user={current_user.id}")
try:
# Parse file_id
import uuid
try:
file_uuid = uuid.UUID(file_id)
except ValueError:
logger.warning(f"Invalid file ID format: {file_id}")
raise ValidationException(
"无效的文件ID格式",
field="file_id"
)
# Convert update data to dict, excluding unset fields
update_dict = update_data.dict(exclude_unset=True)
if not update_dict:
logger.warning(f"No fields to update for file: {file_id}")
raise ValidationException(
"没有提供要更新的字段",
field="update_data"
)
# Update file metadata
updated_file = upload_service.update_file_metadata(
file_uuid, update_dict, current_user, db
)
# Convert to response schema
file_response = GenericFileResponse.model_validate(updated_file)
logger.info(f"Update metadata successful: file_id={file_id}")
return success(data=file_response.dict(), msg="文件元数据更新成功")
except BusinessException:
# Business exceptions are handled by global exception handlers
raise
except Exception as e:
logger.error(f"Update metadata failed: {str(e)}")
# Wrap unknown exceptions
raise FileUploadException(
f"文件元数据更新失败: {str(e)}",
cause=e
)

View File

@@ -0,0 +1,183 @@
from fastapi import APIRouter, Depends, status
from sqlalchemy.orm import Session
import uuid
from app.db import get_db
from app.dependencies import get_current_user, get_current_superuser
from app.models.user_model import User
from app.schemas import user_schema
from app.schemas.user_schema import ChangePasswordRequest, AdminChangePasswordRequest
from app.schemas.response_schema import ApiResponse
from app.services import user_service
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
# 获取API专用日志器
api_logger = get_api_logger()
router = APIRouter(
prefix="/users",
tags=["Users"],
)
@router.post("/superuser", response_model=ApiResponse)
def create_superuser(
user: user_schema.UserCreate,
db: Session = Depends(get_db),
current_superuser: User = Depends(get_current_superuser)
):
"""创建超级管理员(仅超级管理员可访问)"""
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
result = user_service.create_superuser(db=db, user=user, current_user=current_superuser)
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg="超级管理员创建成功")
@router.delete("/{user_id}", response_model=ApiResponse)
def delete_user(
user_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""停用用户(软删除)"""
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
result = user_service.deactivate_user(
db=db, user_id_to_deactivate=user_id, current_user=current_user
)
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
return success(msg="用户停用成功")
@router.post("/{user_id}/activate", response_model=ApiResponse)
def activate_user(
user_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""激活用户"""
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
result = user_service.activate_user(
db=db, user_id_to_activate=user_id, current_user=current_user
)
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg="用户激活成功")
@router.get("", response_model=ApiResponse)
def get_current_user_info(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取当前用户信息"""
api_logger.info(f"当前用户信息请求: {current_user.username}")
result = user_service.get_user(
db=db, user_id=current_user.id, current_user=current_user
)
result_schema = user_schema.User.model_validate(result)
# 设置当前工作空间的角色和名称
if current_user.current_workspace_id:
from app.repositories.workspace_repository import WorkspaceRepository
workspace_repo = WorkspaceRepository(db)
current_workspace = workspace_repo.get_workspace_by_id(current_user.current_workspace_id)
if current_workspace:
result_schema.current_workspace_name = current_workspace.name
for ws in result.workspaces:
if ws.workspace_id == current_user.current_workspace_id:
result_schema.role = ws.role
break
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
return success(data=result_schema, msg="用户信息获取成功")
@router.get("/superusers", response_model=ApiResponse)
def get_tenant_superusers(
include_inactive: bool = False,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
):
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
superusers = user_service.get_tenant_superusers(
db=db,
current_user=current_user,
include_inactive=include_inactive
)
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
return success(data=superusers_schema, msg="租户超管列表获取成功")
@router.get("/{user_id}", response_model=ApiResponse)
def get_user_info_by_id(
user_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""根据用户ID获取用户信息"""
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
result = user_service.get_user(
db=db, user_id=user_id, current_user=current_user
)
api_logger.info(f"用户信息获取成功: {result.username}")
result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg="用户信息获取成功")
@router.put("/change-password", response_model=ApiResponse)
async def change_password(
request: ChangePasswordRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""修改当前用户密码"""
api_logger.info(f"用户密码修改请求: {current_user.username}")
await user_service.change_password(
db=db,
user_id=current_user.id,
old_password=request.old_password,
new_password=request.new_password,
current_user=current_user
)
api_logger.info(f"用户密码修改成功: {current_user.username}")
return success(msg="密码修改成功")
@router.put("/admin/change-password", response_model=ApiResponse)
async def admin_change_password(
request: AdminChangePasswordRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
):
"""超级管理员修改指定用户的密码"""
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
user, generated_password = await user_service.admin_change_password(
db=db,
target_user_id=request.user_id,
new_password=request.new_password,
current_user=current_user
)
# 根据是否生成了随机密码来构造响应
if request.new_password:
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
return success(msg="密码修改成功")
else:
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
return success(data=generated_password, msg="密码重置成功")

View File

@@ -0,0 +1,342 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
from app.models.user_model import User
from app.models.tenant_model import Tenants
from app.models.workspace_model import Workspace, InviteStatus
from app.schemas.response_schema import ApiResponse
from app.schemas.workspace_schema import (
WorkspaceCreate, WorkspaceUpdate, WorkspaceResponse,
WorkspaceInviteCreate, WorkspaceInviteResponse,
InviteValidateResponse, InviteAcceptRequest,
WorkspaceMemberUpdate, WorkspaceMemberItem
)
from app.schemas import knowledge_schema
from app.services import workspace_service
from app.core.logging_config import get_api_logger
from app.services import knowledge_service, document_service
# 获取API专用日志器
api_logger = get_api_logger()
# 需要认证的路由器
router = APIRouter(
prefix="/workspaces",
tags=["Workspaces"],
dependencies=[Depends(get_current_user)]
)
# 公开路由器(不需要认证)
public_router = APIRouter(
prefix="/workspaces",
tags=["Workspaces"]
)
def _convert_members_to_table_items(members):
"""将工作空间成员列表转换为表格项"""
return [
WorkspaceMemberItem(
id=m.id,
username=m.user.username,
account=m.user.email,
role=m.role,
last_login_at=m.user.last_login_at
)
for m in members
]
@router.get("", response_model=ApiResponse)
def get_workspaces(
include_current: bool = Query(True, description="是否包含当前工作空间"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
current_tenant: Tenants = Depends(get_current_tenant)
):
"""获取当前租户下用户参与的所有工作空间
Args:
include_current: 是否包含当前工作空间(默认 True
"""
api_logger.info(
f"用户 {current_user.username} 在租户 {current_tenant.name} 中请求获取工作空间列表",
extra={"include_current": include_current}
)
workspaces = workspace_service.get_user_workspaces(db, current_user)
# 如果不包含当前工作空间,则过滤掉
if not include_current and current_user.current_workspace_id:
workspaces = [w for w in workspaces if w.id != current_user.current_workspace_id]
api_logger.debug(
f"过滤掉当前工作空间",
extra={"current_workspace_id": str(current_user.current_workspace_id)}
)
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
return success(data=workspaces_schema, msg="工作空间列表获取成功")
@router.post("", response_model=ApiResponse)
def create_workspace(
workspace: WorkspaceCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
):
"""创建新的工作空间"""
api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}")
result = workspace_service.create_workspace(
db=db, workspace=workspace, user=current_user)
api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}")
result_schema = WorkspaceResponse.model_validate(result)
return success(data=result_schema, msg="工作空间创建成功")
@router.put("", response_model=ApiResponse)
@cur_workspace_access_guard()
def update_workspace(
workspace: WorkspaceUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""更新工作空间"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 ID: {workspace_id}")
result = workspace_service.update_workspace(
db=db,
workspace_id=workspace_id,
workspace_in=workspace,
user=current_user,
)
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
result_schema = WorkspaceResponse.model_validate(result)
return success(data=result_schema, msg="工作空间更新成功")
@router.get("/members", response_model=ApiResponse)
@cur_workspace_access_guard()
def get_cur_workspace_members(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取工作空间成员列表(关系序列化)"""
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
members = workspace_service.get_workspace_members(
db=db,
workspace_id=current_user.current_workspace_id,
user=current_user,
)
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
table_items = _convert_members_to_table_items(members)
return success(data=table_items, msg="工作空间成员列表获取成功")
@router.put("/members", response_model=ApiResponse)
@cur_workspace_access_guard()
def update_workspace_members(
updates: List[WorkspaceMemberUpdate],
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
members = workspace_service.update_workspace_member_roles(
db=db,
workspace_id=workspace_id,
updates=updates,
user=current_user,
)
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
return success(msg="成员角色更新成功")
@router.delete("/members/{member_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
def delete_workspace_member(
member_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
workspace_service.delete_workspace_member(
db=db,
workspace_id=workspace_id,
member_id=member_id,
user=current_user,
)
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
return success(msg="成员删除成功")
# 创建空间协作邀请
@router.post("/invites", response_model=ApiResponse)
@cur_workspace_access_guard()
def create_workspace_invite(
invite_data: WorkspaceInviteCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""创建工作空间邀请"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求为工作空间 {workspace_id} 创建邀请: {invite_data.email}")
result = workspace_service.create_workspace_invite(
db=db,
workspace_id=workspace_id,
invite_data=invite_data,
user=current_user
)
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
return success(data=result, msg="邀请创建成功")
@router.get("/invites", response_model=ApiResponse)
@cur_workspace_access_guard()
def get_workspace_invites(
status_filter: Optional[InviteStatus] = Query(None, alias="status"),
limit: int = Query(50, ge=1, le=100),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取工作空间邀请列表"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的邀请列表")
invites = workspace_service.get_workspace_invites(
db=db,
workspace_id=workspace_id,
user=current_user,
status=status_filter,
limit=limit,
offset=offset
)
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
return success(data=invites, msg="邀请列表获取成功")
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
def get_workspace_invite_info(
token: str,
db: Session = Depends(get_db),
):
"""获取工作空间邀请用户信息(无需认证)"""
result = workspace_service.validate_invite_token(db=db, token=token)
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
return success(data=result, msg="邀请验证成功")
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
def revoke_workspace_invite(
invite_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""撤销工作空间邀请"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求撤销工作空间 {workspace_id} 的邀请 {invite_id}")
result = workspace_service.revoke_workspace_invite(
db=db,
workspace_id=workspace_id,
invite_id=invite_id,
user=current_user
)
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
return success(data=result, msg="邀请撤销成功")
# ==================== 公开邀请接口(无需认证) ====================
# # 创建一个新的路由器用于公开接口
# public_router = APIRouter(
# prefix="/invites",
# tags=["Public Invites"]
# )
# @public_router.get("/validate", response_model=ApiResponse)
# def validate_invite_token(
# token: str = Query(..., description="邀请令牌"),
# db: Session = Depends(get_db),
# ):
# """验证邀请令牌(公开接口)"""
# api_logger.info(f"验证邀请令牌请求")
@router.put("/{workspace_id}/switch", response_model=ApiResponse)
@workspace_access_guard()
def switch_workspace(
workspace_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""切换工作空间"""
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
workspace_service.switch_workspace(
db=db,
workspace_id=workspace_id,
user=current_user,
)
api_logger.info(f"成功切换工作空间为 {workspace_id}")
return success(msg="工作空间切换成功")
@router.get("/storage", response_model=ApiResponse)
@cur_workspace_access_guard()
def get_workspace_storage_type(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取当前工作空间的存储类型"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的存储类型")
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
user=current_user
)
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
@router.get("/workspace_models", response_model=ApiResponse)
@cur_workspace_access_guard()
def workspace_models_configs(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取当前工作空间的模型配置llm, embedding, rerank"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的模型配置")
configs = workspace_service.get_workspace_models_configs(
db=db,
workspace_id=workspace_id,
user=current_user
)
if configs is None:
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="工作空间不存在或无权访问"
)
api_logger.info(
f"成功获取工作空间 {workspace_id} 的模型配置: "
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
)
return success(data=configs, msg="模型配置获取成功")