feat: Add base project structure with API and web components
This commit is contained in:
108
api/app/schemas/__init__.py
Normal file
108
api/app/schemas/__init__.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from .item_schema import Item
|
||||
from .user_schema import User, UserCreate, UserUpdate
|
||||
from .workspace_schema import Workspace, WorkspaceCreate, WorkspaceMember, WorkspaceMemberCreate
|
||||
from .token_schema import Token, TokenData
|
||||
from .knowledge_schema import Knowledge, KnowledgeCreate, KnowledgeUpdate
|
||||
from .document_schema import Document, DocumentCreate, DocumentUpdate
|
||||
from .file_schema import File, FileCreate, FileUpdate
|
||||
from .tenant_schema import Tenant, TenantCreate, TenantUpdate
|
||||
from .chunk_schema import ChunkCreate, ChunkUpdate, ChunkRetrieve
|
||||
from .knowledgeshare_schema import KnowledgeShare, KnowledgeShareCreate
|
||||
from .app_schema import (
|
||||
DraftRunRequest,
|
||||
DraftRunResponse,
|
||||
DraftRunStreamChunk,
|
||||
App,
|
||||
AppCreate,
|
||||
AppUpdate,
|
||||
AgentConfig,
|
||||
AgentConfigCreate,
|
||||
AgentConfigUpdate,
|
||||
AppRelease,
|
||||
ModelParameters,
|
||||
KnowledgeRetrievalConfig,
|
||||
MemoryConfig,
|
||||
ToolConfig,
|
||||
VariableDefinition,
|
||||
)
|
||||
from .conversation_schema import (
|
||||
Conversation,
|
||||
ConversationCreate,
|
||||
ConversationWithMessages,
|
||||
Message,
|
||||
MessageCreate,
|
||||
ChatRequest,
|
||||
ChatResponse,
|
||||
)
|
||||
from .multi_agent_schema import (
|
||||
SubAgentConfig,
|
||||
RoutingRule,
|
||||
ExecutionConfig,
|
||||
MultiAgentConfigCreate,
|
||||
MultiAgentConfigUpdate,
|
||||
MultiAgentConfigSchema,
|
||||
MultiAgentRunRequest,
|
||||
MultiAgentRunResponse,
|
||||
SubAgentResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Item",
|
||||
"User",
|
||||
"UserCreate",
|
||||
"UserUpdate",
|
||||
"Workspace",
|
||||
"WorkspaceCreate",
|
||||
"WorkspaceMember",
|
||||
"WorkspaceMemberCreate",
|
||||
"Token",
|
||||
"Knowledge",
|
||||
"KnowledgeCreate",
|
||||
"KnowledgeUpdate",
|
||||
"Document",
|
||||
"DocumentCreate",
|
||||
"DocumentUpdate",
|
||||
"File",
|
||||
"FileCreate",
|
||||
"FileUpdate",
|
||||
"Tenant",
|
||||
"TenantCreate",
|
||||
"TenantUpdate",
|
||||
"ChunkCreate",
|
||||
"ChunkUpdate",
|
||||
"ChunkRetrieve",
|
||||
"KnowledgeShare",
|
||||
"KnowledgeShareCreate",
|
||||
"DraftRunRequest",
|
||||
"DraftRunResponse",
|
||||
"DraftRunStreamChunk",
|
||||
"App",
|
||||
"AppCreate",
|
||||
"AppUpdate",
|
||||
"AgentConfig",
|
||||
"AgentConfigCreate",
|
||||
"AgentConfigUpdate",
|
||||
"AppRelease",
|
||||
"ModelParameters",
|
||||
"KnowledgeRetrievalConfig",
|
||||
"MemoryConfig",
|
||||
"ToolConfig",
|
||||
"VariableDefinition",
|
||||
"Conversation",
|
||||
"ConversationCreate",
|
||||
"ConversationWithMessages",
|
||||
"Message",
|
||||
"MessageCreate",
|
||||
"ChatRequest",
|
||||
"ChatResponse",
|
||||
# Multi-Agent Schemas
|
||||
"SubAgentConfig",
|
||||
"RoutingRule",
|
||||
"ExecutionConfig",
|
||||
"MultiAgentConfigCreate",
|
||||
"MultiAgentConfigUpdate",
|
||||
"MultiAgentConfigSchema",
|
||||
"MultiAgentRunRequest",
|
||||
"MultiAgentRunResponse",
|
||||
"SubAgentResult",
|
||||
]
|
||||
104
api/app/schemas/api_key_schema.py
Normal file
104
api/app/schemas/api_key_schema.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""API Key Schema"""
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, List
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from app.models.api_key_model import ApiKeyType
|
||||
|
||||
|
||||
class ApiKeyCreate(BaseModel):
|
||||
"""创建 API Key"""
|
||||
name: str = Field(..., description="API Key 名称", max_length=255)
|
||||
description: Optional[str] = Field(None, description="描述")
|
||||
type: ApiKeyType = Field(..., description="API Key 类型")
|
||||
scopes: List[str] = Field(default_factory=list, description="权限范围列表")
|
||||
resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID")
|
||||
resource_type: Optional[str] = Field(None, description="资源类型")
|
||||
rate_limit: Optional[int] = Field(100, description="速率限制(请求/分钟)", ge=1)
|
||||
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
|
||||
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
|
||||
|
||||
|
||||
class ApiKeyUpdate(BaseModel):
|
||||
"""更新 API Key"""
|
||||
name: Optional[str] = Field(None, description="API Key 名称", max_length=255)
|
||||
description: Optional[str] = Field(None, description="描述")
|
||||
scopes: Optional[List[str]] = Field(None, description="权限范围列表")
|
||||
rate_limit: Optional[int] = Field(None, description="速率限制(请求/分钟)", ge=1)
|
||||
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
|
||||
|
||||
|
||||
class ApiKeyResponse(BaseModel):
|
||||
"""API Key 响应(创建时返回,包含明文 Key)"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: Optional[str]
|
||||
api_key: str = Field(..., description="API Key 明文(仅创建时返回)")
|
||||
key_prefix: str
|
||||
type: str
|
||||
scopes: List[str]
|
||||
resource_id: Optional[uuid.UUID]
|
||||
resource_type: Optional[str]
|
||||
rate_limit: int
|
||||
quota_limit: Optional[int]
|
||||
expires_at: Optional[datetime.datetime]
|
||||
created_at: datetime.datetime
|
||||
|
||||
|
||||
class ApiKey(BaseModel):
|
||||
"""API Key 信息(不包含明文 Key)"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: Optional[str]
|
||||
key_prefix: str
|
||||
type: str
|
||||
scopes: List[str]
|
||||
resource_id: Optional[uuid.UUID]
|
||||
resource_type: Optional[str]
|
||||
rate_limit: int
|
||||
quota_limit: Optional[int]
|
||||
quota_used: int
|
||||
expires_at: Optional[datetime.datetime]
|
||||
is_active: bool
|
||||
last_used_at: Optional[datetime.datetime]
|
||||
usage_count: int
|
||||
workspace_id: uuid.UUID
|
||||
created_by: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
|
||||
class ApiKeyStats(BaseModel):
|
||||
"""API Key 使用统计"""
|
||||
total_requests: int = Field(..., description="总请求数")
|
||||
requests_today: int = Field(..., description="今日请求数")
|
||||
quota_used: int = Field(..., description="已使用配额")
|
||||
quota_limit: Optional[int] = Field(None, description="配额限制")
|
||||
last_used_at: Optional[datetime.datetime] = Field(None, description="最后使用时间")
|
||||
avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)")
|
||||
|
||||
|
||||
class ApiKeyQuery(BaseModel):
|
||||
"""API Key 查询参数"""
|
||||
type: Optional[ApiKeyType] = Field(None, description="API Key 类型")
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID")
|
||||
page: int = Field(1, ge=1, description="页码")
|
||||
pagesize: int = Field(10, ge=1, le=100, description="每页数量")
|
||||
|
||||
|
||||
class ApiKeyAuth(BaseModel):
|
||||
"""API Key 认证信息"""
|
||||
api_key_id: uuid.UUID
|
||||
workspace_id: uuid.UUID
|
||||
type: str
|
||||
scopes: List[str]
|
||||
resource_id: Optional[uuid.UUID]
|
||||
resource_type: Optional[str]
|
||||
425
api/app/schemas/app_schema.py
Normal file
425
api/app/schemas/app_schema.py
Normal file
@@ -0,0 +1,425 @@
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Optional, Any, List, Dict, TYPE_CHECKING
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||
|
||||
|
||||
# ---------- Input Schemas ----------
|
||||
|
||||
class KnowledgeBaseConfig(BaseModel):
|
||||
"""单个知识库配置"""
|
||||
kb_id: str = Field(..., description="知识库ID")
|
||||
top_k: int = Field(default=3, ge=1, le=20, description="检索返回的文档数量")
|
||||
similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0, description="相似度阈值")
|
||||
strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense")
|
||||
weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)")
|
||||
vector_similarity_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="向量相似度权重")
|
||||
retrieve_type: str = Field(default="hybrid", description="检索方式participle| semantic|hybrid")
|
||||
|
||||
|
||||
class KnowledgeRetrievalConfig(BaseModel):
|
||||
"""知识库检索配置(支持多个知识库,每个有独立配置)"""
|
||||
knowledge_bases: List[KnowledgeBaseConfig] = Field(
|
||||
default_factory=list,
|
||||
description="关联的知识库列表,每个知识库有独立配置"
|
||||
)
|
||||
|
||||
# 多知识库融合策略
|
||||
merge_strategy: str = Field(
|
||||
default="weighted",
|
||||
description="多知识库结果融合策略: weighted | rrf | concat"
|
||||
)
|
||||
reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID")
|
||||
reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数")
|
||||
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
"""工具配置"""
|
||||
enabled: bool = Field(default=False, description="是否启用该工具")
|
||||
config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="工具特定配置")
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""记忆配置"""
|
||||
enabled: bool = Field(default=True, description="是否启用对话历史记忆")
|
||||
memory_content: Optional[str] = Field(default=None, description="选择记忆的内容类型")
|
||||
max_history: int = Field(default=10, ge=0, le=100, description="最大保留的历史对话轮数")
|
||||
|
||||
|
||||
class ModelParameters(BaseModel):
|
||||
"""模型参数配置"""
|
||||
temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="温度参数,控制输出的随机性")
|
||||
max_tokens: int = Field(default=2000, ge=1, le=32000, description="最大生成token数")
|
||||
top_p: float = Field(default=1.0, ge=0.0, le=1.0, description="核采样参数")
|
||||
frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0, description="频率惩罚")
|
||||
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0, description="存在惩罚")
|
||||
n: int = Field(default=1, ge=1, le=10, description="生成的回复数量")
|
||||
stop: Optional[List[str]] = Field(default=None, description="停止序列")
|
||||
|
||||
|
||||
class VariableDefinition(BaseModel):
|
||||
"""变量定义"""
|
||||
name: str = Field(..., description="变量名称(标识符)")
|
||||
display_name: Optional[str] = Field(None, description="显示名称(用户看到的名称)")
|
||||
type: str = Field(
|
||||
default="string",
|
||||
description="变量类型: string(单行文本) | text(多行文本) | number(数字)"
|
||||
)
|
||||
required: bool = Field(default=False, description="是否必填")
|
||||
description: Optional[str] = Field(default=None, description="变量描述")
|
||||
max_length: Optional[int] = Field(default=None, description="最大长度(用于文本类型)")
|
||||
|
||||
|
||||
class AgentConfigCreate(BaseModel):
|
||||
"""Agent 行为配置"""
|
||||
# 提示词配置
|
||||
system_prompt: Optional[str] = Field(default=None, description="系统提示词,定义 Agent 的角色和行为准则")
|
||||
|
||||
# 模型配置
|
||||
default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认使用的模型配置ID")
|
||||
model_parameters: ModelParameters = Field(
|
||||
default_factory=ModelParameters,
|
||||
description="模型参数配置(temperature、max_tokens 等)"
|
||||
)
|
||||
|
||||
# 知识库关联
|
||||
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field(
|
||||
default=None,
|
||||
description="知识库检索配置"
|
||||
)
|
||||
|
||||
# 记忆配置
|
||||
memory: MemoryConfig = Field(
|
||||
default_factory=lambda: MemoryConfig(enabled=True),
|
||||
description="对话历史记忆配置"
|
||||
)
|
||||
|
||||
# 变量配置
|
||||
variables: List[VariableDefinition] = Field(
|
||||
default_factory=list,
|
||||
description="Agent 可用的变量列表"
|
||||
)
|
||||
|
||||
# 工具配置
|
||||
tools: Dict[str, ToolConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="工具配置,key 为工具名称(web_search, code_interpreter, image_generation 等)"
|
||||
)
|
||||
|
||||
|
||||
class AppCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
icon_type: Optional[str] = None
|
||||
type: str = Field(pattern=r"^(agent|workflow|multi_agent)$")
|
||||
visibility: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
tags: Optional[List[str]] = Field(default_factory=list)
|
||||
|
||||
# only for type=agent
|
||||
agent_config: Optional[AgentConfigCreate] = None
|
||||
|
||||
# only for type=multi_agent
|
||||
multi_agent_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class AppUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
icon_type: Optional[str] = None
|
||||
visibility: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
class AgentConfigUpdate(BaseModel):
|
||||
"""更新 Agent 行为配置"""
|
||||
# 提示词配置
|
||||
system_prompt: Optional[str] = Field(default=None, description="系统提示词")
|
||||
|
||||
# 模型配置
|
||||
default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认模型配置ID")
|
||||
model_parameters: Optional[ModelParameters] = Field(default=None, description="模型参数配置")
|
||||
|
||||
# 知识库关联
|
||||
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field(
|
||||
default=None,
|
||||
description="知识库检索配置"
|
||||
)
|
||||
|
||||
# 记忆配置
|
||||
memory: Optional[MemoryConfig] = Field(default=None, description="对话历史记忆配置")
|
||||
|
||||
# 变量配置
|
||||
variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表")
|
||||
|
||||
# 工具配置
|
||||
tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置")
|
||||
|
||||
|
||||
# ---------- Output Schemas ----------
|
||||
|
||||
class App(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
workspace_id: uuid.UUID
|
||||
created_by: uuid.UUID
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
icon_type: Optional[str] = None
|
||||
type: str
|
||||
visibility: str
|
||||
status: str
|
||||
tags: List[str] = []
|
||||
current_release_id: Optional[uuid.UUID] = None
|
||||
is_active: bool
|
||||
is_shared: bool = False # 是否是共享应用(从其他工作空间共享来的)
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
"""Agent 配置输出 Schema"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
app_id: uuid.UUID
|
||||
|
||||
# 提示词
|
||||
system_prompt: Optional[str] = None
|
||||
|
||||
# 模型配置
|
||||
default_model_config_id: Optional[uuid.UUID] = None
|
||||
model_parameters: ModelParameters = Field(default_factory=ModelParameters)
|
||||
|
||||
# 知识库检索
|
||||
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = None
|
||||
|
||||
# 记忆配置
|
||||
memory: MemoryConfig = Field(default_factory=lambda: MemoryConfig(enabled=True))
|
||||
|
||||
# 变量配置
|
||||
variables: List[VariableDefinition] = []
|
||||
|
||||
# 工具配置
|
||||
tools: Dict[str, ToolConfig] = {}
|
||||
|
||||
is_active: bool
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@field_validator("model_parameters", mode="before")
|
||||
@classmethod
|
||||
def validate_model_parameters(cls, v):
|
||||
"""处理 None 值,返回默认的 ModelParameters"""
|
||||
if v is None:
|
||||
return ModelParameters()
|
||||
return v
|
||||
|
||||
@field_validator("memory", mode="before")
|
||||
@classmethod
|
||||
def validate_memory(cls, v):
|
||||
"""处理 None 值,返回默认的 MemoryConfig"""
|
||||
if v is None:
|
||||
return MemoryConfig(enabled=True)
|
||||
return v
|
||||
|
||||
@field_validator("variables", mode="before")
|
||||
@classmethod
|
||||
def validate_variables(cls, v):
|
||||
"""处理 None 值,返回空列表"""
|
||||
if v is None:
|
||||
return []
|
||||
return v
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def validate_tools(cls, v):
|
||||
"""处理 None 值,返回空字典"""
|
||||
if v is None:
|
||||
return {}
|
||||
return v
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class PublishRequest(BaseModel):
|
||||
"""发布应用请求"""
|
||||
version_name: str
|
||||
release_notes: Optional[str] = Field(None, description="版本说明")
|
||||
|
||||
|
||||
class AppRelease(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
app_id: uuid.UUID
|
||||
version: int
|
||||
release_notes: Optional[str] = None
|
||||
version_name: str
|
||||
description: Optional[str] = None
|
||||
icon: Optional[str] = None
|
||||
icon_type: Optional[str] = None
|
||||
name: str
|
||||
type: str
|
||||
visibility: str
|
||||
config: Dict[str, Any] = {}
|
||||
default_model_config_id: Optional[uuid.UUID] = None
|
||||
published_by: uuid.UUID
|
||||
publisher_name: str
|
||||
published_at: datetime.datetime
|
||||
is_active: bool
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("published_at", when_used="json")
|
||||
def _serialize_published_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
# ---------- App Share Schemas ----------
|
||||
|
||||
class AppShareCreate(BaseModel):
|
||||
"""应用分享请求"""
|
||||
target_workspace_ids: List[uuid.UUID] = Field(..., description="目标工作空间ID列表")
|
||||
|
||||
|
||||
class AppShare(BaseModel):
|
||||
"""应用分享输出"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
source_app_id: uuid.UUID
|
||||
source_workspace_id: uuid.UUID
|
||||
target_workspace_id: uuid.UUID
|
||||
shared_by: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
# ---------- Draft Run Schemas ----------
|
||||
|
||||
class DraftRunRequest(BaseModel):
|
||||
"""试运行请求"""
|
||||
message: str = Field(..., description="用户消息")
|
||||
conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)")
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
|
||||
|
||||
class DraftRunResponse(BaseModel):
|
||||
"""试运行响应(非流式)"""
|
||||
message: str = Field(..., description="AI 回复消息")
|
||||
conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)")
|
||||
usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况")
|
||||
elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)")
|
||||
|
||||
|
||||
class DraftRunStreamChunk(BaseModel):
|
||||
"""试运行流式响应块"""
|
||||
event: str = Field(..., description="事件类型: start | message | end | error")
|
||||
data: Dict[str, Any] = Field(..., description="事件数据")
|
||||
|
||||
|
||||
# ---------- Draft Run Compare Schemas ----------
|
||||
|
||||
class ModelCompareItem(BaseModel):
|
||||
"""单个对比模型配置"""
|
||||
model_config_id: uuid.UUID = Field(..., description="模型配置ID")
|
||||
model_parameters: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="覆盖模型参数,如 temperature, max_tokens 等"
|
||||
)
|
||||
label: Optional[str] = Field(
|
||||
None,
|
||||
description="自定义显示标签,用于区分同一模型的不同配置"
|
||||
)
|
||||
conversation_id: Optional[str] = Field(
|
||||
None,
|
||||
description="会话ID,用于为每个模型指定独立的会话历史"
|
||||
)
|
||||
|
||||
|
||||
class DraftRunCompareRequest(BaseModel):
|
||||
"""多模型对比试运行请求"""
|
||||
message: str = Field(..., description="用户消息")
|
||||
conversation_id: Optional[str] = Field(None, description="会话ID")
|
||||
user_id: Optional[str] = Field(None, description="用户ID")
|
||||
variables: Optional[Dict[str, Any]] = Field(None, description="变量参数")
|
||||
|
||||
models: List[ModelCompareItem] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=5,
|
||||
description="要对比的模型列表(1-5个)"
|
||||
)
|
||||
|
||||
parallel: bool = Field(True, description="是否并行执行")
|
||||
stream: bool = Field(False, description="是否流式返回")
|
||||
timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)")
|
||||
|
||||
|
||||
class ModelRunResult(BaseModel):
|
||||
"""单个模型运行结果"""
|
||||
model_config_id: uuid.UUID
|
||||
model_name: str
|
||||
label: Optional[str] = None
|
||||
|
||||
parameters_used: Dict[str, Any] = Field(..., description="实际使用的参数")
|
||||
|
||||
message: Optional[str] = None
|
||||
usage: Optional[Dict[str, Any]] = None
|
||||
elapsed_time: float
|
||||
error: Optional[str] = None
|
||||
|
||||
tokens_per_second: Optional[float] = None
|
||||
cost_estimate: Optional[float] = None
|
||||
conversation_id: Optional[str] = None
|
||||
|
||||
|
||||
class DraftRunCompareResponse(BaseModel):
|
||||
"""多模型对比响应"""
|
||||
results: List[ModelRunResult]
|
||||
|
||||
total_elapsed_time: float
|
||||
successful_count: int
|
||||
failed_count: int
|
||||
|
||||
fastest_model: Optional[str] = None
|
||||
cheapest_model: Optional[str] = None
|
||||
26
api/app/schemas/chunk_schema.py
Normal file
26
api/app/schemas/chunk_schema.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from pydantic import BaseModel, Field
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class RetrieveType(StrEnum):
|
||||
"""Retrieval type enumeration"""
|
||||
PARTICIPLE = "participle"
|
||||
SEMANTIC = "semantic"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
class ChunkCreate(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChunkUpdate(BaseModel):
|
||||
content: str | None = Field(None)
|
||||
|
||||
|
||||
class ChunkRetrieve(BaseModel):
|
||||
query: str
|
||||
kb_ids: list[uuid.UUID]
|
||||
similarity_threshold: float | None = Field(None)
|
||||
vector_similarity_weight: float | None = Field(None)
|
||||
top_k: int | None = Field(None)
|
||||
retrieve_type: RetrieveType | None = Field(None)
|
||||
86
api/app/schemas/conversation_schema.py
Normal file
86
api/app/schemas/conversation_schema.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""会话和消息相关的 Schema"""
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||
|
||||
|
||||
# ---------- Input Schemas ----------
|
||||
|
||||
class ConversationCreate(BaseModel):
|
||||
"""创建会话请求"""
|
||||
title: Optional[str] = Field(None, max_length=255, description="会话标题")
|
||||
user_id: Optional[str] = Field(None, description="用户ID(外部系统)")
|
||||
|
||||
|
||||
class MessageCreate(BaseModel):
|
||||
"""创建消息请求"""
|
||||
content: str = Field(..., description="消息内容")
|
||||
variables: Optional[Dict[str, Any]] = Field(None, description="变量参数")
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
"""聊天请求(基于 share_token)"""
|
||||
message: str = Field(..., description="用户消息")
|
||||
conversation_id: Optional[uuid.UUID] = Field(None, description="会话ID(多轮对话)")
|
||||
user_id: Optional[str] = Field(None, description="用户ID(外部系统)")
|
||||
variables: Optional[Dict[str, Any]] = Field(None, description="变量参数")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
web_search: bool = Field(default=False, description="是否启用网络搜索")
|
||||
memory: bool = Field(default=True, description="是否启用记忆功能")
|
||||
|
||||
|
||||
# ---------- Output Schemas ----------
|
||||
|
||||
class Message(BaseModel):
|
||||
"""消息输出"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
conversation_id: uuid.UUID
|
||||
role: str
|
||||
content: str
|
||||
meta_data: Optional[Dict[str, Any]] = None
|
||||
created_at: datetime.datetime
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
"""会话输出"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
app_id: uuid.UUID
|
||||
workspace_id: uuid.UUID
|
||||
user_id: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
is_draft: bool
|
||||
message_count: int
|
||||
is_active: bool
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class ConversationWithMessages(Conversation):
|
||||
"""会话详情(包含消息列表)"""
|
||||
messages: List[Message] = []
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""聊天响应(非流式)"""
|
||||
conversation_id: uuid.UUID
|
||||
message: str
|
||||
usage: Optional[Dict[str, Any]] = None
|
||||
elapsed_time: Optional[float] = None
|
||||
63
api/app/schemas/document_schema.py
Normal file
63
api/app/schemas/document_schema.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
|
||||
class DocumentBase(BaseModel):
|
||||
kb_id: uuid.UUID
|
||||
created_by: uuid.UUID | None = None
|
||||
file_id: uuid.UUID
|
||||
file_name: str
|
||||
file_ext: str
|
||||
file_size: int
|
||||
file_meta: dict
|
||||
parser_id: str
|
||||
parser_config: dict
|
||||
|
||||
|
||||
class DocumentCreate(DocumentBase):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentUpdate(BaseModel):
|
||||
file_id: uuid.UUID | None = Field(None)
|
||||
file_name: str | None = Field(None)
|
||||
file_ext: str | None = Field(None)
|
||||
file_size: int | None = Field(None)
|
||||
file_meta: dict | None = Field(None)
|
||||
parser_id: str | None = Field(None)
|
||||
parser_config: dict | None = Field(None)
|
||||
chunk_num: int | None = Field(None)
|
||||
progress: float | None = Field(None)
|
||||
progress_msg: str | None = Field(None)
|
||||
process_begin_at: datetime.datetime | None = Field(None)
|
||||
process_duration: float | None = Field(None)
|
||||
run: int | None = Field(None)
|
||||
status: int | None = Field(None)
|
||||
|
||||
|
||||
class Document(DocumentBase):
|
||||
id: uuid.UUID
|
||||
chunk_num: int
|
||||
progress: float
|
||||
progress_msg: str
|
||||
process_begin_at: datetime.datetime
|
||||
process_duration: float
|
||||
run: int
|
||||
status: int
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("process_begin_at", when_used="json")
|
||||
def _serialize_process_begin_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
17
api/app/schemas/end_user_schema.py
Normal file
17
api/app/schemas/end_user_schema.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import ConfigDict
|
||||
|
||||
class EndUser(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID = Field(description="终端用户ID")
|
||||
app_id: uuid.UUID = Field(description="应用ID")
|
||||
# end_user_id: str = Field(description="终端用户ID")
|
||||
other_id: Optional[str] = Field(description="第三方ID", default=None)
|
||||
other_name: Optional[str] = Field(description="其他名称", default="")
|
||||
other_address: Optional[str] = Field(description="其他地址", default="")
|
||||
created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now)
|
||||
updated_at: datetime.datetime = Field(description="更新时间", default_factory=datetime.datetime.now)
|
||||
39
api/app/schemas/file_schema.py
Normal file
39
api/app/schemas/file_schema.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
|
||||
class FileBase(BaseModel):
|
||||
kb_id: uuid.UUID
|
||||
created_by: uuid.UUID | None = None
|
||||
parent_id: uuid.UUID | None = None
|
||||
file_name: str
|
||||
file_ext: str
|
||||
file_size: int
|
||||
|
||||
|
||||
class FileCreate(FileBase):
|
||||
pass
|
||||
|
||||
|
||||
class CustomTextFileCreate(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
|
||||
|
||||
class FileUpdate(BaseModel):
|
||||
parent_id: uuid.UUID | None = Field(None)
|
||||
file_name: str | None = Field(None)
|
||||
file_ext: str | None = Field(None)
|
||||
file_size: str | None = Field(None)
|
||||
|
||||
|
||||
class File(FileBase):
|
||||
id: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
69
api/app/schemas/generic_file_schema.py
Normal file
69
api/app/schemas/generic_file_schema.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Schemas for Generic File Upload System
|
||||
"""
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional, Dict, Any
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from app.core.upload_enums import UploadContext
|
||||
|
||||
|
||||
class GenericFileBase(BaseModel):
|
||||
"""Base schema for generic file"""
|
||||
file_name: str = Field(..., description="文件名")
|
||||
context: UploadContext = Field(..., description="上传上下文")
|
||||
is_public: bool = Field(False, description="是否公开")
|
||||
file_metadata: Optional[Dict[str, Any]] = Field(default={}, description="文件元数据")
|
||||
|
||||
|
||||
class GenericFileCreate(GenericFileBase):
|
||||
"""Schema for creating a generic file"""
|
||||
tenant_id: uuid.UUID
|
||||
created_by: uuid.UUID
|
||||
file_ext: str
|
||||
file_size: int
|
||||
mime_type: Optional[str] = None
|
||||
storage_path: str
|
||||
|
||||
|
||||
class GenericFileResponse(BaseModel):
|
||||
"""Schema for generic file response"""
|
||||
id: uuid.UUID = Field(..., description="文件ID")
|
||||
file_name: str = Field(..., description="文件名")
|
||||
file_ext: str = Field(..., description="文件扩展名")
|
||||
file_size: int = Field(..., description="文件大小(字节)")
|
||||
mime_type: Optional[str] = Field(None, description="MIME类型")
|
||||
context: str = Field(..., description="上传上下文")
|
||||
access_url: Optional[str] = Field(None, description="访问URL")
|
||||
is_public: bool = Field(..., description="是否公开")
|
||||
file_metadata: Dict[str, Any] = Field(default={}, description="文件元数据")
|
||||
status: str = Field(..., description="文件状态")
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
created_at: datetime.datetime = Field(..., description="创建时间")
|
||||
updated_at: datetime.datetime = Field(..., description="更新时间")
|
||||
|
||||
|
||||
class FileMetadataUpdate(BaseModel):
|
||||
"""Schema for updating file metadata"""
|
||||
file_name: Optional[str] = Field(None, description="文件名")
|
||||
file_metadata: Optional[Dict[str, Any]] = Field(None, description="文件元数据")
|
||||
is_public: Optional[bool] = Field(None, description="是否公开")
|
||||
|
||||
|
||||
class UploadResultSchema(BaseModel):
|
||||
"""Schema for upload result"""
|
||||
success: bool = Field(..., description="是否成功")
|
||||
file_id: Optional[uuid.UUID] = Field(None, description="文件ID")
|
||||
file_name: str = Field(..., description="文件名")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
file_info: Optional[GenericFileResponse] = Field(None, description="文件信息")
|
||||
|
||||
|
||||
class BatchUploadResponse(BaseModel):
|
||||
"""Schema for batch upload response"""
|
||||
total: int = Field(..., description="总文件数")
|
||||
success_count: int = Field(..., description="成功数量")
|
||||
failed_count: int = Field(..., description="失败数量")
|
||||
results: list[UploadResultSchema] = Field(..., description="上传结果列表")
|
||||
5
api/app/schemas/item_schema.py
Normal file
5
api/app/schemas/item_schema.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
price: float
|
||||
69
api/app/schemas/knowledge_schema.py
Normal file
69
api/app/schemas/knowledge_schema.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||
import datetime
|
||||
import uuid
|
||||
from .user_schema import User
|
||||
from .model_schema import ModelConfig
|
||||
from typing import Optional
|
||||
from app.models.knowledge_model import KnowledgeType, PermissionType
|
||||
|
||||
|
||||
class KnowledgeBase(BaseModel):
|
||||
workspace_id: uuid.UUID | None = None
|
||||
created_by: uuid.UUID | None = None
|
||||
parent_id: uuid.UUID | None = None
|
||||
name: str
|
||||
description: str | None = None
|
||||
avatar: str | None = None
|
||||
type: KnowledgeType | None = None
|
||||
permission_id: PermissionType | None = None
|
||||
embedding_id: uuid.UUID | None = None
|
||||
reranker_id: uuid.UUID | None = None
|
||||
llm_id: uuid.UUID | None = None
|
||||
image2text_id: uuid.UUID | None = None
|
||||
doc_num: int | None = None
|
||||
chunk_num: int | None = None
|
||||
parser_id: str | None = None
|
||||
parser_config: dict | None = None
|
||||
|
||||
|
||||
class KnowledgeCreate(KnowledgeBase):
|
||||
pass
|
||||
|
||||
class KnowledgeUpdate(BaseModel):
|
||||
parent_id: uuid.UUID | None = Field(None)
|
||||
name: str | None = Field(None)
|
||||
description: str | None = Field(None)
|
||||
avatar: str | None = Field(None)
|
||||
type: KnowledgeType | None = Field(None)
|
||||
permission_id: PermissionType | None = Field(None)
|
||||
embedding_id: uuid.UUID | None = Field(None)
|
||||
reranker_id: uuid.UUID | None = Field(None)
|
||||
llm_id: uuid.UUID | None = Field(None)
|
||||
image2text_id: uuid.UUID | None = Field(None)
|
||||
doc_num: int | None = Field(None)
|
||||
chunk_num: int | None = Field(None)
|
||||
parser_id: str | None = Field(None)
|
||||
parser_config: dict | None = Field(None)
|
||||
status: int | None = Field(None)
|
||||
|
||||
|
||||
class Knowledge(KnowledgeBase):
|
||||
id: uuid.UUID
|
||||
status: int
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
created_user: User
|
||||
embedding: Optional[ModelConfig] = None
|
||||
reranker: Optional[ModelConfig] = None
|
||||
llm: Optional[ModelConfig] = None
|
||||
image2text: Optional[ModelConfig] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
37
api/app/schemas/knowledgeshare_schema.py
Normal file
37
api/app/schemas/knowledgeshare_schema.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||
import datetime
|
||||
import uuid
|
||||
from .knowledge_schema import Knowledge
|
||||
from .workspace_schema import Workspace
|
||||
from .user_schema import User
|
||||
|
||||
|
||||
class KnowledgeShareBase(BaseModel):
|
||||
source_kb_id: uuid.UUID
|
||||
source_workspace_id: uuid.UUID | None = None
|
||||
target_kb_id: uuid.UUID | None = None
|
||||
target_workspace_id: uuid.UUID
|
||||
shared_by: uuid.UUID | None = None
|
||||
|
||||
|
||||
class KnowledgeShareCreate(KnowledgeShareBase):
|
||||
pass
|
||||
|
||||
|
||||
class KnowledgeShare(KnowledgeShareBase):
|
||||
id: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
target_kb: Knowledge
|
||||
target_workspace: Workspace
|
||||
shared_user: User
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
17
api/app/schemas/memory_agent_schema.py
Normal file
17
api/app/schemas/memory_agent_schema.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UserInput(BaseModel):
|
||||
message: str
|
||||
history: list[dict]
|
||||
search_switch: str
|
||||
group_id: str
|
||||
config_id: Optional[str] = None
|
||||
|
||||
|
||||
class Write_UserInput(BaseModel):
|
||||
message: str
|
||||
group_id: str
|
||||
config_id: Optional[str] = None
|
||||
18
api/app/schemas/memory_increment_schema.py
Normal file
18
api/app/schemas/memory_increment_schema.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field, field_serializer
|
||||
from pydantic import ConfigDict
|
||||
|
||||
class MemoryIncrement(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
workspace_id: uuid.UUID = Field(description="工作空间ID")
|
||||
total_num: int = Field(description="增量总数")
|
||||
created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now())
|
||||
updated_at: datetime.datetime = Field(description="更新时间", default_factory=datetime.datetime.now())
|
||||
|
||||
@field_serializer('created_at', 'updated_at')
|
||||
def serialize_datetime(self, dt: datetime.datetime, _info) -> str:
|
||||
"""将日期时间序列化为年月日格式"""
|
||||
return dt.strftime('%Y-%m-%d')
|
||||
343
api/app/schemas/memory_storage_schema.py
Normal file
343
api/app/schemas/memory_storage_schema.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
所有的内容是放错误地方了,应该放在models
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, List, Dict, Literal
|
||||
import time
|
||||
import uuid
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 原 UserInput 相关 Schema (保留原有功能)
|
||||
# ============================================================================
|
||||
class UserInput(BaseModel):
|
||||
message: str
|
||||
history: list[dict]
|
||||
search_switch: str
|
||||
group_id: str
|
||||
|
||||
|
||||
class Write_UserInput(BaseModel):
|
||||
message: str
|
||||
group_id: str
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 从 json_schema.py 迁移的 Schema
|
||||
# ============================================================================
|
||||
class BaseDataSchema(BaseModel):
|
||||
"""Base schema for the data"""
|
||||
id: str = Field(..., description="The unique identifier for the data entry.")
|
||||
statement: str = Field(..., description="The statement text.")
|
||||
group_id: str = Field(..., description="The group identifier.")
|
||||
chunk_id: str = Field(..., description="The chunk identifier.")
|
||||
created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.")
|
||||
expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.")
|
||||
valid_at: Optional[str] = Field(None, description="The validation timestamp in ISO 8601 format.")
|
||||
invalid_at: Optional[str] = Field(None, description="The invalidation timestamp in ISO 8601 format.")
|
||||
entity_ids: List[str] = Field([], description="The list of entity identifiers.")
|
||||
|
||||
|
||||
class ConflictResultSchema(BaseModel):
|
||||
"""Schema for the conflict result data in the reflexion_data.json file."""
|
||||
data: List[BaseDataSchema] = Field(..., description="The conflict memory data.")
|
||||
conflict: bool = Field(..., description="Whether the memory is in conflict.")
|
||||
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def _normalize_data(cls, v):
|
||||
if isinstance(v, dict):
|
||||
d = v.get("data")
|
||||
if isinstance(d, dict):
|
||||
v["data"] = [d]
|
||||
return v
|
||||
|
||||
|
||||
class ConflictSchema(BaseModel):
|
||||
"""Schema for the conflict data in the reflexion_data"""
|
||||
data: List[BaseDataSchema] = Field(..., description="The conflict memory data.")
|
||||
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def _normalize_data(cls, v):
|
||||
if isinstance(v, dict):
|
||||
d = v.get("data")
|
||||
if isinstance(d, dict):
|
||||
v["data"] = [d]
|
||||
return v
|
||||
|
||||
|
||||
class ReflexionSchema(BaseModel):
|
||||
"""Schema for the reflexion data in the reflexion_data"""
|
||||
reason: str = Field(..., description="The reason for the reflexion.")
|
||||
solution: str = Field(..., description="The solution for the reflexion.")
|
||||
|
||||
|
||||
class ResolvedSchema(BaseModel):
|
||||
"""Schema for the resolved memory data in the reflexion_data"""
|
||||
original_memory_id: Optional[str] = Field(None, description="The original memory identifier.")
|
||||
resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data.")
|
||||
|
||||
|
||||
class ReflexionResultSchema(BaseModel):
|
||||
"""Schema for the reflexion result data in the reflexion_data.json file."""
|
||||
# 模型输出中 "conflict" 为单个冲突对象(包含 data 与 conflict_memory),而非字典映射
|
||||
conflict: ConflictResultSchema = Field(..., description="The conflict result data.")
|
||||
reflexion: Optional[ReflexionSchema] = Field(None, description="The reflexion data.")
|
||||
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def _normalize_resolved(cls, v):
|
||||
if isinstance(v, dict):
|
||||
conflict = v.get("conflict")
|
||||
if isinstance(conflict, dict) and conflict.get("conflict") is False:
|
||||
v["resolved"] = None
|
||||
else:
|
||||
resolved = v.get("resolved")
|
||||
if isinstance(resolved, dict):
|
||||
orig = resolved.get("original_memory_id")
|
||||
mem = resolved.get("resolved_memory")
|
||||
if orig is None and (mem is None or mem == {}):
|
||||
v["resolved"] = None
|
||||
return v
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 从 messages.py 迁移的 Schema
|
||||
# ============================================================================
|
||||
|
||||
# Composite key identifying a config row
|
||||
class ConfigKey(BaseModel): # 配置参数键模型
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
config_id: int = Field("config_id", description="配置唯一标识(字符串)")
|
||||
user_id: str = Field("user_id", description="用户标识(字符串)")
|
||||
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
|
||||
|
||||
|
||||
# Allowed chunking strategies (extendable later)
|
||||
ChunkerStrategy = Literal[ # 分块策略枚举
|
||||
"RecursiveChunker",
|
||||
"TokenChunker",
|
||||
"SemanticChunker",
|
||||
"NeuralChunker",
|
||||
"HybridChunker",
|
||||
"LLMChunker",
|
||||
"SentenceChunker",
|
||||
"LateChunker"
|
||||
]
|
||||
|
||||
|
||||
# 这是 Request body示例
|
||||
class ConfigParams(ConfigKey): # 创建配置参数模型 旧
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
# Boolean switches
|
||||
enable_llm_dedup_blockwise: bool = Field(True, description="启用LLM决策去重")
|
||||
enable_llm_disambiguation: bool = Field(True, description="启用LLM决策消歧")
|
||||
deep_retrieval: bool = Field(True, description="深度检索开关(保留既有拼写)")
|
||||
|
||||
# Thresholds in [0, 1]
|
||||
t_type_strict: float = Field(0.8, ge=0.0, le=1.0, description="类型严格阈值")
|
||||
t_name_strict: float = Field(0.8, ge=0.0, le=1.0, description="名称严格阈值")
|
||||
t_overall: float = Field(0.8, ge=0.0, le=1.0, description="综合阈值")
|
||||
state: bool = Field(False, description="配置使用状态(True/False)")
|
||||
# Chunker strategy selection (must be one of the declared literals)
|
||||
chunker_strategy: ChunkerStrategy = Field(
|
||||
"RecursiveChunker",
|
||||
description=(
|
||||
"分块策略:RecursiveChunker/TokenChunker/SemanticChunker/NeuralChunker/"
|
||||
"HybridChunker/LLMChunker/SentenceChunker/LateChunker"
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("chunker_strategy", mode="before")
|
||||
@classmethod
|
||||
def map_chunker_aliases(cls, v: str):
|
||||
# 允许常见别名并映射到合法枚举
|
||||
if isinstance(v, str):
|
||||
m = v.strip().lower()
|
||||
alias_map = {
|
||||
"auto": "RecursiveChunker",
|
||||
"by_sentence": "SentenceChunker",
|
||||
"by_paragraph": "SemanticChunker",
|
||||
"fixed_tokens": "TokenChunker",
|
||||
"递归分块": "RecursiveChunker",
|
||||
"token 分块": "TokenChunker",
|
||||
"token分块": "TokenChunker",
|
||||
"语义分块": "SemanticChunker",
|
||||
"神经网络分块": "NeuralChunker",
|
||||
"混合分块": "HybridChunker",
|
||||
"llm 分块": "LLMChunker",
|
||||
"llm分块": "LLMChunker",
|
||||
"句子分块": "SentenceChunker",
|
||||
"延迟分块": "LateChunker",
|
||||
}
|
||||
if m in alias_map:
|
||||
return alias_map[m]
|
||||
return v
|
||||
|
||||
@field_validator("config_id", "user_id", "apply_id")
|
||||
@classmethod
|
||||
def non_empty_str(cls, v: str) -> str:
|
||||
s = str(v).strip() if v is not None else ""
|
||||
if not s:
|
||||
raise ValueError("标识字段必须为非空字符串")
|
||||
return s
|
||||
|
||||
|
||||
class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,去除主键)
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||
config_desc: str = Field("配置描述", description="配置描述(字符串)")
|
||||
workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)")
|
||||
|
||||
# 模型配置字段(可选,用于手动指定或自动填充)
|
||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||
|
||||
|
||||
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
# config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||
config_id: int = Field("配置ID", description="配置ID(字符串)")
|
||||
|
||||
|
||||
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||
config_id: Optional[int] = None
|
||||
config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||
config_desc: str = Field("配置描述", description="配置描述(字符串)")
|
||||
|
||||
|
||||
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||
config_id: Optional[int] = None
|
||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||
enable_llm_dedup_blockwise: Optional[bool] = None
|
||||
enable_llm_disambiguation: Optional[bool] = None
|
||||
deep_retrieval: Optional[bool] = Field(None, validation_alias="deep_retrieval")
|
||||
|
||||
t_type_strict: Optional[float] = Field(None, ge=0.0, le=1.0)
|
||||
t_name_strict: Optional[float] = Field(None, ge=0.0, le=1.0)
|
||||
t_overall: Optional[float] = Field(None, ge=0.0, le=1.0)
|
||||
state: Optional[bool] = None
|
||||
chunker_strategy: Optional[ChunkerStrategy] = None
|
||||
# 句子提取
|
||||
statement_granularity: Optional[int] = Field(2, ge=1, le=3, description="陈述提取颗粒度,挡位 1/2/3;默认 2")
|
||||
include_dialogue_context: Optional[bool] = None
|
||||
max_context: Optional[int] = Field(1000, gt=100, description="对话语境中包含字符的最大数量(>100);默认 1000")
|
||||
|
||||
# 剪枝配置:与 runtime.json 中 pruning 段对应
|
||||
pruning_enabled: Optional[bool] = Field(None, description="是否启动智能语义剪枝")
|
||||
pruning_scene: Optional[Literal["education", "online_service", "outbound"]] = Field(
|
||||
None, description="智能剪枝场景:education/online_service/outbound"
|
||||
)
|
||||
pruning_threshold: Optional[float] = Field(
|
||||
None, ge=0.0, le=0.9, description="智能语义剪枝阈值(0-0.9)"
|
||||
)
|
||||
|
||||
# 反思配置
|
||||
enable_self_reflexion: Optional[bool] = Field(None, description="是否启用自我反思")
|
||||
iteration_period: Optional[Literal["1", "3", "6", "12", "24"]] = Field(
|
||||
"3", description="反思迭代周期,单位小时"
|
||||
)
|
||||
reflexion_range: Optional[Literal["retrieval", "database"]] = Field(
|
||||
"retrieval", description="反思范围:部分/全部"
|
||||
)
|
||||
baseline: Optional[Literal["TIME", "FACT", "TIME-FACT"]] = Field(
|
||||
"TIME", description="基线:时间/事实/时间和事实"
|
||||
)
|
||||
|
||||
@field_validator("chunker_strategy", mode="before")
|
||||
@classmethod
|
||||
def map_chunker_aliases_update(cls, v: str):
|
||||
if isinstance(v, str):
|
||||
m = v.strip().lower()
|
||||
alias_map = {
|
||||
"auto": "RecursiveChunker",
|
||||
"by_sentence": "SentenceChunker",
|
||||
"by_paragraph": "SemanticChunker",
|
||||
"fixed_tokens": "TokenChunker",
|
||||
"递归分块": "RecursiveChunker",
|
||||
"token 分块": "TokenChunker",
|
||||
"token分块": "TokenChunker",
|
||||
"语义分块": "SemanticChunker",
|
||||
"神经网络分块": "NeuralChunker",
|
||||
"混合分块": "HybridChunker",
|
||||
"llm 分块": "LLMChunker",
|
||||
"llm分块": "LLMChunker",
|
||||
"句子分块": "SentenceChunker",
|
||||
"延迟分块": "LateChunker",
|
||||
}
|
||||
if m in alias_map:
|
||||
return alias_map[m]
|
||||
return v
|
||||
|
||||
|
||||
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
|
||||
# 遗忘引擎配置参数更新模型
|
||||
config_id: Optional[int] = None
|
||||
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5")
|
||||
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5")
|
||||
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0")
|
||||
|
||||
|
||||
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
|
||||
config_id: int = Field(..., description="配置ID(唯一)")
|
||||
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
|
||||
class ConfigFilter(BaseModel): # 查询配置参数时使用的模型
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
config_id: Optional[int] = None
|
||||
user_id: Optional[str] = None
|
||||
apply_id: Optional[str] = None
|
||||
|
||||
limit: int = Field(20, ge=1, le=200, description="返回数量上限")
|
||||
offset: int = Field(0, ge=0, description="起始偏移")
|
||||
|
||||
|
||||
class ApiResponse(BaseModel): # 通用API响应模型
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
code: int = Field(..., description="0=成功,非0=各类业务异常")
|
||||
msg: str = Field("", description="说明信息")
|
||||
data: Optional[Any] = Field(None, description="返回数据载荷")
|
||||
error: str = Field("", description="错误信息,失败时有值,成功为空字符串")
|
||||
time: Optional[int] = Field(None, description="响应时间(毫秒,Unix 时间戳)")
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
return int(round(time.time() * 1000))
|
||||
|
||||
|
||||
def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None) -> ApiResponse:
|
||||
return ApiResponse(code=0, msg=msg, data=data, error="", time=time or _now_ms())
|
||||
|
||||
|
||||
def fail(
|
||||
msg: str,
|
||||
error_code: str = "ERROR",
|
||||
data: Optional[Any] = None,
|
||||
time: Optional[int] = None,
|
||||
query_preview: Optional[str] = None,
|
||||
) -> ApiResponse:
|
||||
payload = data
|
||||
if query_preview is not None:
|
||||
if payload is None:
|
||||
payload = {"query_preview": query_preview}
|
||||
elif isinstance(payload, dict):
|
||||
payload = {**payload, "query_preview": query_preview}
|
||||
else:
|
||||
payload = {"data": payload, "query_preview": query_preview}
|
||||
|
||||
return ApiResponse(
|
||||
code=1,
|
||||
msg=msg,
|
||||
data=payload,
|
||||
error=error_code,
|
||||
time=time or _now_ms(),
|
||||
)
|
||||
162
api/app/schemas/model_schema.py
Normal file
162
api/app/schemas/model_schema.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||
from typing import Optional, List, Dict, Any
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
|
||||
|
||||
|
||||
# ModelConfig Schemas
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""模型配置基础Schema"""
|
||||
name: str = Field(..., description="模型显示名称", max_length=255)
|
||||
type: ModelType = Field(..., description="模型类型")
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
is_public: bool = Field(False, description="是否公开")
|
||||
|
||||
|
||||
class ApiKeyCreateNested(BaseModel):
|
||||
"""用于在创建模型时内嵌创建API Key的Schema"""
|
||||
model_name: str = Field(..., description="模型实际名称", max_length=255)
|
||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||||
priority: str = Field("1", description="优先级", max_length=10)
|
||||
|
||||
|
||||
class ModelConfigCreate(ModelConfigBase):
|
||||
"""创建模型配置Schema"""
|
||||
api_keys: Optional[ApiKeyCreateNested] = Field(None, description="同时创建的API Key配置")
|
||||
skip_validation: Optional[bool] = Field(False, description="是否跳过配置验证")
|
||||
|
||||
|
||||
class ModelConfigUpdate(BaseModel):
|
||||
"""更新模型配置Schema"""
|
||||
name: Optional[str] = Field(None, description="模型显示名称", max_length=255)
|
||||
type: Optional[ModelType] = Field(None, description="模型类型")
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数")
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
is_public: Optional[bool] = Field(None, description="是否公开")
|
||||
|
||||
|
||||
class ModelConfig(ModelConfigBase):
|
||||
"""模型配置Schema"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
api_keys: List["ModelApiKey"] = []
|
||||
|
||||
|
||||
# ModelApiKey Schemas
|
||||
class ModelApiKeyBase(BaseModel):
|
||||
"""API Key基础Schema"""
|
||||
model_name: str = Field(..., description="模型实际名称", max_length=255)
|
||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置")
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
priority: str = Field("1", description="优先级", max_length=10)
|
||||
|
||||
|
||||
class ModelApiKeyCreate(ModelApiKeyBase):
|
||||
"""创建API Key Schema"""
|
||||
model_config_id: uuid.UUID = Field(..., description="模型配置ID")
|
||||
|
||||
|
||||
class ModelApiKeyUpdate(BaseModel):
|
||||
"""更新API Key Schema"""
|
||||
model_name: Optional[str] = Field(None, description="模型实际名称", max_length=255)
|
||||
provider: Optional[ModelProvider] = Field(None, description="API Key提供商")
|
||||
api_key: Optional[str] = Field(None, description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置")
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
priority: Optional[str] = Field(None, description="优先级", max_length=10)
|
||||
|
||||
|
||||
class ModelApiKey(ModelApiKeyBase):
|
||||
"""API Key Schema"""
|
||||
id: uuid.UUID
|
||||
model_config_id: uuid.UUID
|
||||
usage_count: str
|
||||
last_used_at: Optional[datetime.datetime]
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("last_used_at", when_used="json")
|
||||
def _serialize_last_used_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
# 查询和响应Schemas
|
||||
class ModelConfigQuery(BaseModel):
|
||||
"""模型配置查询Schema"""
|
||||
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
||||
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
|
||||
is_active: Optional[bool] = Field(None, description="激活状态筛选")
|
||||
is_public: Optional[bool] = Field(None, description="公开状态筛选")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
page: int = Field(1, description="页码", ge=1)
|
||||
pagesize: int = Field(10, description="每页数量", ge=1, le=100)
|
||||
|
||||
class ModelMarketplace(BaseModel):
|
||||
"""模型广场响应Schema"""
|
||||
llm_models: List[ModelConfig] = []
|
||||
embedding_models: List[ModelConfig] = []
|
||||
rerank_models: List[ModelConfig] = []
|
||||
total_count: int
|
||||
active_count: int
|
||||
|
||||
|
||||
# 统计信息Schema
|
||||
class ModelStats(BaseModel):
|
||||
"""模型统计信息Schema"""
|
||||
total_models: int
|
||||
active_models: int
|
||||
llm_count: int
|
||||
embedding_count: int
|
||||
rerank_count: int
|
||||
provider_stats: Dict[str, int]
|
||||
|
||||
|
||||
# 验证模型配置Schema
|
||||
class ModelValidateRequest(BaseModel):
|
||||
"""验证模型配置请求"""
|
||||
model_name: str = Field(..., description="模型实际名称")
|
||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||
api_key: str = Field(..., description="API密钥")
|
||||
api_base: Optional[str] = Field(None, description="API基础URL")
|
||||
model_type: Optional[ModelType] = Field(ModelType.LLM, description="模型类型")
|
||||
test_message: Optional[str] = Field("Hello", description="测试消息")
|
||||
|
||||
|
||||
class ModelValidateResponse(BaseModel):
|
||||
"""验证模型配置响应"""
|
||||
valid: bool = Field(..., description="是否有效")
|
||||
message: str = Field(..., description="验证消息")
|
||||
response: Optional[str] = Field(None, description="模型响应内容")
|
||||
elapsed_time: Optional[float] = Field(None, description="响应时间(秒)")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
usage: Optional[Dict[str, Any]] = Field(None, description="Token使用情况")
|
||||
|
||||
|
||||
# 更新前向引用
|
||||
ModelConfig.model_rebuild()
|
||||
167
api/app/schemas/multi_agent_schema.py
Normal file
167
api/app/schemas/multi_agent_schema.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""多 Agent 相关的 Schema 定义"""
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||
|
||||
|
||||
# ==================== 子 Agent 配置 ====================
|
||||
|
||||
class SubAgentConfig(BaseModel):
|
||||
"""子 Agent 配置"""
|
||||
agent_id: uuid.UUID = Field(..., description="Agent ID")
|
||||
name: str = Field(..., description="Agent 名称")
|
||||
role: Optional[str] = Field(None, description="角色描述")
|
||||
priority: int = Field(default=1, ge=1, le=100, description="优先级(1-100)")
|
||||
capabilities: List[str] = Field(default_factory=list, description="能力列表")
|
||||
|
||||
|
||||
class RoutingRule(BaseModel):
|
||||
"""路由规则"""
|
||||
condition: str = Field(..., description="条件表达式")
|
||||
target_agent_id: uuid.UUID = Field(..., description="目标 Agent ID")
|
||||
priority: int = Field(default=1, ge=1, le=100, description="优先级")
|
||||
|
||||
|
||||
class ExecutionConfig(BaseModel):
|
||||
"""执行配置"""
|
||||
max_iterations: int = Field(default=5, ge=1, le=20, description="最大迭代次数")
|
||||
timeout: int = Field(default=60, ge=10, le=300, description="超时时间(秒)")
|
||||
parallel_limit: int = Field(default=3, ge=1, le=10, description="并行限制")
|
||||
retry_on_failure: bool = Field(default=True, description="失败时是否重试")
|
||||
max_retries: int = Field(default=3, ge=0, le=10, description="最大重试次数")
|
||||
|
||||
|
||||
# ==================== 多 Agent 配置 ====================
|
||||
|
||||
class MultiAgentConfigCreate(BaseModel):
|
||||
"""创建多 Agent 配置"""
|
||||
master_agent_id: uuid.UUID = Field(..., description="主 Agent ID")
|
||||
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
|
||||
orchestration_mode: str = Field(
|
||||
...,
|
||||
pattern="^(sequential|parallel|conditional|loop)$",
|
||||
description="编排模式:sequential|parallel|conditional|loop"
|
||||
)
|
||||
sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表")
|
||||
routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则")
|
||||
execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置")
|
||||
aggregation_strategy: str = Field(
|
||||
default="merge",
|
||||
pattern="^(merge|vote|priority|custom)$",
|
||||
description="结果整合策略:merge|vote|priority|custom"
|
||||
)
|
||||
|
||||
|
||||
class MultiAgentConfigUpdate(BaseModel):
|
||||
"""更新多 Agent 配置"""
|
||||
master_agent_id: Optional[uuid.UUID] = None
|
||||
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
|
||||
orchestration_mode: Optional[str] = Field(
|
||||
None,
|
||||
pattern="^(sequential|parallel|conditional|loop)$"
|
||||
)
|
||||
sub_agents: Optional[List[SubAgentConfig]] = None
|
||||
routing_rules: Optional[List[RoutingRule]] = None
|
||||
execution_config: Optional[ExecutionConfig] = None
|
||||
aggregation_strategy: Optional[str] = Field(
|
||||
None,
|
||||
pattern="^(merge|vote|priority|custom)$"
|
||||
)
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class MultiAgentConfigSchema(BaseModel):
|
||||
"""多 Agent 配置输出"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
app_id: uuid.UUID
|
||||
master_agent_id: uuid.UUID
|
||||
master_agent_name: Optional[str]
|
||||
orchestration_mode: str
|
||||
sub_agents: List[Dict[str, Any]]
|
||||
routing_rules: Optional[List[Dict[str, Any]]]
|
||||
execution_config: Dict[str, Any]
|
||||
aggregation_strategy: str
|
||||
is_active: bool
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
# ==================== 多 Agent 运行 ====================
|
||||
|
||||
class MultiAgentRunRequest(BaseModel):
|
||||
"""多 Agent 运行请求"""
|
||||
message: str = Field(..., description="用户消息")
|
||||
conversation_id: Optional[uuid.UUID] = Field(None, description="会话 ID")
|
||||
user_id: Optional[str] = Field(None, description="用户 ID")
|
||||
variables: Optional[Dict[str, Any]] = Field(None, description="变量参数")
|
||||
use_llm_routing: bool = Field(default=True, description="是否启用 LLM 路由(默认启用)")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
web_search: bool = Field(default=False, description="是否启用网络搜索")
|
||||
memory: bool = Field(default=True, description="是否启用记忆功能")
|
||||
|
||||
|
||||
class SubAgentResult(BaseModel):
|
||||
"""子 Agent 执行结果"""
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
elapsed_time: Optional[float] = None
|
||||
|
||||
|
||||
class MultiAgentRunResponse(BaseModel):
|
||||
"""多 Agent 运行响应"""
|
||||
message: str = Field(..., description="最终结果")
|
||||
conversation_id: Optional[uuid.UUID] = Field(None, description="会话 ID")
|
||||
elapsed_time: float = Field(..., description="总耗时(秒)")
|
||||
mode: str = Field(..., description="执行模式")
|
||||
sub_results: Union[List[Dict[str, Any]], Dict[str, Any]] = Field(..., description="子 Agent 结果")
|
||||
usage: Optional[Dict[str, Any]] = Field(None, description="资源使用情况")
|
||||
|
||||
|
||||
# ==================== 智能路由测试 ====================
|
||||
|
||||
class RoutingTestRequest(BaseModel):
|
||||
"""路由测试请求"""
|
||||
message: str = Field(..., description="测试消息")
|
||||
conversation_id: Optional[uuid.UUID] = Field(None, description="会话 ID(可选)")
|
||||
routing_model_id: Optional[uuid.UUID] = Field(None, description="路由模型 ID(用于 LLM 路由)")
|
||||
use_llm: bool = Field(default=False, description="是否启用 LLM 路由")
|
||||
keyword_threshold: Optional[float] = Field(
|
||||
default=0.8,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="关键词置信度阈值(0-1)"
|
||||
)
|
||||
force_new: bool = Field(default=False, description="是否强制重新路由")
|
||||
|
||||
|
||||
class RoutingTestCase(BaseModel):
|
||||
"""路由测试用例"""
|
||||
message: str = Field(..., description="测试消息")
|
||||
expected_agent_id: Optional[uuid.UUID] = Field(None, description="期望的 Agent ID")
|
||||
description: Optional[str] = Field(None, description="测试用例描述")
|
||||
|
||||
|
||||
class BatchRoutingTestRequest(BaseModel):
|
||||
"""批量路由测试请求"""
|
||||
test_cases: List[RoutingTestCase] = Field(..., description="测试用例列表")
|
||||
routing_model_id: Optional[uuid.UUID] = Field(None, description="路由模型 ID")
|
||||
use_llm: bool = Field(default=False, description="是否启用 LLM 路由")
|
||||
keyword_threshold: Optional[float] = Field(
|
||||
default=0.8,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="关键词置信度阈值"
|
||||
)
|
||||
61
api/app/schemas/prompt_schema.py
Normal file
61
api/app/schemas/prompt_schema.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from jinja2 import Environment, Template, meta
|
||||
from typing import Any, Dict
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
from abc import ABC
|
||||
from typing import Union, List
|
||||
|
||||
|
||||
class PromptMessageRole(str, Enum):
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
class TextPromptMessageContent(BaseModel):
|
||||
type: str = Field(default="text")
|
||||
data: str
|
||||
PromptMessageContentUnionTypes = TextPromptMessageContent
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
role: PromptMessageRole
|
||||
content: Union[str, List[PromptMessageContentUnionTypes], None] = None
|
||||
name: Union[str, None] = None
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return not self.content
|
||||
|
||||
def get_text_content(self) -> str:
|
||||
if isinstance(self.content, str):
|
||||
return self.content
|
||||
elif isinstance(self.content, list):
|
||||
return "".join([item.data for item in self.content if isinstance(item, TextPromptMessageContent)])
|
||||
return ""
|
||||
|
||||
|
||||
def render_prompt_message(template_str: str, role: PromptMessageRole, params: Dict[str, Any]) -> PromptMessage:
|
||||
"""
|
||||
通用函数:自动解析模板变量,渲染PromptMessage
|
||||
- template_str: Jinja2模板字符串
|
||||
- role: PromptMessageRole
|
||||
- params: 提供模板变量的字典
|
||||
"""
|
||||
env = Environment()
|
||||
parsed_content = env.parse(template_str)
|
||||
variables = meta.find_undeclared_variables(parsed_content)
|
||||
|
||||
# 检查缺失参数,如果缺失则给默认值 ''
|
||||
for var in variables:
|
||||
if var not in params:
|
||||
params[var] = ""
|
||||
|
||||
# 渲染模板
|
||||
jinja_template = Template(template_str)
|
||||
rendered_text = jinja_template.render(**params)
|
||||
|
||||
return PromptMessage(
|
||||
role=role,
|
||||
content=[TextPromptMessageContent(data=rendered_text)]
|
||||
)
|
||||
|
||||
|
||||
104
api/app/schemas/release_share_schema.py
Normal file
104
api/app/schemas/release_share_schema.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||
|
||||
|
||||
# ---------- Input Schemas ----------
|
||||
|
||||
class ReleaseShareCreate(BaseModel):
|
||||
"""创建/启用分享配置"""
|
||||
is_enabled: bool = Field(default=True, description="是否启用公开分享")
|
||||
require_password: bool = Field(default=False, description="是否需要密码访问")
|
||||
password: Optional[str] = Field(None, min_length=4, max_length=50, description="访问密码(4-50字符)")
|
||||
allow_embed: bool = Field(default=False, description="是否允许嵌入")
|
||||
embed_domains: Optional[List[str]] = Field(default=None, description="允许嵌入的域名白名单,空表示不限制")
|
||||
|
||||
|
||||
class ReleaseShareUpdate(BaseModel):
|
||||
"""更新分享配置"""
|
||||
is_enabled: Optional[bool] = Field(None, description="是否启用公开分享")
|
||||
require_password: Optional[bool] = Field(None, description="是否需要密码访问")
|
||||
password: Optional[str] = Field(None, min_length=4, max_length=50, description="访问密码")
|
||||
allow_embed: Optional[bool] = Field(None, description="是否允许嵌入")
|
||||
embed_domains: Optional[List[str]] = Field(None, description="允许嵌入的域名白名单")
|
||||
|
||||
|
||||
class PasswordVerifyRequest(BaseModel):
|
||||
"""密码验证请求"""
|
||||
password: str = Field(..., description="访问密码")
|
||||
|
||||
|
||||
class TokenRequest(BaseModel):
|
||||
"""获取访问 token 请求"""
|
||||
user_id: Optional[str] = Field(None, description="用户 ID(可选,不提供则自动生成)")
|
||||
password: Optional[str] = Field(None, description="访问密码(如果需要)")
|
||||
|
||||
|
||||
# ---------- Output Schemas ----------
|
||||
|
||||
class ReleaseShare(BaseModel):
|
||||
"""分享配置输出"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
release_id: uuid.UUID
|
||||
app_id: uuid.UUID
|
||||
is_enabled: bool
|
||||
share_token: str
|
||||
share_url: str # 完整的公开访问 URL
|
||||
require_password: bool
|
||||
allow_embed: bool
|
||||
embed_domains: List[str] = []
|
||||
view_count: int
|
||||
last_accessed_at: Optional[datetime.datetime] = None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("last_accessed_at", when_used="json")
|
||||
def _serialize_last_accessed_at(self, dt: Optional[datetime.datetime]):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class SharedReleaseInfo(BaseModel):
|
||||
"""公开访问返回的应用信息"""
|
||||
app_name: str
|
||||
app_description: Optional[str] = None
|
||||
app_icon: Optional[str] = None
|
||||
app_type: str
|
||||
version: int
|
||||
release_notes: Optional[str] = None
|
||||
published_at: int
|
||||
|
||||
# 根据应用类型返回不同配置
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
# 访问控制信息
|
||||
require_password: bool
|
||||
is_password_verified: bool = False # 当前是否已验证密码
|
||||
|
||||
# 嵌入配置
|
||||
allow_embed: bool
|
||||
|
||||
|
||||
class EmbedCode(BaseModel):
|
||||
"""嵌入代码"""
|
||||
iframe_code: str = Field(..., description="iframe 嵌入代码")
|
||||
preview_url: str = Field(..., description="预览 URL")
|
||||
width: str = Field(default="100%", description="宽度")
|
||||
height: str = Field(default="600px", description="高度")
|
||||
|
||||
|
||||
class ShareStats(BaseModel):
|
||||
"""分享统计"""
|
||||
view_count: int
|
||||
last_accessed_at: Optional[int] = None
|
||||
created_at: int
|
||||
22
api/app/schemas/response_schema.py
Normal file
22
api/app/schemas/response_schema.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Optional
|
||||
import time
|
||||
|
||||
|
||||
class PageMeta(BaseModel):
|
||||
page: int = Field(..., description="当前页码,从1开始")
|
||||
pagesize: int = Field(..., description="每页数量")
|
||||
total: int = Field(..., description="总条数")
|
||||
hasnext: bool = Field(..., description="是否有下一页")
|
||||
|
||||
class PageData(BaseModel):
|
||||
page: PageMeta = Field(..., description="分页元数据")
|
||||
items: list = Field(..., description="分页数据列表")
|
||||
|
||||
|
||||
class ApiResponse(BaseModel):
|
||||
code: int = Field(0, description="业务状态码,0=成功,非0=各类业务异常")
|
||||
msg: str = Field("OK", description="给人看的简短提示")
|
||||
data: Optional[Any] = Field(None, description="具体数据")
|
||||
error: str = Field("", description="失败时的字段级错误信息,成功时为空字符串")
|
||||
time: int = Field(default_factory=lambda: int(time.time()), description="Unix时间戳(秒)")
|
||||
13
api/app/schemas/retrieval_info_schema.py
Normal file
13
api/app/schemas/retrieval_info_schema.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Optional, Text
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import ConfigDict
|
||||
|
||||
class Host(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID = Field(description="宿主ID")
|
||||
host_id: uuid.UUID = Field(description="其他ID")
|
||||
retrieve_info: Optional[Text] = Field(description="检索信息")
|
||||
created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now)
|
||||
65
api/app/schemas/tenant_schema.py
Normal file
65
api/app/schemas/tenant_schema.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
||||
from typing import Optional, List
|
||||
import datetime
|
||||
import uuid
|
||||
from app.core.exceptions import ValidationException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
|
||||
class TenantBase(BaseModel):
|
||||
"""租户基础Schema"""
|
||||
name: str = Field(..., description="租户名称", max_length=255)
|
||||
description: Optional[str] = Field(None, description="租户描述", max_length=1000)
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_name(cls, v):
|
||||
if not v or not v.strip():
|
||||
raise ValidationException('租户名称不能为空', code=BizCode.VALIDATION_FAILED)
|
||||
return v.strip()
|
||||
|
||||
|
||||
class TenantCreate(TenantBase):
|
||||
"""创建租户Schema"""
|
||||
pass
|
||||
|
||||
|
||||
class TenantUpdate(BaseModel):
|
||||
"""更新租户Schema"""
|
||||
name: Optional[str] = Field(None, description="租户名称", max_length=255)
|
||||
description: Optional[str] = Field(None, description="租户描述", max_length=1000)
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_name(cls, v):
|
||||
if v is not None and (not v or not v.strip()):
|
||||
raise ValidationException('租户名称不能为空', code=BizCode.VALIDATION_FAILED)
|
||||
return v.strip() if v else v
|
||||
|
||||
|
||||
class Tenant(TenantBase):
|
||||
"""租户Schema"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
|
||||
class TenantQuery(BaseModel):
|
||||
"""租户查询Schema"""
|
||||
is_active: Optional[bool] = Field(None, description="激活状态筛选")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
page: int = Field(1, description="页码", ge=1)
|
||||
size: int = Field(10, description="每页数量", ge=1, le=100)
|
||||
|
||||
|
||||
class TenantList(BaseModel):
|
||||
"""租户列表响应Schema"""
|
||||
items: List[Tenant]
|
||||
total: int
|
||||
page: int
|
||||
size: int
|
||||
pages: int
|
||||
30
api/app/schemas/token_schema.py
Normal file
30
api/app/schemas/token_schema.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from pydantic import BaseModel, EmailStr, field_serializer
|
||||
from typing import Optional
|
||||
import datetime
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str
|
||||
expires_at: datetime.datetime
|
||||
refresh_expires_at: datetime.datetime
|
||||
|
||||
@field_serializer("expires_at", when_used="json")
|
||||
def _serialize_expires_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("refresh_expires_at", when_used="json")
|
||||
def _serialize_refresh_expires_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
class TokenData(BaseModel):
|
||||
userId: Optional[str] = None
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
class TokenRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
invite: Optional[str] = None
|
||||
|
||||
76
api/app/schemas/user_schema.py
Normal file
76
api/app/schemas/user_schema.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from dataclasses import field
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator, validator, ConfigDict
|
||||
from typing import Optional
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from app.models import Workspace
|
||||
from app.models.workspace_model import WorkspaceRole
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
username: str
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
username: Optional[str] = None
|
||||
email: Optional[EmailStr] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_superuser: Optional[bool] = None
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
"""修改密码请求"""
|
||||
old_password: str = Field(..., description="当前密码")
|
||||
new_password: str = Field(..., min_length=6, description="新密码,至少6位")
|
||||
|
||||
|
||||
class AdminChangePasswordRequest(BaseModel):
|
||||
"""管理员修改用户密码请求"""
|
||||
user_id: uuid.UUID = Field(..., description="要修改密码的用户ID")
|
||||
new_password: Optional[str] = Field(None, min_length=6, description="新密码,至少6位。如果不提供则自动生成随机密码")
|
||||
|
||||
|
||||
class ChangePasswordResponse(BaseModel):
|
||||
"""修改密码响应"""
|
||||
message: str
|
||||
success: bool = True
|
||||
generated_password: Optional[str] = Field(None, description="自动生成的密码(仅在管理员重置时返回)")
|
||||
|
||||
|
||||
class User(UserBase):
|
||||
id: uuid.UUID
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
created_at: int
|
||||
last_login_at: Optional[int] = None
|
||||
current_workspace_id: Optional[uuid.UUID] = None
|
||||
current_workspace_name: Optional[str] = None
|
||||
role: Optional[WorkspaceRole] = None
|
||||
|
||||
# 将 datetime 转换为毫秒时间戳
|
||||
@validator("created_at", pre=True)
|
||||
def _created_at_to_ms(cls, v):
|
||||
if isinstance(v, datetime.datetime):
|
||||
return int(v.timestamp() * 1000)
|
||||
if isinstance(v, (int, float)):
|
||||
return int(v)
|
||||
return v
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_validator("last_login_at", mode="before")
|
||||
def _last_login_to_ms(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, datetime.datetime):
|
||||
return int(v.timestamp() * 1000)
|
||||
if isinstance(v, (int, float)):
|
||||
return int(v)
|
||||
return v
|
||||
|
||||
172
api/app/schemas/workspace_schema.py
Normal file
172
api/app/schemas/workspace_schema.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import email
|
||||
from pydantic import BaseModel, Field, EmailStr, field_serializer, computed_field, ConfigDict
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import Literal
|
||||
from app.models.workspace_model import WorkspaceRole, InviteStatus
|
||||
|
||||
|
||||
class WorkspaceBase(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
icon: str | None = None
|
||||
iconType: str | None = None
|
||||
storage_type: str | None = None
|
||||
llm: str | None = None
|
||||
embedding: str | None = None
|
||||
rerank: str | None = None
|
||||
|
||||
|
||||
class WorkspaceCreate(WorkspaceBase):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class WorkspaceUpdate(BaseModel):
|
||||
name: str | None = Field(None)
|
||||
description: str | None = Field(None)
|
||||
icon: str | None = Field(None)
|
||||
iconType: str | None = Field(None)
|
||||
storage_type: str | None = Field(None)
|
||||
llm: str | None = Field(None)
|
||||
embedding: str | None = Field(None)
|
||||
rerank: str | None = Field(None)
|
||||
|
||||
|
||||
class Workspace(WorkspaceBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
tenant_id: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class WorkspaceResponse(WorkspaceBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
tenant_id: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
is_active: bool
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp()) if dt else None
|
||||
|
||||
|
||||
class WorkspaceMemberBase(BaseModel):
|
||||
user_id: uuid.UUID
|
||||
role: WorkspaceRole
|
||||
|
||||
|
||||
class WorkspaceMemberCreate(WorkspaceMemberBase):
|
||||
pass
|
||||
|
||||
class WorkspaceMemberUpdate(BaseModel):
|
||||
id: uuid.UUID
|
||||
role: WorkspaceRole
|
||||
|
||||
class WorkspaceMember(WorkspaceMemberBase):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
workspace_id: uuid.UUID
|
||||
email: str
|
||||
|
||||
|
||||
# 简版嵌套模型用于成员详情的关系序列化
|
||||
class UserShort(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class WorkspaceShort(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
|
||||
|
||||
class WorkspaceMemberDetail(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
role: WorkspaceRole
|
||||
is_active: bool
|
||||
user: UserShort
|
||||
workspace: WorkspaceShort
|
||||
|
||||
|
||||
# 成员管理表格视图项(扁平化字段,便于前端表格渲染)
|
||||
class WorkspaceMemberItem(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
username: str
|
||||
account: EmailStr
|
||||
role: WorkspaceRole # 原始角色值:manager | member
|
||||
last_login_at: datetime.datetime | None = None
|
||||
|
||||
# 将最后登录时间序列化为毫秒时间戳,便于前端统一格式化
|
||||
@field_serializer("last_login_at", when_used="json")
|
||||
def _serialize_last_login(self, dt: datetime.datetime | None):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
# # 动态计算角色中文标签
|
||||
# @computed_field
|
||||
# def role_label(self) -> str:
|
||||
# return "管理员" if self.role == WorkspaceRole.manager else "成员"
|
||||
|
||||
|
||||
# Workspace Invite Schemas
|
||||
class WorkspaceInviteCreate(BaseModel):
|
||||
email: EmailStr = Field(..., description="被邀请者邮箱")
|
||||
role: WorkspaceRole = Field(..., description="邀请角色:manager 或 member")
|
||||
expires_in_days: int = Field(default=7, ge=1, le=30, description="邀请有效期天数,默认7天")
|
||||
|
||||
|
||||
class WorkspaceInviteResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
workspace_id: uuid.UUID
|
||||
email: str
|
||||
role: WorkspaceRole
|
||||
status: InviteStatus
|
||||
expires_at: datetime.datetime
|
||||
accepted_at: datetime.datetime | None
|
||||
created_by_user_id: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
invite_token: str | None = Field(None, description="邀请令牌,仅在创建时返回")
|
||||
|
||||
@field_serializer("expires_at", when_used="json")
|
||||
def _serialize_expires_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("accepted_at", when_used="json")
|
||||
def _serialize_accepted_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class InviteValidateResponse(BaseModel):
|
||||
workspace_name: str
|
||||
workspace_id: uuid.UUID
|
||||
email: str
|
||||
role: WorkspaceRole
|
||||
is_expired: bool
|
||||
is_valid: bool
|
||||
|
||||
|
||||
class InviteAcceptRequest(BaseModel):
|
||||
token: str = Field(..., description="邀请令牌")
|
||||
Reference in New Issue
Block a user