Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management
This commit is contained in:
32
api/app/schemas/emotion_schema.py
Normal file
32
api/app/schemas/emotion_schema.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""情绪分析相关的请求和响应模型"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EmotionTagsRequest(BaseModel):
|
||||
"""获取情绪标签统计请求"""
|
||||
group_id: str = Field(..., description="组ID")
|
||||
emotion_type: Optional[str] = Field(None, description="情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)")
|
||||
start_date: Optional[str] = Field(None, description="开始日期(ISO格式,如:2024-01-01)")
|
||||
end_date: Optional[str] = Field(None, description="结束日期(ISO格式,如:2024-12-31)")
|
||||
limit: int = Field(10, ge=1, le=100, description="返回数量限制")
|
||||
|
||||
|
||||
class EmotionWordcloudRequest(BaseModel):
|
||||
"""获取情绪词云数据请求"""
|
||||
group_id: str = Field(..., description="组ID")
|
||||
emotion_type: Optional[str] = Field(None, description="情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)")
|
||||
limit: int = Field(50, ge=1, le=200, description="返回词语数量")
|
||||
|
||||
|
||||
class EmotionHealthRequest(BaseModel):
|
||||
"""获取情绪健康指数请求"""
|
||||
group_id: str = Field(..., description="组ID")
|
||||
time_range: str = Field("30d", description="时间范围(7d/30d/90d)")
|
||||
|
||||
|
||||
class EmotionSuggestionsRequest(BaseModel):
|
||||
"""获取个性化情绪建议请求"""
|
||||
group_id: str = Field(..., description="组ID")
|
||||
config_id: Optional[int] = Field(None, description="配置ID(用于指定LLM模型)")
|
||||
@@ -13,5 +13,6 @@ class EndUser(BaseModel):
|
||||
other_id: Optional[str] = Field(description="第三方ID", default=None)
|
||||
other_name: Optional[str] = Field(description="其他名称", default="")
|
||||
other_address: Optional[str] = Field(description="其他地址", default="")
|
||||
reflection_time: Optional[datetime.datetime] = Field(description="反思时间", default_factory=datetime.datetime.now)
|
||||
created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now)
|
||||
updated_at: datetime.datetime = Field(description="更新时间", default_factory=datetime.datetime.now)
|
||||
|
||||
52
api/app/schemas/memory_reflection_schemas.py
Normal file
52
api/app/schemas/memory_reflection_schemas.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OptimizationStrategy(str, Enum):
|
||||
"""优化策略枚举"""
|
||||
SPEED_FIRST = "speed_first"
|
||||
ACCURACY_FIRST = "accuracy_first"
|
||||
BALANCED = "balanced"
|
||||
class Memory_Reflection(BaseModel):
|
||||
config_id: Optional[int] = None
|
||||
reflection_enabled: bool
|
||||
reflection_period_in_hours: str
|
||||
reflexion_range: str
|
||||
baseline: str
|
||||
reflection_model_id: str
|
||||
memory_verify: bool
|
||||
quality_assessment: bool
|
||||
|
||||
# 新增快速引擎优化参数
|
||||
optimization_strategy: Optional[OptimizationStrategy] = OptimizationStrategy.BALANCED
|
||||
use_fast_model: Optional[bool] = True
|
||||
enable_caching: Optional[bool] = True
|
||||
enable_streaming: Optional[bool] = True
|
||||
batch_size: Optional[int] = Field(default=3, ge=1, le=10)
|
||||
max_concurrent: Optional[int] = Field(default=5, ge=1, le=20)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class FastReflectionRequest(BaseModel):
|
||||
"""快速反思请求模型"""
|
||||
reflection: Memory_Reflection
|
||||
host_id: Optional[str] = "88a459f5_text02"
|
||||
optimization_strategy: Optional[OptimizationStrategy] = OptimizationStrategy.BALANCED
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class ReflectionBenchmarkRequest(BaseModel):
|
||||
"""反思基准测试请求模型"""
|
||||
reflection: Memory_Reflection
|
||||
host_id: Optional[str] = "88a459f5_text02"
|
||||
iterations: Optional[int] = Field(default=3, ge=1, le=10)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
所有的内容是放错误地方了,应该放在models
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, List, Dict, Literal
|
||||
from typing import Any, Optional, List, Dict, Literal, Union
|
||||
import time
|
||||
import uuid
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
|
||||
@@ -28,25 +28,48 @@ class Write_UserInput(BaseModel):
|
||||
# ============================================================================
|
||||
class BaseDataSchema(BaseModel):
|
||||
"""Base schema for the data"""
|
||||
id: str = Field(..., description="The unique identifier for the data entry.")
|
||||
statement: str = Field(..., description="The statement text.")
|
||||
group_id: str = Field(..., description="The group identifier.")
|
||||
chunk_id: str = Field(..., description="The chunk identifier.")
|
||||
# 保持原有必需字段为可选,以兼容不同数据源
|
||||
id: Optional[str] = Field(None, description="The unique identifier for the data entry.")
|
||||
statement: Optional[str] = Field(None, description="The statement text.")
|
||||
group_id: Optional[str] = Field(None, description="The group identifier.")
|
||||
chunk_id: Optional[str] = Field(None, description="The chunk identifier.")
|
||||
created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.")
|
||||
expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.")
|
||||
valid_at: Optional[str] = Field(None, description="The validation timestamp in ISO 8601 format.")
|
||||
invalid_at: Optional[str] = Field(None, description="The invalidation timestamp in ISO 8601 format.")
|
||||
entity_ids: List[str] = Field([], description="The list of entity identifiers.")
|
||||
description: Optional[str] = Field(None, description="The description of the data entry.")
|
||||
|
||||
# 新增字段以匹配实际输入数据
|
||||
entity1_name: str = Field(..., description="The first entity name.")
|
||||
entity2_name: Optional[str] = Field(None, description="The second entity name.")
|
||||
statement_id: str = Field(..., description="The statement identifier.")
|
||||
relationship_type: str = Field(..., description="The relationship type.")
|
||||
relationship: Optional[Dict[str, Any]] = Field(None, description="The relationship object.")
|
||||
entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.")
|
||||
|
||||
|
||||
class QualityAssessmentSchema(BaseModel):
|
||||
"""Schema for memory quality assessment results."""
|
||||
score: int = Field(..., ge=0, le=100, description="Quality score percentage (0-100).")
|
||||
summary: str = Field(..., description="Brief summary of data quality status, including main issues and strengths.")
|
||||
|
||||
|
||||
class MemoryVerifySchema(BaseModel):
|
||||
"""Schema for memory privacy verification results."""
|
||||
has_privacy: bool = Field(..., description="Whether privacy information was detected.")
|
||||
privacy_types: List[str] = Field([], description="List of detected privacy information types.")
|
||||
summary: str = Field(..., description="Brief summary of privacy detection results.")
|
||||
|
||||
|
||||
class ConflictResultSchema(BaseModel):
|
||||
"""Schema for the conflict result data in the reflexion_data.json file."""
|
||||
data: List[BaseDataSchema] = Field(..., description="The conflict memory data.")
|
||||
data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.")
|
||||
conflict: bool = Field(..., description="Whether the memory is in conflict.")
|
||||
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
||||
quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
|
||||
memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_data(cls, v):
|
||||
if isinstance(v, dict):
|
||||
d = v.get("data")
|
||||
@@ -61,7 +84,6 @@ class ConflictSchema(BaseModel):
|
||||
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_data(cls, v):
|
||||
if isinstance(v, dict):
|
||||
d = v.get("data")
|
||||
@@ -76,21 +98,30 @@ class ReflexionSchema(BaseModel):
|
||||
solution: str = Field(..., description="The solution for the reflexion.")
|
||||
|
||||
|
||||
class ChangeRecordSchema(BaseModel):
|
||||
"""Schema for individual change records"""
|
||||
field: List[Dict[str, str]] = Field(..., description="List of field changes, each containing field name and new value.")
|
||||
|
||||
class ResolvedSchema(BaseModel):
|
||||
"""Schema for the resolved memory data in the reflexion_data"""
|
||||
original_memory_id: Optional[str] = Field(None, description="The original memory identifier.")
|
||||
resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data.")
|
||||
# resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).")
|
||||
resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
|
||||
change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.")
|
||||
|
||||
|
||||
class SingleReflexionResultSchema(BaseModel):
|
||||
"""Schema for a single reflexion result item."""
|
||||
conflict: ConflictResultSchema = Field(..., description="The conflict result data for this specific conflict type.")
|
||||
reflexion: ReflexionSchema = Field(..., description="The reflexion data for this conflict.")
|
||||
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.")
|
||||
type: str = Field("reflexion_result", description="The type identifier.")
|
||||
|
||||
class ReflexionResultSchema(BaseModel):
|
||||
"""Schema for the reflexion result data in the reflexion_data.json file."""
|
||||
# 模型输出中 "conflict" 为单个冲突对象(包含 data 与 conflict_memory),而非字典映射
|
||||
conflict: ConflictResultSchema = Field(..., description="The conflict result data.")
|
||||
reflexion: Optional[ReflexionSchema] = Field(None, description="The reflexion data.")
|
||||
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.")
|
||||
"""Schema for the complete reflexion result data - a list of individual conflict resolutions."""
|
||||
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_resolved(cls, v):
|
||||
if isinstance(v, dict):
|
||||
conflict = v.get("conflict")
|
||||
|
||||
99
api/app/schemas/prompt_optimizer_schema.py
Normal file
99
api/app/schemas/prompt_optimizer_schema.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
# =========================================
|
||||
# API Request Schemas
|
||||
# =========================================
|
||||
class PromptOptMessage(BaseModel):
|
||||
model_id: UUID = Field(
|
||||
...,
|
||||
description="Model ID"
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="User's input message"
|
||||
)
|
||||
|
||||
current_prompt: str = Field(
|
||||
default="",
|
||||
description="currently optimized prompt"
|
||||
)
|
||||
|
||||
|
||||
class PromptOptModelSet(BaseModel):
|
||||
id: UUID | None = Field(
|
||||
default=None,
|
||||
description="Configuration ID"
|
||||
)
|
||||
|
||||
system_prompt: str = Field(
|
||||
...,
|
||||
description="System Prompt"
|
||||
)
|
||||
|
||||
|
||||
# =========================================
|
||||
# Service Layer Results
|
||||
# =========================================
|
||||
class OptimizePromptResult(BaseModel):
|
||||
prompt: str = Field(
|
||||
...,
|
||||
description="Optimized Prompt"
|
||||
)
|
||||
desc: str = Field(
|
||||
...,
|
||||
description="Description"
|
||||
)
|
||||
|
||||
|
||||
# =========================================
|
||||
# API Response Schemas
|
||||
# =========================================
|
||||
class CreateSessionResponse(BaseModel):
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
id: UUID = Field(
|
||||
...,
|
||||
description="Session ID"
|
||||
)
|
||||
|
||||
|
||||
class OptimizePromptResponse(BaseModel):
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
prompt: str = Field(
|
||||
...,
|
||||
description="Optimized Prompt"
|
||||
)
|
||||
desc: str = Field(
|
||||
...,
|
||||
description="Description"
|
||||
)
|
||||
variables: list = Field(
|
||||
...,
|
||||
description="Variables"
|
||||
)
|
||||
|
||||
|
||||
class SessionMessage(BaseModel):
|
||||
role: str = Field(
|
||||
...,
|
||||
description="Message role (user/assistant)"
|
||||
)
|
||||
content: str = Field(
|
||||
...,
|
||||
description="Message content"
|
||||
)
|
||||
|
||||
|
||||
class SessionHistoryResponse(BaseModel):
|
||||
session_id: UUID = Field(
|
||||
...,
|
||||
description="Session ID"
|
||||
)
|
||||
messages: list[SessionMessage] = Field(
|
||||
...,
|
||||
description="List of messages in the session"
|
||||
)
|
||||
Reference in New Issue
Block a user