Merge branch 'develop' into fix/memory-enduser-config

This commit is contained in:
Ke Sun
2026-02-06 16:25:57 +08:00
69 changed files with 38144 additions and 362 deletions

28618
api/General_purpose_entity.ttl Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -290,7 +290,8 @@ async def pilot_run(
api_logger.info(
f"Pilot run requested: config_id={payload.config_id}, "
f"dialogue_text_length={len(payload.dialogue_text)}, language={language}"
f"dialogue_text_length={len(payload.dialogue_text)}, "
f"custom_text_length={len(payload.custom_text) if payload.custom_text else 0}"
)
payload.config_id = resolve_config_id(payload.config_id, db)
svc = DataConfigService(db)

View File

@@ -4,13 +4,14 @@
Endpoints:
POST /api/memory/ontology/extract - 提取本体类
POST /api/memory/ontology/export - 导出OWL文件
POST /api/memory/ontology/export - 按场景导出OWL文件
POST /api/memory/ontology/import - 导入OWL文件到指定场景
POST /api/memory/ontology/scene - 创建本体场景
PUT /api/memory/ontology/scene/{scene_id} - 更新本体场景
DELETE /api/memory/ontology/scene/{scene_id} - 删除本体场景
GET /api/memory/ontology/scene/{scene_id} - 获取单个场景
GET /api/memory/ontology/scenes - 获取场景列表
POST /api/memory/ontology/class - 创建本体类型
POST /api/memory/ontology/class - 创建本体类型(支持批量)
PUT /api/memory/ontology/class/{class_id} - 更新本体类型
DELETE /api/memory/ontology/class/{class_id} - 删除本体类型
GET /api/memory/ontology/class/{class_id} - 获取单个类型
@@ -19,11 +20,15 @@ Endpoints:
import logging
import tempfile
from typing import Dict, Optional
import io
from typing import Dict, Optional, List
from urllib.parse import quote
from fastapi import APIRouter, Depends, HTTPException, Header
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
@@ -31,11 +36,10 @@ from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.core.memory.models.ontology_models import OntologyClass
from typing import List
from app.core.memory.models.ontology_scenario_models import OntologyClass
from app.schemas.ontology_schemas import (
ExportRequest,
ExportResponse,
ExportBySceneRequest,
ExportBySceneResponse,
ExtractionRequest,
ExtractionResponse,
SceneCreateRequest,
@@ -46,6 +50,7 @@ from app.schemas.ontology_schemas import (
ClassUpdateRequest,
ClassResponse,
ClassListResponse,
ImportOwlResponse,
)
from app.schemas.response_schema import ApiResponse
from app.services.ontology_service import OntologyService
@@ -187,22 +192,19 @@ async def extract_ontology(
从场景描述中提取符合OWL规范的本体类。
提取结果仅返回给前端,不会自动保存到数据库。
前端可以从返回结果中选择需要的类型,然后调用 /class 接口创建类型。
支持中英文切换,通过 X-Language-Type Header 指定语言。
Args:
request: 提取请求,包含scenario、domain、llm_id和scene_id
language_type: 语言类型'zh'(中文)或 'en'(英文),默认 'zh'
language_type: 语言类型 Header (zh/en)
db: 数据库会话
current_user: 当前用户
"""
api_logger.info(
f"Ontology extraction requested by user {current_user.id}, "
f"scenario_length={len(request.scenario)}, "
f"domain={request.domain}, "
f"llm_id={request.llm_id}, "
f"scene_id={request.scene_id}, "
f"language_type={language_type}"
f"scene_id={request.scene_id}"
)
try:
@@ -222,7 +224,7 @@ async def extract_ontology(
llm_id=request.llm_id
)
# 调用服务层执行提取传入scene_id和workspace_id
# 调用服务层执行提取
result = await service.extract_ontology(
scenario=request.scenario,
domain=request.domain,
@@ -231,7 +233,7 @@ async def extract_ontology(
language=language
)
# 构建响应(语言已在提取时通过模板控制,无需二次翻译)
# 构建响应
response = ExtractionResponse(
classes=result.classes,
domain=result.domain,
@@ -240,7 +242,7 @@ async def extract_ontology(
api_logger.info(
f"Ontology extraction completed, extracted {len(result.classes)} classes, "
f"saved to scene {request.scene_id}, language={language_type}"
f"scene_id={request.scene_id}, language={language}"
)
return success(data=response.model_dump(), msg="本体提取成功")
@@ -261,146 +263,6 @@ async def extract_ontology(
return fail(BizCode.INTERNAL_ERROR, "本体提取失败", str(e))
@router.post("/export", response_model=ApiResponse)
async def export_owl(
request: ExportRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""导出OWL文件
将提取的本体类导出为OWL文件,支持多种格式。
导出操作不需要LLM,只使用OWL验证器和Owlready2库。
Args:
request: 导出请求,包含classes、format和include_metadata
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含OWL文件内容的响应
Supported formats:
- rdfxml: 标准OWL RDF/XML格式(完整)
- turtle: Turtle格式(可读性好)
- ntriples: N-Triples格式(简单)
- json: JSON格式(简化,只包含类信息)
Response format:
{
"code": 200,
"msg": "OWL文件导出成功",
"data": {
"owl_content": "...",
"format": "rdfxml",
"classes_count": 7
}
}
"""
api_logger.info(
f"OWL export requested by user {current_user.id}, "
f"classes_count={len(request.classes)}, "
f"format={request.format}, "
f"include_metadata={request.include_metadata}"
)
try:
# 验证格式
valid_formats = ["rdfxml", "turtle", "ntriples", "json"]
if request.format not in valid_formats:
api_logger.warning(f"Invalid export format: {request.format}")
return fail(
BizCode.BAD_REQUEST,
"不支持的导出格式",
f"format必须是以下之一: {', '.join(valid_formats)}"
)
# JSON格式直接导出,不需要OWL验证
if request.format == "json":
owl_validator = OWLValidator()
owl_content = owl_validator.export_to_owl(
world=None,
format="json",
classes=request.classes
)
response = ExportResponse(
owl_content=owl_content,
format=request.format,
classes_count=len(request.classes)
)
api_logger.info(
f"JSON export completed, content_length={len(owl_content)}"
)
return success(data=response.model_dump(), msg="OWL文件导出成功")
# 创建临时文件路径
with tempfile.NamedTemporaryFile(
mode='w',
suffix='.owl',
delete=False
) as tmp_file:
output_path = tmp_file.name
# 导出操作不需要LLM,直接使用OWL验证器
owl_validator = OWLValidator()
# 验证本体类
logger.debug("Validating ontology classes")
is_valid, errors, world = owl_validator.validate_ontology_classes(
classes=request.classes,
)
if not is_valid:
logger.warning(
f"OWL validation found {len(errors)} issues during export: {errors}"
)
# 继续导出,但记录警告
if not world:
error_msg = "Failed to create OWL world for export"
logger.error(error_msg)
return fail(BizCode.INTERNAL_ERROR, "创建OWL世界失败", error_msg)
# 导出OWL文件
logger.info(f"Exporting to {request.format} format")
owl_content = owl_validator.export_to_owl(
world=world,
output_path=output_path,
format=request.format,
classes=request.classes
)
# 构建响应
response = ExportResponse(
owl_content=owl_content,
format=request.format,
classes_count=len(request.classes)
)
api_logger.info(
f"OWL export completed, format={request.format}, "
f"content_length={len(owl_content)}"
)
return success(data=response.model_dump(), msg="OWL文件导出成功")
except ValueError as e:
# 验证错误 (400)
api_logger.warning(f"Validation error in export: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
# 运行时错误 (500)
api_logger.error(f"Runtime error in export: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "OWL文件导出失败", str(e))
except Exception as e:
# 未知错误 (500)
api_logger.error(f"Unexpected error in export: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "OWL文件导出失败", str(e))
# ==================== 本体场景管理接口 ====================
@@ -893,3 +755,370 @@ async def get_class(
"""
from app.controllers.ontology_secondary_routes import get_class_handler
return await get_class_handler(class_id, db, current_user)
# ==================== OWL 导入接口 ====================
@router.post("/import", response_model=ApiResponse)
async def import_owl_file(
scene_name: str = Form(..., description="场景名称"),
scene_description: Optional[str] = Form(None, description="场景描述(可选)"),
file: UploadFile = File(..., description="OWL/TTL文件"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""导入 OWL/TTL 文件并创建新场景
上传 OWL 或 TTL 文件解析其中定义的本体类型owl:Class
解析成功后创建新场景,并将类型保存到该场景的 ontology_class 表中。
文件格式根据文件扩展名自动识别:
- .owl, .rdf, .xml -> rdfxml 格式
- .ttl -> turtle 格式
Args:
scene_name: 场景名称(表单字段)
scene_description: 场景描述(表单字段,可选)
file: 上传的文件(支持 .owl, .ttl, .rdf, .xml
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含导入结果
"""
from app.repositories.ontology_scene_repository import OntologySceneRepository
from app.repositories.ontology_class_repository import OntologyClassRepository
# 根据文件扩展名确定格式
filename = file.filename.lower() if file.filename else ""
if filename.endswith('.ttl'):
owl_format = "turtle"
file_type = "ttl"
elif filename.endswith(('.owl', '.rdf', '.xml')):
owl_format = "rdfxml"
file_type = "owl"
else:
return fail(
BizCode.BAD_REQUEST,
"文件格式不支持",
f"不支持的文件格式: {filename},支持的格式: .owl, .ttl, .rdf, .xml"
)
api_logger.info(
f"OWL import requested by user {current_user.id}, "
f"scene_name={scene_name}, "
f"filename={file.filename}, "
f"format={owl_format}"
)
try:
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 1. 验证场景名称不为空
if not scene_name or not scene_name.strip():
return fail(BizCode.BAD_REQUEST, "请求参数无效", "场景名称不能为空")
scene_name = scene_name.strip()
# 2. 检查场景名称是否已存在
scene_repo = OntologySceneRepository(db)
existing_scene = scene_repo.get_by_name(scene_name, workspace_id)
if existing_scene:
api_logger.warning(f"Scene name already exists: {scene_name}")
return fail(
BizCode.BAD_REQUEST,
"场景名称已存在",
f"工作空间下已存在名为 '{scene_name}' 的场景"
)
# 3. 读取文件内容
try:
content = await file.read()
owl_content = content.decode('utf-8')
except UnicodeDecodeError:
return fail(
BizCode.BAD_REQUEST,
f"{file_type}文件导入失败",
"文件编码错误,请确保文件使用 UTF-8 编码"
)
# 4. 解析 OWL 内容(先解析,成功后再创建场景)
owl_validator = OWLValidator()
parsed_classes = owl_validator.parse_owl_content(
owl_content=owl_content,
format=owl_format
)
if not parsed_classes:
api_logger.warning("No classes found in OWL content")
return fail(
BizCode.BAD_REQUEST,
"未找到本体类型",
"文件中没有定义任何本体类型owl:Class"
)
# 5. 文件解析成功,创建场景
scene = scene_repo.create(
scene_data={
"scene_name": scene_name,
"scene_description": scene_description
},
workspace_id=workspace_id
)
scene_uuid = scene.scene_id
api_logger.info(f"Scene created for import: {scene_uuid}")
# 6. 批量创建类型(去重同一批次内的重复类型)
class_repo = OntologyClassRepository(db)
created_items = []
existing_names = set()
skipped_count = 0
for cls in parsed_classes:
class_name = cls["name"]
class_description = cls.get("description")
# 检查同一批次内是否重复
if class_name in existing_names:
skipped_count += 1
api_logger.debug(f"Skipping duplicate class in batch: {class_name}")
continue
# 创建类型
ontology_class = class_repo.create(
class_data={
"class_name": class_name,
"class_description": class_description
},
scene_id=scene_uuid
)
# 添加到已存在集合,防止同一批次内重复
existing_names.add(class_name)
created_items.append(ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
))
# 7. 提交事务
db.commit()
# 8. 构建响应
response = ImportOwlResponse(
scene_id=scene_uuid,
scene_name=scene.scene_name,
imported_count=len(created_items),
skipped_count=skipped_count,
items=created_items
)
api_logger.info(
f"{file_type} import completed, "
f"scene_id={scene_uuid}, "
f"scene_name={scene_name}, "
f"format={owl_format}, "
f"imported={len(created_items)}, "
f"skipped={skipped_count}"
)
return success(data=response.model_dump(), msg=f"{file_type}文件导入成功")
except ValueError as e:
db.rollback()
api_logger.warning(f"Validation error in import: {str(e)}")
return fail(BizCode.BAD_REQUEST, f"{file_type}文件导入失败", str(e))
except Exception as e:
db.rollback()
api_logger.error(f"Unexpected error in import: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, f"{file_type}文件导入失败", str(e))
# ==================== OWL 导出接口 ====================
@router.post("/export")
async def export_owl_by_scene(
request: ExportBySceneRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""按场景导出OWL/TTL文件
根据scene_id从数据库查询该场景下的所有本体类型并导出为文件下载。
Args:
request: 导出请求,包含 scene_id 和 format
db: 数据库会话
current_user: 当前用户
Returns:
StreamingResponse: 文件流响应,浏览器会直接下载文件
"""
from uuid import UUID
from app.repositories.ontology_scene_repository import OntologySceneRepository
from app.repositories.ontology_class_repository import OntologyClassRepository
api_logger.info(
f"OWL export by scene requested by user {current_user.id}, "
f"scene_id={request.scene_id}, "
f"format={request.format}"
)
try:
# 验证格式参数
valid_formats = ["rdfxml", "turtle"]
owl_format = request.format.lower() if request.format else "rdfxml"
if owl_format not in valid_formats:
api_logger.warning(f"Invalid format: {request.format}")
return fail(
BizCode.BAD_REQUEST,
"格式参数无效",
f"不支持的格式: {request.format},支持的格式: rdfxml, turtle"
)
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 1. 查询场景信息
scene_repo = OntologySceneRepository(db)
scene = scene_repo.get_by_id(request.scene_id)
if not scene:
api_logger.warning(f"Scene not found: {request.scene_id}")
return fail(BizCode.NOT_FOUND, "场景不存在", f"找不到场景: {request.scene_id}")
# 验证场景属于当前工作空间
if scene.workspace_id != workspace_id:
api_logger.warning(
f"Scene {request.scene_id} does not belong to workspace {workspace_id}"
)
return fail(BizCode.FORBIDDEN, "无权访问", "该场景不属于当前工作空间")
# 2. 查询场景下的所有本体类型
class_repo = OntologyClassRepository(db)
ontology_classes_db = class_repo.get_by_scene(request.scene_id)
if not ontology_classes_db:
api_logger.warning(f"No classes found in scene: {request.scene_id}")
return fail(BizCode.BAD_REQUEST, "场景为空", "该场景下没有定义任何本体类型")
# 3. 将数据库模型转换为OWL导出所需的OntologyClass格式
ontology_classes: List[OntologyClass] = []
for db_class in ontology_classes_db:
owl_class = OntologyClass(
id=str(db_class.class_id),
name=db_class.class_name,
name_chinese=db_class.class_name if _is_chinese(db_class.class_name) else None,
description=db_class.class_description or "",
examples=[],
parent_class=None,
entity_type="Concept",
domain=scene.scene_name
)
ontology_classes.append(owl_class)
# 4. 确定文件名、扩展名和 MIME 类型
file_ext = ".ttl" if owl_format == "turtle" else ".owl"
filename = _sanitize_filename(scene.scene_name) + file_ext
media_type = "text/turtle" if owl_format == "turtle" else "application/rdf+xml"
file_type = "ttl" if owl_format == "turtle" else "owl"
# 5. 导出OWL文件
with tempfile.NamedTemporaryFile(
mode='w',
suffix='.owl',
delete=False
) as tmp_file:
output_path = tmp_file.name
owl_validator = OWLValidator()
# 验证本体类
is_valid, errors, world = owl_validator.validate_ontology_classes(
classes=ontology_classes,
)
if not is_valid:
logger.warning(
f"OWL validation found {len(errors)} issues during export: {errors}"
)
if not world:
error_msg = "Failed to create OWL world for export"
logger.error(error_msg)
return fail(BizCode.INTERNAL_ERROR, "创建OWL世界失败", error_msg)
# 导出OWL文件使用请求指定的格式
owl_content = owl_validator.export_to_owl(
world=world,
output_path=output_path,
format=owl_format,
classes=ontology_classes
)
api_logger.info(
f"{file_type} export by scene completed, "
f"scene={scene.scene_name}, "
f"filename={filename}, "
f"format={owl_format}, "
f"classes_count={len(ontology_classes)}"
)
# 6. 返回文件流响应
# filename 使用 ASCII 安全的默认名filename* 使用 UTF-8 编码的实际名称
ascii_filename = f"ontology{file_ext}"
encoded_filename = quote(filename)
return StreamingResponse(
io.BytesIO(owl_content.encode('utf-8')),
media_type=media_type,
headers={
"Content-Disposition": f"attachment; filename=\"{ascii_filename}\"; filename*=UTF-8''{encoded_filename}"
}
)
except ValueError as e:
api_logger.warning(f"Validation error in export by scene: {str(e)}")
file_type = "ttl" if (request.format and request.format.lower() == "turtle") else "owl"
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in export by scene: {str(e)}", exc_info=True)
file_type = "ttl" if (request.format and request.format.lower() == "turtle") else "owl"
return fail(BizCode.INTERNAL_ERROR, f"{file_type}文件导出失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in export by scene: {str(e)}", exc_info=True)
file_type = "ttl" if (request.format and request.format.lower() == "turtle") else "owl"
return fail(BizCode.INTERNAL_ERROR, f"{file_type}文件导出失败", str(e))
def _is_chinese(text: str) -> bool:
"""检查文本是否包含中文字符"""
for char in text:
if '\u4e00' <= char <= '\u9fff':
return True
return False
def _sanitize_filename(name: str) -> str:
"""清理文件名,移除不合法字符"""
import re
# 移除或替换不合法的文件名字符
sanitized = re.sub(r'[<>:"/\\|?*]', '_', name)
# 移除前后空格
sanitized = sanitized.strip()
# 如果为空,使用默认名称
if not sanitized:
sanitized = "ontology_export"
return sanitized

View File

@@ -5,7 +5,7 @@ from typing import Optional
import uuid
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.dependencies import get_current_user
from app.models import User
from app.schemas import skill_schema
from app.schemas.response_schema import PageData, PageMeta
@@ -16,7 +16,6 @@ router = APIRouter(prefix="/skills", tags=["Skills"])
@router.post("", summary="创建技能")
@cur_workspace_access_guard()
def create_skill(
data: skill_schema.SkillCreate,
db: Session = Depends(get_db),
@@ -29,7 +28,6 @@ def create_skill(
@router.get("", summary="技能列表")
@cur_workspace_access_guard()
def list_skills(
search: Optional[str] = Query(None, description="搜索关键词"),
is_active: Optional[bool] = Query(None, description="是否激活"),
@@ -51,7 +49,6 @@ def list_skills(
@router.get("/{skill_id}", summary="获取技能详情")
@cur_workspace_access_guard()
def get_skill(
skill_id: uuid.UUID,
db: Session = Depends(get_db),
@@ -64,7 +61,6 @@ def get_skill(
@router.put("/{skill_id}", summary="更新技能")
@cur_workspace_access_guard()
def update_skill(
skill_id: uuid.UUID,
data: skill_schema.SkillUpdate,
@@ -78,7 +74,6 @@ def update_skill(
@router.delete("/{skill_id}", summary="删除技能")
@cur_workspace_access_guard()
def delete_skill(
skill_id: uuid.UUID,
db: Session = Depends(get_db),

View File

@@ -221,6 +221,28 @@ class Settings:
# workflow config
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
# ========================================================================
# General Ontology Type Configuration
# ========================================================================
# 通用本体文件路径列表(逗号分隔)
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl")
# 是否启用通用本体类型功能
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
# Prompt 中最大类型数量
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
# 核心通用类型列表(逗号分隔)
CORE_GENERAL_TYPES: str = os.getenv(
"CORE_GENERAL_TYPES",
"Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building,"
"Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject"
)
# 实验模式开关(允许通过 API 动态切换本体配置)
ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true"
def get_memory_output_path(self, filename: str = "") -> str:
"""
Get the full path for memory module output files.

View File

@@ -94,6 +94,31 @@ async def write(
from app.core.memory.utils.config.config_utils import get_pipeline_config
pipeline_config = get_pipeline_config(memory_config)
# Fetch ontology types if scene_id is configured
ontology_types = None
if memory_config.scene_id:
try:
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
with get_db_context() as db:
ontology_types = load_ontology_types_for_scene(
scene_id=memory_config.scene_id,
workspace_id=memory_config.workspace_id,
db=db
)
if ontology_types:
logger.info(
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
)
else:
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
except Exception as e:
logger.warning(
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
exc_info=True
)
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,

View File

@@ -58,12 +58,25 @@ from app.core.memory.models.triplet_models import (
TripletExtractionResponse,
)
# Ontology models
from app.core.memory.models.ontology_models import (
# Ontology scenario models (LLM extracted from scenarios)
from app.core.memory.models.ontology_scenario_models import (
OntologyClass,
OntologyExtractionResponse,
)
# Ontology extraction models (for extraction flow)
from app.core.memory.models.ontology_extraction_models import (
OntologyTypeInfo,
OntologyTypeList,
)
# Ontology general models (loaded from external ontology files)
from app.core.memory.models.ontology_general_models import (
OntologyFileFormat,
GeneralOntologyType,
GeneralOntologyTypeRegistry,
)
# Variable configuration models
from app.core.memory.models.variate_config import (
StatementExtractionConfig,
@@ -114,6 +127,13 @@ __all__ = [
# Ontology models
"OntologyClass",
"OntologyExtractionResponse",
# Ontology type models for extraction flow
"OntologyTypeInfo",
"OntologyTypeList",
# General ontology type models
"OntologyFileFormat",
"GeneralOntologyType",
"GeneralOntologyTypeRegistry",
# Variable configuration
"StatementExtractionConfig",
"ForgettingEngineConfig",

View File

@@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-
"""本体类型数据结构模块
本模块定义用于在萃取流程中传递本体类型信息的轻量级数据类。
Classes:
OntologyTypeInfo: 单个本体类型信息
OntologyTypeList: 本体类型列表
"""
from dataclasses import dataclass
from typing import List
@dataclass
class OntologyTypeInfo:
"""本体类型信息,用于萃取流程中传递。
Attributes:
class_name: 类型名称
class_description: 类型描述
"""
class_name: str
class_description: str
def to_prompt_format(self) -> str:
"""转换为提示词格式。
Returns:
格式化的字符串,如 "- TypeName: Description"
"""
return f"- {self.class_name}: {self.class_description}"
@dataclass
class OntologyTypeList:
"""本体类型列表。
Attributes:
types: 本体类型信息列表
"""
types: List[OntologyTypeInfo]
@classmethod
def from_db_models(cls, ontology_classes: list) -> "OntologyTypeList":
"""从数据库模型转换创建 OntologyTypeList。
Args:
ontology_classes: OntologyClass 数据库模型列表,
每个对象应包含 class_name 和 class_description 属性
Returns:
包含转换后类型信息的 OntologyTypeList 实例
"""
types = [
OntologyTypeInfo(
class_name=oc.class_name,
class_description=oc.class_description or ""
)
for oc in ontology_classes
]
return cls(types=types)
def to_prompt_section(self) -> str:
"""转换为提示词中的类型列表部分。
Returns:
格式化的类型列表字符串,每行一个类型;
如果列表为空则返回空字符串
"""
if not self.types:
return ""
lines = [t.to_prompt_format() for t in self.types]
return "\n".join(lines)
def get_type_names(self) -> List[str]:
"""获取所有类型名称列表。
Returns:
类型名称字符串列表
"""
return [t.class_name for t in self.types]
def get_type_hierarchy_hints(self) -> List[str]:
"""获取类型层次结构提示列表。
尝试从通用本体注册表中获取每个类型的继承链信息。
Returns:
层次提示字符串列表,格式为 "类型名 → 父类1 → 父类2"
"""
hints = []
try:
from app.core.memory.ontology_services.ontology_type_merger import OntologyTypeMerger
merger = OntologyTypeMerger()
for type_info in self.types:
hint = merger.get_type_hierarchy_hint(type_info.class_name)
if hint:
hints.append(hint)
except Exception:
# 如果无法获取层次信息,返回空列表
pass
return hints

View File

@@ -0,0 +1,223 @@
# -*- coding: utf-8 -*-
"""通用本体类型数据模型模块
本模块定义用于通用本体类型管理的数据结构,包括:
- OntologyFileFormat: 本体文件格式枚举
- GeneralOntologyType: 通用本体类型数据类
- GeneralOntologyTypeRegistry: 通用本体类型注册表
Classes:
OntologyFileFormat: 本体文件格式枚举,支持 TTL、OWL/XML、RDF/XML、N-Triples、JSON-LD
GeneralOntologyType: 通用本体类型包含类名、URI、标签、描述、父类等信息
GeneralOntologyTypeRegistry: 类型注册表,管理类型集合和层次结构
"""
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Set
logger = logging.getLogger(__name__)
class OntologyFileFormat(Enum):
"""本体文件格式枚举
支持的格式:
- TURTLE: Turtle 格式 (.ttl 文件)
- RDF_XML: RDF/XML 格式 (.owl, .rdf 文件)
- N_TRIPLES: N-Triples 格式 (.nt 文件)
- JSON_LD: JSON-LD 格式 (.jsonld, .json 文件)
"""
TURTLE = "turtle" # .ttl 文件
RDF_XML = "xml" # .owl, .rdf (RDF/XML 格式)
N_TRIPLES = "nt" # .nt 文件
JSON_LD = "json-ld" # .jsonld 文件
@classmethod
def from_extension(cls, file_path: str) -> "OntologyFileFormat":
"""根据文件扩展名推断格式
Args:
file_path: 文件路径
Returns:
推断出的文件格式,默认返回 RDF_XML
"""
ext = file_path.lower().split('.')[-1]
format_map = {
'ttl': cls.TURTLE,
'owl': cls.RDF_XML,
'rdf': cls.RDF_XML,
'nt': cls.N_TRIPLES,
'jsonld': cls.JSON_LD,
'json': cls.JSON_LD,
}
return format_map.get(ext, cls.RDF_XML)
@dataclass
class GeneralOntologyType:
"""通用本体类型
表示从本体文件中解析出的类型定义,包含类型的基本信息和层次关系。
Attributes:
class_name: 类型名称,如 "Person"
class_uri: 完整 URI"http://dbpedia.org/ontology/Person"
labels: 多语言标签字典,键为语言代码(如 "en", "zh"),值为标签文本
description: 类型描述
parent_class: 父类名称,用于构建类型层次
source_file: 来源文件路径
"""
class_name: str # 类型名称,如 "Person"
class_uri: str # 完整 URI
labels: Dict[str, str] = field(default_factory=dict) # 多语言标签
description: Optional[str] = None # 类型描述
parent_class: Optional[str] = None # 父类名称
source_file: Optional[str] = None # 来源文件
def get_label(self, lang: str = "en") -> str:
"""获取指定语言的标签
优先返回指定语言的标签,如果不存在则尝试返回英文标签,
最后返回类型名称作为默认值。
Args:
lang: 语言代码,默认为 "en"
Returns:
指定语言的标签,或默认值
"""
return self.labels.get(lang, self.labels.get("en", self.class_name))
@dataclass
class GeneralOntologyTypeRegistry:
"""通用本体类型注册表
管理解析后的本体类型集合,提供类型查询、层次遍历、注册表合并等功能。
Attributes:
types: 类型字典,键为类型名称,值为 GeneralOntologyType 实例
hierarchy: 层次结构字典,键为父类名称,值为子类名称集合
source_files: 已加载的源文件路径列表
"""
types: Dict[str, GeneralOntologyType] = field(default_factory=dict)
hierarchy: Dict[str, Set[str]] = field(default_factory=dict) # 父类 -> 子类集合
source_files: List[str] = field(default_factory=list)
def get_type(self, name: str) -> Optional[GeneralOntologyType]:
"""根据名称获取类型
Args:
name: 类型名称
Returns:
对应的 GeneralOntologyType 实例,如果不存在则返回 None
"""
return self.types.get(name)
def get_ancestors(self, name: str) -> List[str]:
"""获取类型的所有祖先类型(防循环)
从当前类型开始,沿着父类链向上遍历,返回所有祖先类型名称。
使用 visited 集合防止循环引用导致的无限循环。
Args:
name: 类型名称
Returns:
祖先类型名称列表,按从近到远的顺序排列
"""
ancestors = []
current = name
visited = set()
while current and current not in visited:
visited.add(current)
type_info = self.types.get(current)
if type_info and type_info.parent_class:
# 检测循环引用
if type_info.parent_class in visited:
logger.warning(
f"检测到类型层次循环引用: {current} -> {type_info.parent_class}"
f"已遍历路径: {' -> '.join([name] + ancestors)}"
)
break
ancestors.append(type_info.parent_class)
current = type_info.parent_class
else:
break
return ancestors
def get_descendants(self, name: str) -> Set[str]:
"""获取类型的所有后代类型
从当前类型开始,沿着子类关系向下遍历,返回所有后代类型名称。
使用广度优先搜索,避免重复处理已访问的类型。
Args:
name: 类型名称
Returns:
后代类型名称集合
"""
descendants: Set[str] = set()
to_process = [name]
while to_process:
current = to_process.pop()
children = self.hierarchy.get(current, set())
new_children = children - descendants
descendants.update(new_children)
to_process.extend(new_children)
return descendants
def merge(self, other: "GeneralOntologyTypeRegistry") -> None:
"""合并另一个注册表(先加载的优先)
将另一个注册表的类型和层次结构合并到当前注册表。
对于同名类型,保留当前注册表中已存在的定义(先加载优先)。
层次结构会合并所有子类关系。
Args:
other: 要合并的另一个注册表
"""
for name, type_info in other.types.items():
if name not in self.types:
self.types[name] = type_info
for parent, children in other.hierarchy.items():
if parent not in self.hierarchy:
self.hierarchy[parent] = set()
self.hierarchy[parent].update(children)
self.source_files.extend(other.source_files)
def get_statistics(self) -> Dict[str, Any]:
"""获取注册表统计信息
Returns:
包含以下键的字典:
- total_types: 总类型数
- root_types: 根类型数(无父类的类型)
- max_depth: 类型层次的最大深度
- source_files: 源文件列表
"""
return {
"total_types": len(self.types),
"root_types": len([t for t in self.types.values() if not t.parent_class]),
"max_depth": self._calculate_max_depth(),
"source_files": self.source_files,
}
def _calculate_max_depth(self) -> int:
"""计算类型层次的最大深度
遍历所有类型,计算每个类型到根的深度,返回最大值。
Returns:
类型层次的最大深度
"""
max_depth = 0
for type_name in self.types:
depth = len(self.get_ancestors(type_name))
max_depth = max(max_depth, depth)
return max_depth

View File

@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
"""本体类型服务模块
本模块提供本体类型相关的服务,包括:
- OntologyTypeMerger: 本体类型合并服务
- get_general_ontology_registry: 获取通用本体类型注册表(单例,懒加载)
- get_ontology_type_merger: 获取类型合并服务实例
- reload_ontology_registry: 重新加载本体注册表(实验模式)
- clear_ontology_cache: 清除本体缓存
- is_general_ontology_enabled: 检查通用本体类型功能是否启用
"""
from .ontology_type_merger import OntologyTypeMerger, DEFAULT_CORE_GENERAL_TYPES
from .ontology_type_loader import (
get_general_ontology_registry,
get_ontology_type_merger,
reload_ontology_registry,
clear_ontology_cache,
is_general_ontology_enabled,
)
__all__ = [
"OntologyTypeMerger",
"DEFAULT_CORE_GENERAL_TYPES",
"get_general_ontology_registry",
"get_ontology_type_merger",
"reload_ontology_registry",
"clear_ontology_cache",
"is_general_ontology_enabled",
]

View File

@@ -0,0 +1,145 @@
"""本体类型加载器
提供统一的本体类型加载逻辑,避免代码重复。
Functions:
load_ontology_types_for_scene: 从数据库加载场景的本体类型
is_general_ontology_enabled: 检查是否启用通用本体
"""
import logging
from typing import Optional
from uuid import UUID
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
def load_ontology_types_for_scene(
scene_id: Optional[UUID],
workspace_id: UUID,
db: Session
) -> Optional["OntologyTypeList"]:
"""从数据库加载场景的本体类型
统一的本体类型加载逻辑,用于替代各处重复的加载代码。
Args:
scene_id: 场景ID如果为 None 则返回 None
workspace_id: 工作空间ID
db: 数据库会话
Returns:
OntologyTypeList 如果场景有类型定义,否则返回 None
Examples:
>>> ontology_types = load_ontology_types_for_scene(
... scene_id=scene_uuid,
... workspace_id=workspace_uuid,
... db=db_session
... )
>>> if ontology_types:
... print(f"Loaded {len(ontology_types.types)} types")
"""
if not scene_id:
return None
try:
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.repositories.ontology_class_repository import OntologyClassRepository
# 查询场景的本体类型
ontology_repo = OntologyClassRepository(db)
ontology_classes = ontology_repo.get_classes_by_scene(
scene_id=scene_id,
workspace_id=workspace_id
)
if not ontology_classes:
logger.info(f"No ontology types found for scene_id: {scene_id}")
return None
# 转换为 OntologyTypeList
ontology_types = OntologyTypeList.from_db_models(ontology_classes)
logger.info(
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {scene_id}"
)
return ontology_types
except Exception as e:
logger.error(f"Failed to load ontology types for scene_id {scene_id}: {e}", exc_info=True)
return None
def create_empty_ontology_type_list() -> Optional["OntologyTypeList"]:
"""创建空的本体类型列表(用于仅使用通用类型的场景)
Returns:
空的 OntologyTypeList 如果通用本体已启用,否则返回 None
"""
try:
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
if is_general_ontology_enabled():
logger.info("Creating empty OntologyTypeList for general types only")
return OntologyTypeList(types=[])
return None
except Exception as e:
logger.warning(f"Failed to create empty OntologyTypeList: {e}")
return None
def is_general_ontology_enabled() -> bool:
"""检查是否启用了通用本体
Returns:
True 如果通用本体已启用,否则 False
"""
try:
from app.core.memory.ontology_services.ontology_type_merger import OntologyTypeMerger
merger = OntologyTypeMerger()
return merger.general_registry is not None
except Exception as e:
logger.warning(f"Failed to check general ontology status: {e}")
return False
def load_ontology_types_with_fallback(
scene_id: Optional[UUID],
workspace_id: UUID,
db: Session,
enable_general_fallback: bool = True
) -> Optional["OntologyTypeList"]:
"""加载本体类型,如果场景没有类型则回退到通用类型
这是一个便捷函数,组合了场景类型加载和通用类型回退逻辑。
Args:
scene_id: 场景ID
workspace_id: 工作空间ID
db: 数据库会话
enable_general_fallback: 是否在没有场景类型时启用通用类型回退
Returns:
OntologyTypeList 或 None
"""
# 首先尝试加载场景类型
ontology_types = load_ontology_types_for_scene(
scene_id=scene_id,
workspace_id=workspace_id,
db=db
)
# 如果没有场景类型且启用了回退,创建空列表以使用通用类型
if ontology_types is None and enable_general_fallback:
ontology_types = create_empty_ontology_type_list()
if ontology_types:
logger.info("No scene ontology types, will use general ontology types only")
return ontology_types

View File

@@ -0,0 +1,231 @@
# -*- coding: utf-8 -*-
"""本体类型合并服务模块
本模块实现本体类型合并服务,负责按优先级合并场景类型与通用类型。
合并优先级:
1. 场景特定类型(最高优先级)
2. 核心通用类型
3. 相关父类类型(最低优先级)
Classes:
OntologyTypeMerger: 本体类型合并服务类
Constants:
DEFAULT_CORE_GENERAL_TYPES: 默认核心通用类型集合
"""
import logging
from typing import List, Optional, Set
from app.core.memory.models.ontology_general_models import GeneralOntologyTypeRegistry
from app.core.memory.models.ontology_extraction_models import OntologyTypeInfo, OntologyTypeList
logger = logging.getLogger(__name__)
# 默认核心通用类型
DEFAULT_CORE_GENERAL_TYPES: Set[str] = {
"Person", "Organization", "Company", "GovernmentAgency",
"Place", "Location", "City", "Country", "Building",
"Event", "SportsEvent", "MusicEvent", "SocialEvent",
"Work", "Book", "Film", "Software", "Album",
"Concept", "TopicalConcept", "AcademicSubject",
"Device", "Food", "Drug", "ChemicalSubstance",
"TimePeriod", "Year",
}
class OntologyTypeMerger:
"""本体类型合并服务
负责按优先级合并场景类型与通用类型,生成用于三元组提取的类型列表。
合并优先级:
1. 场景特定类型(最高优先级)- 标记为 [场景类型]
2. 核心通用类型 - 标记为 [通用类型]
3. 相关父类类型(最低优先级)- 标记为 [通用父类]
Attributes:
general_registry: 通用本体类型注册表
max_types_in_prompt: Prompt 中最大类型数量限制
core_types: 核心通用类型集合
Example:
>>> registry = GeneralOntologyTypeRegistry()
>>> merger = OntologyTypeMerger(registry, max_types_in_prompt=50)
>>> merged = merger.merge(scene_types)
>>> print(len(merged.types))
"""
def __init__(
self,
general_registry: GeneralOntologyTypeRegistry,
max_types_in_prompt: int = 50,
core_types: Optional[List[str]] = None
):
"""初始化本体类型合并服务
Args:
general_registry: 通用本体类型注册表
max_types_in_prompt: Prompt 中最大类型数量,默认 50
core_types: 自定义核心类型列表,如果为 None 则使用默认核心类型
"""
self.general_registry = general_registry
self.max_types_in_prompt = max_types_in_prompt
self.core_types: Set[str] = set(core_types) if core_types else DEFAULT_CORE_GENERAL_TYPES.copy()
def update_core_types(self, core_types: List[str]) -> None:
"""动态更新核心类型列表
更新后立即生效,无需重启服务。
Args:
core_types: 新的核心类型列表
"""
self.core_types = set(core_types)
logger.info(f"核心类型已更新: {len(self.core_types)} 个类型")
def merge(
self,
scene_types: Optional[OntologyTypeList],
include_related_types: bool = True
) -> OntologyTypeList:
"""合并场景类型与通用类型
按优先级合并类型:
1. 场景特定类型(最高优先级)
2. 核心通用类型
3. 相关父类类型(可选)
合并后的类型总数不超过 max_types_in_prompt。
Args:
scene_types: 场景特定类型列表,可以为 None
include_related_types: 是否包含相关父类类型,默认 True
Returns:
合并后的类型列表,每个类型带有来源标记
"""
merged_types: List[OntologyTypeInfo] = []
seen_names: Set[str] = set()
# 1. 场景特定类型(最高优先级)
scene_type_count = 0
if scene_types and scene_types.types:
for scene_type in scene_types.types:
if scene_type.class_name not in seen_names:
merged_types.append(OntologyTypeInfo(
class_name=scene_type.class_name,
class_description=f"[场景类型] {scene_type.class_description}"
))
seen_names.add(scene_type.class_name)
scene_type_count += 1
# 2. 核心通用类型
remaining_slots = self.max_types_in_prompt - len(merged_types)
core_types_added: List[OntologyTypeInfo] = []
for type_name in self.core_types:
if type_name not in seen_names and remaining_slots > 0:
general_type = self.general_registry.get_type(type_name)
if general_type:
description = (
general_type.labels.get("zh") or
general_type.description or
general_type.get_label("en") or
type_name
)
core_types_added.append(OntologyTypeInfo(
class_name=type_name,
class_description=f"[通用类型] {description}"
))
seen_names.add(type_name)
remaining_slots -= 1
merged_types.extend(core_types_added)
# 3. 相关父类类型
related_types_added: List[OntologyTypeInfo] = []
if include_related_types and scene_types and scene_types.types:
for scene_type in scene_types.types:
if remaining_slots <= 0:
break
general_type = self.general_registry.get_type(scene_type.class_name)
if general_type and general_type.parent_class:
parent_name = general_type.parent_class
if parent_name not in seen_names:
parent_type = self.general_registry.get_type(parent_name)
if parent_type:
description = (
parent_type.labels.get("zh") or
parent_type.description or
parent_name
)
related_types_added.append(OntologyTypeInfo(
class_name=parent_name,
class_description=f"[通用父类] {description}"
))
seen_names.add(parent_name)
remaining_slots -= 1
merged_types.extend(related_types_added)
logger.info(
f"类型合并完成: 场景类型 {scene_type_count} 个, "
f"核心通用类型 {len(core_types_added)} 个, "
f"相关类型 {len(related_types_added)} 个, "
f"总计 {len(merged_types)}"
)
return OntologyTypeList(types=merged_types)
def get_type_hierarchy_hint(self, type_name: str) -> Optional[str]:
"""获取类型的层次提示信息(最多 3 级)
返回类型的继承链信息,格式为 "类型名 → 父类1 → 父类2 → 父类3"
Args:
type_name: 类型名称
Returns:
层次提示字符串,如果类型不存在或没有父类则返回 None
"""
general_type = self.general_registry.get_type(type_name)
if not general_type:
return None
ancestors = self.general_registry.get_ancestors(type_name)
if ancestors:
# 限制最多 3 级祖先
return f"{type_name}{''.join(ancestors[:3])}"
return None
def get_merge_statistics(self, scene_types: Optional[OntologyTypeList]) -> dict:
"""获取合并统计信息
执行合并操作并返回各类型来源的数量统计。
Args:
scene_types: 场景特定类型列表
Returns:
包含以下键的统计字典:
- total_types: 合并后总类型数
- scene_types: 场景类型数量
- general_types: 通用类型数量
- parent_types: 父类类型数量
- available_core_types: 可用核心类型数量
- registry_total_types: 注册表中总类型数
"""
merged = self.merge(scene_types)
scene_count = sum(1 for t in merged.types if "[场景类型]" in t.class_description)
general_count = sum(1 for t in merged.types if "[通用类型]" in t.class_description)
parent_count = sum(1 for t in merged.types if "[通用父类]" in t.class_description)
return {
"total_types": len(merged.types),
"scene_types": scene_count,
"general_types": general_count,
"parent_types": parent_count,
"available_core_types": len(self.core_types),
"registry_total_types": len(self.general_registry.types),
}

View File

@@ -34,6 +34,8 @@ from app.core.memory.models.graph_models import (
StatementNode,
)
from app.core.memory.models.message_models import DialogData
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.core.memory.models.variate_config import (
ExtractionPipelineConfig,
)
@@ -95,6 +97,8 @@ class ExtractionOrchestrator:
config: Optional[ExtractionPipelineConfig] = None,
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
embedding_id: Optional[str] = None,
ontology_types: Optional[OntologyTypeList] = None,
enable_general_types: bool = True,
language: str = "zh",
):
"""
@@ -119,6 +123,29 @@ class ExtractionOrchestrator:
self.progress_callback = progress_callback # 保存进度回调函数
self.embedding_id = embedding_id # 保存嵌入模型ID
self.language = language # 保存语言配置
# 处理本体类型配置
# 根据 enable_general_types 参数决定是否将通用本体类型与场景特定类型合并
# 如果启用合并且配置中开启了通用本体功能,则使用 OntologyTypeMerger 进行融合
if enable_general_types and ontology_types:
from app.core.memory.ontology_services.ontology_type_loader import (
get_ontology_type_merger,
is_general_ontology_enabled,
)
if is_general_ontology_enabled():
merger = get_ontology_type_merger()
self.ontology_types = merger.merge(ontology_types)
logger.info(
f"已启用通用本体类型融合: 场景类型 {len(ontology_types.types) if ontology_types.types else 0} 个 -> "
f"合并后 {len(self.ontology_types.types) if self.ontology_types.types else 0}"
)
else:
self.ontology_types = ontology_types
logger.info("通用本体类型功能已在配置中禁用,仅使用场景类型")
else:
self.ontology_types = ontology_types
if not enable_general_types and ontology_types:
logger.info("enable_general_types=False仅使用场景类型")
# 保存去重消歧的详细记录(内存中的数据结构)
self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录
@@ -130,7 +157,7 @@ class ExtractionOrchestrator:
llm_client=llm_client,
config=self.config.statement_extraction,
)
self.triplet_extractor = TripletExtractor(llm_client=llm_client, language=language)
self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language)
self.temporal_extractor = TemporalExtractor(llm_client=llm_client)
logger.info("ExtractionOrchestrator 初始化完成")

View File

@@ -14,7 +14,7 @@ import time
from typing import List, Optional
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.models.ontology_models import (
from app.core.memory.models.ontology_scenario_models import (
OntologyClass,
OntologyExtractionResponse,
)
@@ -118,7 +118,7 @@ class OntologyExtractor:
logger.info(
f"Starting ontology extraction - scenario_length={len(scenario)}, "
f"domain={domain}, max_classes={max_classes}, min_classes={min_classes}, "
f"timeout={timeout}"
f"timeout={timeout}, language={language}"
)
try:

View File

@@ -1,6 +1,6 @@
import os
import asyncio
from typing import List, Dict
from typing import List, Dict, Optional
from app.core.logging_config import get_memory_logger
from app.core.memory.llm_tools.openai_client import OpenAIClient
@@ -8,6 +8,7 @@ from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
from app.core.memory.models.triplet_models import TripletExtractionResponse
from app.core.memory.models.message_models import DialogData, Statement
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.core.memory.utils.log.logging_utils import prompt_logger
logger = get_memory_logger(__name__)
@@ -17,14 +18,21 @@ logger = get_memory_logger(__name__)
class TripletExtractor:
"""Extracts knowledge triplets and entities from statements using LLM"""
def __init__(self, llm_client: OpenAIClient, language: str = "zh"):
def __init__(
self,
llm_client: OpenAIClient,
ontology_types: Optional[OntologyTypeList] = None,
language: str = "zh"):
"""Initialize the TripletExtractor with an LLM client
Args:
llm_client: OpenAIClient instance for processing
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
ontology_types: Optional OntologyTypeList containing predefined ontology types
for entity classification guidance
"""
self.llm_client = llm_client
self.ontology_types = ontology_types
self.language = language
def _get_language(self) -> str:
@@ -51,7 +59,8 @@ class TripletExtractor:
chunk_content=chunk_content,
json_schema=TripletExtractionResponse.model_json_schema(),
predicate_instructions=PREDICATE_DEFINITIONS,
language=self._get_language()
language=self._get_language(),
ontology_types=self.ontology_types,
)
# Create messages for LLM

View File

@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
"""本体解析工具模块
本模块提供本体文件解析功能,支持多种 RDF 格式的本体文件解析。
Modules:
ontology_parser: 本体文件解析器
"""
from .ontology_parser import MultiOntologyParser, OntologyParser
__all__ = ["OntologyParser", "MultiOntologyParser"]

View File

@@ -0,0 +1,366 @@
# -*- coding: utf-8 -*-
"""本体文件解析器模块
本模块提供统一的本体文件解析功能,支持多种 RDF 格式:
- Turtle (.ttl)
- OWL/XML (.owl)
- RDF/XML (.rdf)
- N-Triples (.nt)
- JSON-LD (.jsonld)
解析器会自动根据文件扩展名推断格式,并在解析失败时尝试其他格式。
解析结果包含类定义的名称、URI、多语言标签、描述和父类信息。
Classes:
OntologyParser: 统一本体文件解析器
MultiOntologyParser: 多本体文件解析器
Example:
>>> parser = OntologyParser("ontology.ttl")
>>> registry = parser.parse()
>>> print(f"解析了 {len(registry.types)} 个类型")
>>> multi_parser = MultiOntologyParser(["ontology1.ttl", "ontology2.owl"])
>>> merged_registry = multi_parser.parse_all()
>>> print(f"合并后共 {len(merged_registry.types)} 个类型")
"""
import logging
import re
from typing import List, Optional
from rdflib import OWL, RDF, RDFS, Graph, URIRef
from app.core.memory.models.ontology_general_models import (
GeneralOntologyType,
GeneralOntologyTypeRegistry,
OntologyFileFormat,
)
logger = logging.getLogger(__name__)
class OntologyParser:
"""统一本体文件解析器
解析本体文件并提取类定义,构建类型注册表。支持多种 RDF 格式,
并提供格式自动推断和回退机制。
Attributes:
file_path: 本体文件路径
file_format: 文件格式,如果未指定则根据扩展名推断
graph: rdflib Graph 实例,用于存储解析后的 RDF 数据
Example:
>>> parser = OntologyParser("dbpedia.owl")
>>> registry = parser.parse()
>>> person_type = registry.get_type("Person")
>>> if person_type:
... print(f"Person URI: {person_type.class_uri}")
"""
def __init__(
self,
file_path: str,
file_format: Optional[OntologyFileFormat] = None,
):
"""初始化解析器
Args:
file_path: 本体文件路径
file_format: 文件格式,如果未指定则根据扩展名自动推断
"""
self.file_path = file_path
self.file_format = file_format or OntologyFileFormat.from_extension(file_path)
self.graph = Graph()
def parse(self) -> GeneralOntologyTypeRegistry:
"""解析本体文件,返回类型注册表
首先尝试使用推断的格式解析文件,如果失败则尝试其他格式。
解析成功后,遍历所有 owl:Class 和 rdfs:Class 定义,
提取类信息并构建层次结构。
Returns:
GeneralOntologyTypeRegistry: 包含所有解析出的类型和层次结构的注册表
Raises:
ValueError: 当所有格式都无法解析文件时抛出
"""
logger.info(f"开始解析本体文件: {self.file_path}")
# 尝试解析,失败则尝试其他格式
self._parse_with_fallback()
registry = GeneralOntologyTypeRegistry()
registry.source_files.append(self.file_path)
# 遍历 owl:Class
for class_uri in self.graph.subjects(RDF.type, OWL.Class):
type_info = self._parse_class(class_uri)
if type_info:
registry.types[type_info.class_name] = type_info
self._update_hierarchy(registry, type_info)
# 遍历 rdfs:Class避免重复
for class_uri in self.graph.subjects(RDF.type, RDFS.Class):
uri_str = str(class_uri)
# 检查是否已经作为 owl:Class 解析过
if uri_str not in [t.class_uri for t in registry.types.values()]:
type_info = self._parse_class(class_uri)
if type_info and type_info.class_name not in registry.types:
registry.types[type_info.class_name] = type_info
self._update_hierarchy(registry, type_info)
logger.info(f"本体解析完成: {len(registry.types)} 个类型")
return registry
def _parse_with_fallback(self) -> None:
"""尝试解析文件,失败时尝试其他格式
首先使用推断的格式解析,如果失败则依次尝试 RDF_XML 和 TURTLE 格式。
Raises:
ValueError: 当所有格式都无法解析文件时抛出
"""
try:
self.graph.parse(self.file_path, format=self.file_format.value)
return
except Exception as e:
logger.warning(f"使用 {self.file_format.value} 格式解析失败: {e}")
# 尝试其他格式
fallback_formats = [
OntologyFileFormat.RDF_XML,
OntologyFileFormat.TURTLE,
OntologyFileFormat.N_TRIPLES,
OntologyFileFormat.JSON_LD,
]
for fmt in fallback_formats:
if fmt != self.file_format:
try:
self.graph.parse(self.file_path, format=fmt.value)
logger.info(f"使用回退格式 {fmt.value} 解析成功")
return
except Exception:
continue
raise ValueError(f"无法解析本体文件: {self.file_path}")
def _update_hierarchy(
self,
registry: GeneralOntologyTypeRegistry,
type_info: GeneralOntologyType
) -> None:
"""更新层次结构
如果类型有父类,将其添加到层次结构中。
Args:
registry: 类型注册表
type_info: 类型信息
"""
if type_info.parent_class:
if type_info.parent_class not in registry.hierarchy:
registry.hierarchy[type_info.parent_class] = set()
registry.hierarchy[type_info.parent_class].add(type_info.class_name)
def _parse_class(self, class_uri: URIRef) -> Optional[GeneralOntologyType]:
"""解析单个类定义
从 RDF 图中提取类的名称、URI、标签、描述和父类信息。
过滤空白节点和内置类型Thing、Resource
Args:
class_uri: 类的 URI 引用
Returns:
GeneralOntologyType 实例,如果应该跳过该类则返回 None
"""
uri_str = str(class_uri)
class_name = self._extract_local_name(uri_str)
# 过滤空白节点和内置类型
if not class_name:
return None
if class_name.startswith('_:'):
return None
if class_name in ('Thing', 'Resource'):
return None
# 过滤空白节点 URI以 _: 开头或包含空白节点标识)
if uri_str.startswith('_:'):
return None
# 提取标签
labels = self._extract_labels(class_uri)
# 提取描述
description = self._extract_description(class_uri)
# 提取父类
parent_class = self._extract_parent_class(class_uri)
return GeneralOntologyType(
class_name=class_name,
class_uri=uri_str,
labels=labels,
description=description,
parent_class=parent_class,
source_file=self.file_path
)
def _extract_labels(self, class_uri: URIRef) -> dict:
"""提取类的多语言标签
从 rdfs:label 属性中提取所有语言的标签。
如果没有标签,使用类名作为英文标签。
Args:
class_uri: 类的 URI 引用
Returns:
语言代码到标签文本的字典
"""
labels = {}
for label in self.graph.objects(class_uri, RDFS.label):
lang = getattr(label, 'language', None) or "en"
labels[lang] = str(label)
# 如果没有标签,使用类名作为默认标签
if not labels:
class_name = self._extract_local_name(str(class_uri))
if class_name:
labels["en"] = class_name
return labels
def _extract_description(self, class_uri: URIRef) -> Optional[str]:
"""提取类的描述
从 rdfs:comment 属性中提取描述,优先使用英文描述。
Args:
class_uri: 类的 URI 引用
Returns:
类的描述文本,如果没有则返回 None
"""
description = None
for comment in self.graph.objects(class_uri, RDFS.comment):
lang = getattr(comment, 'language', None)
# 优先使用英文描述
if lang == "en":
return str(comment)
# 如果还没有描述,使用无语言标记或其他语言的描述
if description is None:
description = str(comment)
return description
def _extract_parent_class(self, class_uri: URIRef) -> Optional[str]:
"""提取类的父类
从 rdfs:subClassOf 属性中提取第一个有效的父类。
过滤内置类型Thing、Resource和空白节点。
Args:
class_uri: 类的 URI 引用
Returns:
父类名称,如果没有有效父类则返回 None
"""
for parent_uri in self.graph.objects(class_uri, RDFS.subClassOf):
parent_uri_str = str(parent_uri)
# 跳过空白节点
if parent_uri_str.startswith('_:'):
continue
parent_name = self._extract_local_name(parent_uri_str)
# 过滤内置类型
if parent_name and parent_name not in ('Thing', 'Resource'):
return parent_name
return None
def _extract_local_name(self, uri: str) -> Optional[str]:
"""从 URI 中提取本地名称
支持两种常见的 URI 格式:
1. 使用 # 分隔的 URI如 http://example.org/ontology#Person
2. 使用 / 分隔的 URI如 http://dbpedia.org/ontology/Person
Args:
uri: 完整的 URI 字符串
Returns:
本地名称,如果无法提取则返回 None
"""
# 处理空白节点
if uri.startswith('_:'):
return None
# 尝试使用 # 分隔
if '#' in uri:
local_name = uri.rsplit('#', 1)[1]
if local_name:
return local_name
# 尝试使用 / 分隔
if '/' in uri:
local_name = uri.rsplit('/', 1)[1]
if local_name:
return local_name
# 使用正则表达式作为最后手段
match = re.search(r'[#/]([^#/]+)$', uri)
return match.group(1) if match else None
class MultiOntologyParser:
"""多本体文件解析器
支持加载多个本体文件并将它们合并到一个统一的类型注册表中。
先加载的文件中的类型定义优先保留(当存在同名类型时)。
Attributes:
file_paths: 本体文件路径列表
Example:
>>> parser = MultiOntologyParser([
... "General_purpose_entity.ttl",
... "domain_specific.owl"
... ])
>>> registry = parser.parse_all()
>>> print(f"合并后共 {len(registry.types)} 个类型")
"""
def __init__(self, file_paths: List[str]):
"""初始化多文件解析器
Args:
file_paths: 本体文件路径列表
"""
self.file_paths = file_paths
def parse_all(self) -> GeneralOntologyTypeRegistry:
"""解析所有本体文件并合并
依次解析每个本体文件,并将结果合并到一个统一的注册表中。
如果某个文件解析失败,会记录警告日志并跳过该文件继续处理。
Returns:
GeneralOntologyTypeRegistry: 合并后的类型注册表
"""
merged_registry = GeneralOntologyTypeRegistry()
for file_path in self.file_paths:
try:
parser = OntologyParser(file_path)
registry = parser.parse()
merged_registry.merge(registry)
logger.info(f"已合并本体文件: {file_path}")
except Exception as e:
logger.warning(f"跳过无法解析的本体文件 {file_path}: {e}")
logger.info(f"多本体合并完成: 共 {len(merged_registry.types)} 个类型")
return merged_registry

View File

@@ -9,22 +9,29 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
prompt_dir = os.path.join(current_dir, "prompts")
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
async def get_prompts(message: str) -> list[dict]:
async def get_prompts(message: str, language: str = "zh") -> list[dict]:
"""
Renders system and user prompts using Jinja2 templates.
Args:
message: The message content
language: Language for output ("zh" for Chinese, "en" for English)
Returns:
List of message dictionaries with role and content
"""
system_template = prompt_env.get_template("system.jinja2")
user_template = prompt_env.get_template("user.jinja2")
system_prompt = system_template.render()
user_prompt = user_template.render(message=message)
system_prompt = system_template.render(language=language)
user_prompt = user_template.render(message=message, language=language)
# 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('system', system_prompt)
log_prompt_rendering('user', user_prompt)
# 可选:记录模板渲染信息(仅当 prompt_templates.log 存在时生效)
log_template_rendering('system.jinja2', {})
log_template_rendering('user.jinja2', {'message': message})
log_template_rendering('system.jinja2', {'language': language})
log_template_rendering('user.jinja2', {'message': message, 'language': language})
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
@@ -38,6 +45,7 @@ async def render_statement_extraction_prompt(
include_dialogue_context: bool = False,
dialogue_content: str | None = None,
max_dialogue_chars: int | None = None,
language: str = "zh",
) -> str:
"""
Renders the statement extraction prompt using the extract_statement.jinja2 template.
@@ -46,6 +54,11 @@ async def render_statement_extraction_prompt(
chunk_content: The content of the chunk to process
definitions: Label definitions for statement classification
json_schema: JSON schema for the expected output format
granularity: Extraction granularity level (1-3)
include_dialogue_context: Whether to include full dialogue context
dialogue_content: Full dialogue content for context
max_dialogue_chars: Maximum characters for dialogue context
language: Language for output ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string
@@ -69,6 +82,7 @@ async def render_statement_extraction_prompt(
granularity=granularity,
include_dialogue_context=include_dialogue_context,
dialogue_context=ctx,
language=language,
)
# 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('statement extraction', rendered_prompt)
@@ -90,6 +104,7 @@ async def render_temporal_extraction_prompt(
temporal_guide: dict,
statement_guide: dict,
json_schema: dict,
language: str = "zh",
) -> str:
"""
Renders the temporal extraction prompt using the extract_temporal.jinja2 template.
@@ -100,6 +115,7 @@ async def render_temporal_extraction_prompt(
temporal_guide: Guidance on temporal types.
statement_guide: Guidance on statement types.
json_schema: JSON schema for the expected output format.
language: Language for output ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as a string.
@@ -111,6 +127,7 @@ async def render_temporal_extraction_prompt(
temporal_guide=temporal_guide,
statement_guide=statement_guide,
json_schema=json_schema,
language=language,
)
# 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('temporal extraction', rendered_prompt)
@@ -130,6 +147,7 @@ def render_entity_dedup_prompt(
context: dict,
json_schema: dict,
disambiguation_mode: bool = False,
language: str = "zh",
) -> str:
"""
Render the entity deduplication prompt using the entity_dedup.jinja2 template.
@@ -139,6 +157,8 @@ def render_entity_dedup_prompt(
entity_b: Dict of entity B attributes
context: Dict of computed signals (group/type gate, similarities, co-occurrence, relation statements)
json_schema: JSON schema for the structured output (EntityDedupDecision)
disambiguation_mode: Whether to use disambiguation mode
language: Language for output ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string
@@ -157,6 +177,7 @@ def render_entity_dedup_prompt(
relation_statements=context.get("relation_statements", []),
json_schema=json_schema,
disambiguation_mode=disambiguation_mode,
language=language,
)
# prompt_logger.info("\n=== RENDERED ENTITY DEDUP PROMPT ===")
@@ -177,7 +198,14 @@ def render_entity_dedup_prompt(
# Args:
# entity_a: Dict of entity A attributes
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None, language: str = "zh") -> str:
async def render_triplet_extraction_prompt(
statement: str,
chunk_content: str,
json_schema: dict,
predicate_instructions: dict = None,
language: str = "zh",
ontology_types: "OntologyTypeList | None" = None,
) -> str:
"""
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
@@ -187,17 +215,31 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
json_schema: JSON schema for the expected output format
predicate_instructions: Optional predicate instructions
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
ontology_types: Optional OntologyTypeList containing predefined ontology types for entity classification
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("extract_triplet.jinja2")
# 准备本体类型数据
ontology_type_section = ""
ontology_type_names = []
type_hierarchy_hints = []
if ontology_types and ontology_types.types:
ontology_type_section = ontology_types.to_prompt_section()
ontology_type_names = ontology_types.get_type_names()
type_hierarchy_hints = ontology_types.get_type_hierarchy_hints()
rendered_prompt = template.render(
statement=statement,
chunk_content=chunk_content,
json_schema=json_schema,
predicate_instructions=predicate_instructions,
language=language
language=language,
ontology_types=ontology_type_section,
ontology_type_names=ontology_type_names,
type_hierarchy_hints=type_hierarchy_hints,
)
# 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('triplet extraction', rendered_prompt)
@@ -207,7 +249,10 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
'chunk_content': 'str',
'json_schema': 'TripletExtractionResponse.schema',
'predicate_instructions': 'PREDICATE_DEFINITIONS',
'language': language
'language': language,
'ontology_types': bool(ontology_type_section),
'ontology_type_count': len(ontology_type_names),
'type_hierarchy_hints_count': len(type_hierarchy_hints),
})
return rendered_prompt
@@ -249,7 +294,8 @@ async def render_memory_summary_prompt(
async def render_emotion_extraction_prompt(
statement: str,
extract_keywords: bool,
enable_subject: bool
enable_subject: bool,
language: str = "zh"
) -> str:
"""
Renders the emotion extraction prompt using the extract_emotion.jinja2 template.
@@ -258,6 +304,7 @@ async def render_emotion_extraction_prompt(
statement: The statement to analyze
extract_keywords: Whether to extract emotion keywords
enable_subject: Whether to enable subject classification
language: Language for output ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string
@@ -266,7 +313,8 @@ async def render_emotion_extraction_prompt(
rendered_prompt = template.render(
statement=statement,
extract_keywords=extract_keywords,
enable_subject=enable_subject
enable_subject=enable_subject,
language=language
)
# 记录渲染结果到提示日志
@@ -467,7 +515,8 @@ async def render_ontology_extraction_prompt(
'scenario_len': len(scenario) if scenario else 0,
'domain': domain,
'max_classes': max_classes,
'json_schema': 'OntologyExtractionResponse.schema'
'json_schema': 'OntologyExtractionResponse.schema',
'language': language
})
return rendered_prompt

View File

@@ -1,9 +1,16 @@
===任务===
===Task===
{% if language == "zh" %}
你是一个实体去重/消歧判断助手。你将被提供两个实体的详细信息和上下文,请严格根据指引判断它们是否是同一真实世界实体,并在需要时进行类型消歧。
模式: {{ '消歧模式' if disambiguation_mode else '去重模式' }}
{% else %}
You are an entity deduplication/disambiguation assistant. You will be provided with detailed information and context for two entities. Please strictly follow the guidelines to determine whether they are the same real-world entity and perform type disambiguation when necessary.
===输入===
Mode: {{ 'Disambiguation Mode' if disambiguation_mode else 'Deduplication Mode' }}
{% endif %}
===Input===
{% if language == "zh" %}
实体A:
- 名称: "{{ entity_a.name | default('') }}"
- 类型: "{{ entity_a.entity_type | default('') }}"
@@ -34,8 +41,41 @@
{% for s in relation_statements %}
- {{ s }}
{% endfor %}
{% else %}
Entity A:
- Name: "{{ entity_a.name | default('') }}"
- Type: "{{ entity_a.entity_type | default('') }}"
- Description: "{{ entity_a.description | default('') }}"
- Aliases: {{ entity_a.aliases | default([]) }}
{# TODO: fact_summary feature temporarily disabled, to be enabled after future development #}
{# - Summary: "{{ entity_a.fact_summary | default('') }}" #}
- Connection Strength: "{{ entity_a.connect_strength | default('') }}"
===判定指引===
Entity B:
- Name: "{{ entity_b.name | default('') }}"
- Type: "{{ entity_b.entity_type | default('') }}"
- Description: "{{ entity_b.description | default('') }}"
- Aliases: {{ entity_b.aliases | default([]) }}
{# TODO: fact_summary feature temporarily disabled, to be enabled after future development #}
{# - Summary: "{{ entity_b.fact_summary | default('') }}" #}
- Connection Strength: "{{ entity_b.connect_strength | default('') }}"
Context:
- Same Group: {{ same_group | default(false) }}
- Type Consistent or Unknown: {{ type_ok | default(false) }}
- Type Similarity (0-1): {{ type_similarity | default(0.0) }}
- Name Text Similarity (0-1): {{ name_text_sim | default(0.0) }}
- Name Embedding Similarity (0-1): {{ name_embed_sim | default(0.0) }}
- Name Contains Relationship: {{ name_contains | default(false) }}
- Context Co-occurrence (same statement refers to both): {{ co_occurrence | default(false) }}
- Related Relationship Statements (from entity-entity edges):
{% for s in relation_statements %}
- {{ s }}
{% endfor %}
{% endif %}
===Guidelines===
{% if language == "zh" %}
{% if disambiguation_mode %}
- 这是"同名但类型不同"的消歧场景。请判断两者是否指向同一真实世界实体。
- 综合名称文本/向量相似度、别名、描述、摘要与上下文关系(同源与关系陈述)进行判断。
@@ -68,8 +108,43 @@
- 优先保留连接强度更强(strong/both)者;其余相同则保留描述/摘要更丰富者再相同时保留实体Acanonical_idx=0
- **注意**别名aliases已在三元组提取阶段获取合并时会自动整合无需在此阶段提取。
{% endif %}
{% else %}
{% if disambiguation_mode %}
- This is a disambiguation scenario for "same name but different types". Please determine whether they refer to the same real-world entity.
- Make judgments based on name text/vector similarity, aliases, descriptions, summaries, and contextual relationships (co-occurrence and relationship statements).
- **Alias Handling (High Priority)**:
* If the alias lists of both entities have intersections, this is a strong signal of identity
* If one entity's name appears in another entity's aliases, it should be considered a high-confidence match
* If one entity's alias exactly matches another entity's name, it should be considered a high-confidence match
* Alias matching weight should be higher than pure name text similarity
- If unable to determine with sufficient confidence, handle conservatively: do not merge, and suggest blocking this pair in other fuzzy/heuristic merges (block_pair=true).
- If merging is needed (should_merge=true), select the "canonical entity" (canonical_idx) and **must** provide a suggested unified type (suggested_type).
- **Type Unification Principles (Important)**:
* Prioritize more specific and accurate types (e.g., HistoricalPeriod over Organization, MilitaryCapability over Concept)
* If both types are specific but different, choose the type that best matches the entity's core semantics
* Generic types (Concept, Phenomenon, Condition, State, Attribute, Event) have lower priority than domain-specific types
* Suggested type must be consistent with context and entity description
- Canonical entity priority: higher connection strength (strong/both); if equal, retain the one with richer description/summary; if still equal, retain Entity A (canonical_idx=0).
- **Note**: Aliases are already obtained during triplet extraction and will be automatically integrated during merging; no need to extract at this stage.
{% else %}
- If entity types are the same or either is UNKNOWN/empty, can proceed as candidates; if types clearly conflict (e.g., person vs. item), unless aliases and descriptions are highly consistent, determine as different entities.
- **Alias Matching Priority (Highest Priority)**:
* If Entity A's name exactly matches any of Entity B's aliases, it should be considered a high-confidence match
* If Entity B's name exactly matches any of Entity A's aliases, it should be considered a high-confidence match
* If any alias of Entity A exactly matches any alias of Entity B, it should be considered a high-confidence match
* When aliases match exactly, merging should be considered even if name text similarity is low
* Alias matching confidence should be higher than pure name similarity matching
- Make judgments based on name text/vector similarity, aliases, descriptions, summaries, and contextual relationships.
- When context co-occurs or there are clear relationship statements supporting identity (e.g., the same object is repeatedly mentioned or aliases correspond), the judgment threshold can be moderately lowered.
- Conservative decision: when unable to determine with sufficient confidence, do not merge (same_entity=false).
- If merging is needed, select the "canonical entity to retain" (canonical_idx) as the more appropriate one:
- Prioritize retaining the one with stronger connection strength (strong/both); if equal, retain the one with richer description/summary; if still equal, retain Entity A (canonical_idx=0).
- **Note**: Aliases are already obtained during triplet extraction and will be automatically integrated during merging; no need to extract at this stage.
{% endif %}
{% endif %}
**Output format**
{% if language == "zh" %}
{% if disambiguation_mode %}
返回JSON格式必须包含以下字段
{
@@ -103,6 +178,41 @@
- confidence: 决策的置信度范围0.0-1.0
- reason: 决策理由的简短说明
{% endif %}
{% else %}
{% if disambiguation_mode %}
Return JSON format with the following required fields:
{
"should_merge": boolean,
"canonical_idx": 0 or 1,
"confidence": float (0.0-1.0),
"block_pair": boolean,
"suggested_type": "string or null",
"reason": "string"
}
**Field Descriptions**:
- should_merge: Whether these two entities should be merged (true/false)
- canonical_idx: Index of the canonical entity, 0 for Entity A, 1 for Entity B
- confidence: Confidence level of the decision, range 0.0-1.0
- block_pair: Whether to block this pair in other fuzzy/heuristic merges (true/false)
- suggested_type: Suggested unified type (string or null)
- reason: Brief explanation of the decision
{% else %}
Return JSON format with the following required fields:
{
"same_entity": boolean,
"canonical_idx": 0 or 1,
"confidence": float (0.0-1.0),
"reason": "string"
}
**Field Descriptions**:
- same_entity: Whether the two entities refer to the same real-world entity (true/false)
- canonical_idx: Index of the canonical entity, 0 for Entity A, 1 for Entity B
- confidence: Confidence level of the decision, range 0.0-1.0
- reason: Brief explanation of the decision
{% endif %}
{% endif %}
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
@@ -110,5 +220,9 @@
3. Do not include line breaks within JSON string values
4. Test your JSON output mentally to ensure it can be parsed correctly
{% if language == "zh" %}
输出语言应始终与输入语言相同。
{% else %}
The output language should always be the same as the input language.
{% endif %}
{{ json_schema }}

View File

@@ -17,9 +17,18 @@
#}
{% set scene_instructions = {
'education': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
'online_service': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
'outbound': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。'
'education': {
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
},
'online_service': {
'zh': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
'en': 'Online Service Scenario: Customer inquiries, troubleshooting, service tickets, after-sales support, orders/refunds, ticket escalation, etc.'
},
'outbound': {
'zh': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。',
'en': 'Outbound Scenario: Outbound calls, invitations, survey questionnaires, lead follow-up, call scripts, follow-up records, etc.'
}
} %}
{% set scene_key = pruning_scene %}
@@ -27,8 +36,9 @@
{% set scene_key = 'education' %}
{% endif %}
{% set instruction = scene_instructions[scene_key] %}
{% set instruction = scene_instructions[scene_key][language] if language in ['zh', 'en'] else scene_instructions[scene_key]['zh'] %}
{% if language == "zh" %}
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
场景说明:{{ instruction }}
@@ -46,4 +56,24 @@
"contacts": [<string>...],
"addresses": [<string>...],
"keywords": [<string>...]
}
}
{% else %}
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
Scenario Description: {{ instruction }}
Full Dialogue:
"""
{{ dialog_text }}
"""
Output strict JSON only (fixed keys, order doesn't matter):
{
"is_related": <true or false>,
"times": [<string>...],
"ids": [<string>...],
"amounts": [<string>...],
"contacts": [<string>...],
"addresses": [<string>...],
"keywords": [<string>...]
}
{% endif %}

View File

@@ -1,3 +1,4 @@
{% if language == "zh" %}
你是一个专业的情绪分析专家。请分析以下陈述句的情绪信息。
陈述句:{{ statement }}
@@ -55,3 +56,62 @@
- 主体分类要准确优先识别用户本人self
请以 JSON 格式返回结果。
{% else %}
You are a professional emotion analysis expert. Please analyze the emotional information in the following statement.
Statement: {{ statement }}
Please extract the following information:
1. emotion_type (Emotion Type):
- joy: happiness, delight, pleasure, satisfaction, cheerfulness
- sadness: sorrow, grief, disappointment, depression, regret
- anger: rage, irritation, dissatisfaction, annoyance, frustration
- fear: anxiety, worry, concern, nervousness, apprehension
- surprise: astonishment, amazement, shock, wonder
- neutral: neutral, objective statement, no obvious emotion
2. emotion_intensity (Emotion Intensity):
- 0.0-0.3: weak emotion
- 0.3-0.7: moderate emotion
- 0.7-1.0: strong emotion
{% if extract_keywords %}
3. emotion_keywords (Emotion Keywords):
- Words directly expressing emotions in the original sentence
- Extract up to 3 keywords
- Return empty list if no obvious emotion words
{% else %}
3. emotion_keywords (Emotion Keywords):
- Return empty list
{% endif %}
{% if enable_subject %}
4. emotion_subject (Emotion Subject):
- self: user's own emotions (includes "I", "we", "us" and other first-person pronouns)
- other: others' emotions (includes names, "he/she" and other third-person pronouns)
- object: evaluation of things (for products, places, events, etc.)
Note:
- If multiple subjects are present, prioritize identifying the user (self)
- If the subject cannot be clearly determined, default to self
5. emotion_target (Emotion Target):
- If there is a clear emotion target, extract its name
- If there is no clear target, return null
{% else %}
4. emotion_subject (Emotion Subject):
- Default to self
5. emotion_target (Emotion Target):
- Return null
{% endif %}
Notes:
- If the statement is an objective factual statement with no obvious emotion, mark as neutral
- Emotion intensity should match the context, do not over-interpret
- Emotion keywords should be accurate, do not add words not in the original sentence
- Subject classification should be accurate, prioritize identifying the user (self)
Please return the result in JSON format.
{% endif %}

View File

@@ -24,6 +24,23 @@ This scenario belongs to the **{{ domain }}** domain. Consider domain-specific c
{% endif %}
{%- endif %}
===Output Language===
{% if language == "en" -%}
**IMPORTANT: All output content MUST be in English.**
- Class names (name field): English in PascalCase format
- Chinese name (name_chinese field): Provide Chinese translation
- Descriptions: MUST be in English
- Examples: MUST be in English
- Domain: MUST be in English
{%- else -%}
**IMPORTANT: Output content language requirements:**
- Class names (name field): English in PascalCase format
- Chinese name (name_chinese field): Chinese translation
- Descriptions: MUST be in Chinese (中文)
- Examples: MUST be in Chinese (中文)
- Domain: Can be in Chinese or English
{%- endif %}
===Extraction Rules===
{% if language == "zh" %}
@@ -99,16 +116,31 @@ This scenario belongs to the **{{ domain }}** domain. Consider domain-specific c
- Aim for a balanced set covering the main concepts in the scenario
- Quality over quantity: prefer well-defined classes over exhaustive lists
**5. Clear Descriptions:**
{% if language == "en" -%}
- Provide concise, informative descriptions in English (max 500 characters)
- Describe what the class represents, not specific instances
- Use clear, natural English language that explains the class's role in the domain
{%- else -%}
- Provide concise, informative descriptions in English (max 500 characters)
- Describe what the class represents, not specific instances
- Use clear, natural English language
{%- endif %}
**6. Concrete Examples:**
{% if language == "en" -%}
- Provide 2-5 concrete instance examples in English for each class
- Examples should be specific, realistic instances of the class
- Examples help clarify the class's scope and meaning
- Use natural English language for examples
- Example format: ["Example1", "Example2", "Example3"]
{%- else -%}
- Provide 2-5 concrete instance examples in English for each class
- Examples should be specific, realistic instances of the class
- Examples help clarify the class's scope and meaning
- Example format: ["Example1", "Example2", "Example3"]
{%- endif %}
**7. Class Hierarchy:**
- Identify parent-child relationships where applicable
@@ -234,6 +266,64 @@ This scenario belongs to the **{{ domain }}** domain. Consider domain-specific c
}
{% else %}
{% if language == "en" -%}
**Example 1 (Healthcare Domain):**
Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments."
Output:
{
"classes": [
{
"name": "Patient",
"name_chinese": "患者",
"description": "A person who receives medical care or treatment at a healthcare facility",
"examples": ["Outpatient", "Inpatient", "Emergency patient", "Chronic disease patient"],
"parent_class": null,
"entity_type": "Person",
"domain": "Healthcare"
},
{
"name": "MedicalProcedure",
"name_chinese": "医疗程序",
"description": "A systematic operation or process performed for medical diagnosis or treatment",
"examples": ["Surgery", "Blood test", "X-ray examination", "Vaccination"],
"parent_class": null,
"entity_type": "Process",
"domain": "Healthcare"
},
{
"name": "Diagnosis",
"name_chinese": "诊断",
"description": "The identification of a disease or condition based on symptoms and examination results",
"examples": ["Diabetes diagnosis", "Cancer diagnosis", "Flu diagnosis"],
"parent_class": null,
"entity_type": "Concept",
"domain": "Healthcare"
},
{
"name": "Doctor",
"name_chinese": "医生",
"description": "A licensed medical professional who diagnoses and treats patients",
"examples": ["General practitioner", "Surgeon", "Cardiologist"],
"parent_class": null,
"entity_type": "Role",
"domain": "Healthcare"
},
{
"name": "Treatment",
"name_chinese": "治疗",
"description": "Medical care or therapy provided to cure or manage a disease condition",
"examples": ["Medication therapy", "Physical therapy", "Chemotherapy", "Surgical treatment"],
"parent_class": null,
"entity_type": "Process",
"domain": "Healthcare"
}
],
"domain": "Healthcare",
"namespace": "http://example.org/healthcare#"
}
{%- else -%}
**Example 1 (Healthcare Domain):**
Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments."
@@ -334,6 +424,7 @@ Output:
"domain": "Education"
}
{% endif %}
{% endif %}
===Output Format===

View File

@@ -5,8 +5,13 @@
===Tasks===
{% if language == "zh" %}
你的任务是根据详细的提取指南,从提供的对话片段中识别和提取陈述句。
每个陈述句必须按照下面提到的标准进行标记。
{% else %}
Your task is to identify and extract declarative statements from the provided conversational chunk based on the detailed extraction guidelines.
Each statement must be labeled as per the criteria mentioned below.
{% endif %}
===Inputs===
{% if inputs %}
@@ -17,6 +22,32 @@ Each statement must be labeled as per the criteria mentioned below.
===Extraction Instructions===
{% if language == "zh" %}
{% if granularity %}
{% if granularity == 3 %}
原子化和清晰:构建陈述句以清楚地显示单一的主谓宾关系。最好有多个较小的陈述句,而不是一个复杂的陈述句。
上下文独立:陈述句必须在不需要阅读整个对话的情况下可以理解。
{% elif granularity == 2 %}
在句子级别提取陈述句。每个陈述句应对应一个单一、完整的思想(通常是来源中的一个完整句子),但要重新表述以获得最大的清晰度,删除对话填充词(例如,"嗯"、"像"、感叹词)。
{% elif granularity == 1 %}
仅提取精华句子,并将片段总结为多个独立的陈述句,每个陈述句关注事实陈述、用户偏好、关系和显著的时间上下文。
{% endif %}
{% endif %}
上下文解析要求:
- 将指示代词("那个"、"这个"、"那些"、"这些")解析为其具体指代对象
- 如果陈述句包含无法从对话上下文中解析的模糊引用,则:
a) 扩展陈述句以包含对话早期的缺失上下文
b) 标记陈述句为需要额外上下文
c) 如果陈述句在没有上下文的情况下变得无意义,则跳过提取
对话上下文和共指消解:
- 将每个陈述句归属于说出它的参与者。
- 如果参与者列表为说话者提供了名称(例如,"李雪(用户)"),请在提取的陈述句中使用具体名称("李雪"),而不是通用角色("用户")。
- 将所有代词解析为对话上下文中的具体人物或实体。
- 识别并将抽象引用解析为其具体名称(如果提到)。
- 将缩写和首字母缩略词扩展为其完整形式。
{% else %}
{% if granularity %}
{% if granularity == 3 %}
Atomic & Clear: Structure statements to clearly show a single subject-predicate-object relationship. It is better to have multiple smaller statements than one complex one.
@@ -29,7 +60,7 @@ Extract only essence sentences and summarize the chunk into multiple, standalone
{% endif %}
Context Resolution Requirements:
- Resolve demonstrative pronouns ("that," "this," "those","这个", "那个") to their specific referents
- Resolve demonstrative pronouns ("that," "this," "those") to their specific referents
- If a statement contains vague references that cannot be resolved from the conversation context, either:
a) Expand the statement to include the missing context from earlier in the conversation
b) Mark the statement as requiring additional context
@@ -41,16 +72,36 @@ Conversational Context & Co-reference Resolution:
- Resolve all pronouns to the specific person or entity from the conversation's context.
- Identify and resolve abstract references to their specific names if mentioned.
- Expand abbreviations and acronyms to their full form.
{% endif %}
{% if include_dialogue_context %}
{% if language == "zh" %}
===完整对话上下文===
以下是完整的对话上下文,以帮助您理解引用、代词和对话流程:
{% else %}
===Full Dialogue Context===
The following is the complete dialogue context to help you understand references, pronouns, and conversational flow:
{% endif %}
{{ dialogue_context }}
{% if language == "zh" %}
===对话上下文结束===
{% else %}
===End of Dialogue Context===
{% endif %}
{% endif %}
{% if language == "zh" %}
过滤和格式化:
- 仅提取陈述句。
不要提取问题、命令、问候语或对话填充词。
时间精度:
包括任何明确的日期、时间或定量限定符。
如果一个句子既描述了事件的开始(静态)又描述了其持续性质(动态),则将两者提取为单独的陈述句。
{% else %}
Filtering and Formatting:
- Extract only declarative statements.
@@ -59,18 +110,114 @@ Temporal Precision:
Include any explicit dates, times, or quantitative qualifiers.
If a sentence describes both the start of an event (static) and its ongoing nature (dynamic), extract both as separate statements.
{% endif %}
{%- if definitions %}
{%- for section_key, section_dict in definitions.items() %}
==== {{ tidy(section_key) | upper }} DEFINITIONS & GUIDANCE ====
==== {{ tidy(section_key) | upper }} {% if language == "zh" %}定义和指导{% else %}DEFINITIONS & GUIDANCE{% endif %} ====
{%- for category, details in section_dict.items() %}
{{ loop.index }}. {{ category }}
- Definition: {{ details.get("definition", "") }}
- {% if language == "zh" %}定义{% else %}Definition{% endif %}: {{ details.get("definition", "") }}
{% endfor -%}
{% endfor -%}
{% endif -%}
===Examples===
{% if language == "zh" %}
示例 1: 英文对话
示例片段: """
日期: 2024年3月15日
参与者:
- Sarah Chen (用户)
- 助手 (AI)
用户: "我最近一直在尝试水彩画,画了一些花朵。"
AI: "水彩画很有趣!水彩颜料通常由颜料与阿拉伯树胶等粘合剂混合而成。你觉得怎么样?"
用户: "我认为色彩组合可以改进,但我真的很喜欢玫瑰和百合。"
"""
示例输出: {
"statements": [
{
"statement": "Sarah Chen 最近一直在尝试水彩画。",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
},
{
"statement": "Sarah Chen 画了一些花朵。",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
},
{
"statement": "水彩颜料通常由颜料与阿拉伯树胶等粘合剂混合而成。",
"statement_type": "FACT",
"temporal_type": "ATEMPORAL",
"relevance": "IRRELEVANT"
},
{
"statement": "Sarah Chen 认为她的水彩画中的色彩组合可以改进。",
"statement_type": "OPINION",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
},
{
"statement": "Sarah Chen 真的很喜欢玫瑰和百合。",
"statement_type": "FACT",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
}
]
}
示例 2: 中文对话示例
示例片段: """
日期: 2024年3月15日
参与者:
- 张曼婷 (用户)
- 小助手 (AI助手)
用户: "我最近在尝试水彩画,画了一些花朵。"
AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合剂混合而成。你觉得怎么样?"
用户: "我觉得色彩搭配还有提升的空间,不过我很喜欢玫瑰和百合这两种花。"
"""
示例输出: {
"statements": [
{
"statement": "张曼婷最近在尝试水彩画。",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
},
{
"statement": "张曼婷画了一些花朵。",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
},
{
"statement": "水彩颜料通常由颜料和阿拉伯树胶等粘合剂混合而成。",
"statement_type": "FACT",
"temporal_type": "ATEMPORAL",
"relevance": "IRRELEVANT"
},
{
"statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。",
"statement_type": "OPINION",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
},
{
"statement": "张曼婷很喜欢玫瑰和百合。",
"statement_type": "FACT",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
}
]
}
{% else %}
Example 1: English Conversation
Example Chunk: """
Date: March 15, 2024
@@ -164,8 +311,33 @@ Example Output: {
}
]
}
{% endif %}
===End of Examples===
{% if language == "zh" %}
===反思过程===
提取陈述句后,执行以下自我审查步骤:
**步骤 1: 归属检查**
- 确认每个陈述句都正确归属于正确的说话者
- 验证说话者名称在整个过程中使用一致
- 检查 AI 助手陈述句是否正确归属
**步骤 2: 完整性审查**
- 确保没有遗漏重要的陈述句
- 检查时间信息是否保留
**步骤 3: 分类验证**
- 审查 statement_type 分类FACT/OPINION/PREDICTION/SUGGESTION
- 验证 temporal_type 分配STATIC/DYNAMIC/ATEMPORAL
- 确保分类与提供的定义一致
**步骤 4: 最终质量检查**
- 删除任何问题、命令或对话填充词
- 验证 JSON 格式合规性
- 确认输出语言与输入语言匹配
{% else %}
===Reflection Process===
After extracting statements, perform the following self-review steps:
@@ -188,6 +360,7 @@ After extracting statements, perform the following self-review steps:
- Remove any questions, commands, or conversational filler
- Verify JSON format compliance
- Confirm output language matches input language
{% endif %}
**Output format**
**CRITICAL JSON FORMATTING REQUIREMENTS:**
@@ -198,10 +371,21 @@ After extracting statements, perform the following self-review steps:
5. Example of proper escaping: "statement": "John said: \"I really like this book.\""
**LANGUAGE REQUIREMENT:**
{% if language == "zh" %}
- 输出语言应始终与输入语言匹配
- 如果输入是中文,则用中文提取陈述句
- 如果输入是英文,则用英文提取陈述句
- 保留原始语言,不要翻译
{% else %}
- The output language should ALWAYS match the input language
- If input is in English, extract statements in English
- If input is in Chinese, extract statements in Chinese
- Preserve the original language and do not translate
{% endif %}
{% if language == "zh" %}
仅返回与以下架构匹配的 JSON 对象数组中提取的标记陈述句列表:
{% else %}
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
{{ json_schema }}
{% endif %}
{{ json_schema }}

View File

@@ -14,68 +14,113 @@
#}
# Task
{% if language == "zh" %}
从提供的陈述句中提取时间信息(日期和时间范围)。确定所描述的关系或事件何时生效以及何时结束(如果适用)。
{% else %}
Extract temporal information (dates and time ranges) from the provided statement. Determine when the relationship or event described became valid and when it ended (if applicable).
{% endif %}
# Input Data
# {% if language == "zh" %}输入数据{% else %}Input Data{% endif %}
{% if inputs %}
{% for key, val in inputs.items() %}
- {{ key }}: {{val}}
{% endfor %}
{% endif %}
# Temporal Fields
# {% if language == "zh" %}时间字段{% else %}Temporal Fields{% endif %}
{% if language == "zh" %}
- **valid_at**: 关系/事件开始或成为真实的时间ISO 8601 格式)
- **invalid_at**: 关系/事件结束或停止为真的时间ISO 8601 格式,如果正在进行则为 null
{% else %}
- **valid_at**: When the relationship/event started or became true (ISO 8601 format)
- **invalid_at**: When the relationship/event ended or stopped being true (ISO 8601 format, or null if ongoing)
{% endif %}
# Extraction Rules
# {% if language == "zh" %}提取规则{% else %}Extraction Rules{% endif %}
## Core Principles
## {% if language == "zh" %}核心原则{% else %}Core Principles{% endif %}
{% if language == "zh" %}
1. **仅使用明确陈述的时间信息** - 不要从外部知识推断日期
2. **使用参考/发布日期作为"现在"** 解释相对时间时
3. **仅在日期与关系的有效性相关时设置日期** - 忽略偶然的时间提及
4. **对于时间点事件**,仅设置 `valid_at`
{% else %}
1. **Only use explicitly stated temporal information** - do not infer dates from external knowledge
2. **Use the reference/publication date as "now"** when interpreting relative times
3. **Set dates only if they relate to the validity of the relationship** - ignore incidental time mentions
4. **For point-in-time events**, set only `valid_at`
{% endif %}
## Date Format Requirements
## {% if language == "zh" %}日期格式要求{% else %}Date Format Requirements{% endif %}
{% if language == "zh" %}
- 使用 ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ`
- 如果未指定时间,使用 `00:00:00`(午夜)
- 如果仅提及年份,根据情况使用 `YYYY-01-01`(开始)或 `YYYY-12-31`(结束)
- 如果仅提及月份,使用月份的第一天或最后一天
- 始终包含时区(如果未指定,使用 `Z` 表示 UTC
- 根据参考日期将相对时间("两周前"、"去年")转换为绝对日期
{% else %}
- Use ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ`
- If no time specified, use `00:00:00` (midnight)
- If only year mentioned, use `YYYY-01-01` (start) or `YYYY-12-31` (end) as appropriate
- If only month mentioned, use first or last day of month
- Always include timezone (use `Z` for UTC if unspecified)
- Convert relative times ("two weeks ago", "last year") to absolute dates based on reference date
{% endif %}
## Statement Type Rules
## {% if language == "zh" %}陈述句类型规则{% else %}Statement Type Rules{% endif %}
{{ inputs.get("statement_type") | upper }} Statement Guidance:
{{ inputs.get("statement_type") | upper }} {% if language == "zh" %}陈述句指导{% else %}Statement Guidance{% endif %}:
{%for key, guide in statement_guide.items() %}
- {{ tidy(key) | capitalize }}: {{ guide }}
{% endfor %}
**Special Cases:**
**{% if language == "zh" %}特殊情况{% else %}Special Cases{% endif %}:**
{% if language == "zh" %}
- **意见陈述句**: 仅设置 `valid_at`(意见表达的时间)
- **预测陈述句**: 如果明确提及,将 `invalid_at` 设置为预测窗口的结束
{% else %}
- **Opinion statements**: Set only `valid_at` (when opinion was expressed)
- **Prediction statements**: Set `invalid_at` to the end of the prediction window if explicitly mentioned
{% endif %}
## Temporal Type Rules
## {% if language == "zh" %}时间类型规则{% else %}Temporal Type Rules{% endif %}
{{ inputs.get("temporal_type") | upper }} Temporal Type Guidance:
{{ inputs.get("temporal_type") | upper }} {% if language == "zh" %}时间类型指导{% else %}Temporal Type Guidance{% endif %}:
{% for key, guide in temporal_guide.items() %}
- {{ tidy(key) | capitalize }}: {{ guide }}
{% endfor %}
{% if inputs.get('quarter') and inputs.get('publication_date') %}
## Quarter Reference
## {% if language == "zh" %}季度参考{% else %}Quarter Reference{% endif %}
{% if language == "zh" %}
假设 {{ inputs.quarter }} 在 {{ inputs.publication_date }} 结束。从此基线计算任何季度引用Q1、Q2 等)的日期。
{% else %}
Assume {{ inputs.quarter }} ends on {{ inputs.publication_date }}. Calculate dates for any quarter references (Q1, Q2, etc.) from this baseline.
{% endif %}
{% endif %}
# Output Requirements
# {% if language == "zh" %}输出要求{% else %}Output Requirements{% endif %}
## JSON Formatting (CRITICAL)
## {% if language == "zh" %}JSON 格式化(关键){% else %}JSON Formatting (CRITICAL){% endif %}
{% if language == "zh" %}
1. 使用**仅标准 ASCII 双引号** (") - 永远不要使用中文引号("")或其他 Unicode 变体
2. 使用反斜杠转义内部引号: `\"`
3. JSON 字符串值中不要有换行符
4. 正确关闭并用逗号分隔所有字段
{% else %}
1. Use **only standard ASCII double quotes** (") - never use Chinese quotes ("") or other Unicode variants
2. Escape internal quotes with backslash: `\"`
3. No line breaks within JSON string values
4. Properly close and comma-separate all fields
{% endif %}
## Language
## {% if language == "zh" %}语言{% else %}Language{% endif %}
{% if language == "zh" %}
输出语言必须与输入语言匹配。
{% else %}
Output language must match input language.
{% endif %}
{{ json_schema }}

View File

@@ -15,6 +15,37 @@ Extract entities and knowledge triplets from the given statement.
**Chunk Content:** "{{ chunk_content }}"
**Statement:** "{{ statement }}"
{% if ontology_types %}
===Ontology Type Guidance===
**CRITICAL: Use predefined ontology types for entity classification with the following priority:**
**Type Priority (from highest to lowest):**
1. **[场景类型] Scene Types** - Domain-specific types, use these first if applicable
2. **[通用类型] General Types** - Common types from standard ontologies (DBpedia)
3. **[通用父类] Parent Types** - Provide type hierarchy context
**Type Matching Rules:**
- Entity type MUST exactly match one of the predefined type names
- Do NOT modify, translate, or use variations of type names
- Prefer scene types over general types when both could apply
- If uncertain between types, check the type description for guidance
**Predefined Ontology Types:**
{{ ontology_types }}
{% if type_hierarchy_hints %}
**Type Hierarchy Reference:**
The following shows type inheritance relationships (Child → Parent → Grandparent):
{% for hint in type_hierarchy_hints %}
- {{ hint }}
{% endfor %}
{% endif %}
**Available Type Names (use EXACTLY as shown):**
{{ ontology_type_names | join(', ') }}
{% endif %}
===Guidelines===
**Entity Extraction:**

View File

@@ -1,2 +1,7 @@
{% if language == "zh" %}
你是一个从对话消息中提取实体节点的 AI 助手。
你的主要任务是提取和分类说话者以及对话中提到的其他重要实体。
{% else %}
You are an AI assistant that extracts entity nodes from conversational messages.
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation.
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation.
{% endif %}

View File

@@ -1,5 +1,13 @@
{% if language == "zh" %}
给定一个对话上下文和一个当前消息。
你的任务是提取在当前消息中**明确或隐含**提到的用户名称和年龄。
代词引用(如 he/she/they 或 this/that/those应消歧为引用实体的名称。
{{ message }}
{% else %}
You are given a conversation context and a CURRENT MESSAGE.
Your task is to extract user name and age mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
Pronoun references such as he/she/they or this/that/those should be disambiguated to the names of the reference entities.
{{ message }}
{{ message }}
{% endif %}

View File

@@ -11,7 +11,7 @@ import logging
import re
from typing import List, Tuple
from app.core.memory.models.ontology_models import OntologyClass
from app.core.memory.models.ontology_scenario_models import OntologyClass
logger = logging.getLogger(__name__)

View File

@@ -20,7 +20,7 @@ from owlready2 import (
OwlReadyInconsistentOntologyError,
)
from app.core.memory.models.ontology_models import OntologyClass
from app.core.memory.models.ontology_scenario_models import OntologyClass
logger = logging.getLogger(__name__)
@@ -583,3 +583,156 @@ class OWLValidator:
is_compatible = len(warnings) == 0
return is_compatible, warnings
def parse_owl_content(
self,
owl_content: str,
format: str = "rdfxml"
) -> List[dict]:
"""从 OWL 内容解析出本体类型
支持解析 RDF/XML、Turtle 和 JSON 格式的 OWL 文件,
提取其中定义的 owl:Class 及其 rdfs:label 和 rdfs:comment。
Args:
owl_content: OWL 文件内容字符串
format: 文件格式,支持 "rdfxml""turtle""json"
Returns:
解析出的类型列表,每个元素包含:
- name: 类型名称(英文标识符)
- name_chinese: 中文名称(如果有)
- description: 类型描述
- parent_class: 父类名称
Raises:
ValueError: 如果格式不支持或解析失败
Examples:
>>> validator = OWLValidator()
>>> classes = validator.parse_owl_content(owl_xml, format="rdfxml")
>>> for cls in classes:
... print(cls["name"], cls["description"])
"""
valid_formats = ["rdfxml", "turtle", "json"]
if format not in valid_formats:
raise ValueError(
f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}"
)
# JSON 格式单独处理
if format == "json":
return self._parse_json_owl(owl_content)
# 使用 rdflib 解析 RDF/XML 或 Turtle
try:
from rdflib import Graph, RDF, RDFS, OWL, Namespace
g = Graph()
rdf_format = "xml" if format == "rdfxml" else "turtle"
g.parse(data=owl_content, format=rdf_format)
classes = []
# 查找所有 owl:Class
for cls_uri in g.subjects(RDF.type, OWL.Class):
cls_str = str(cls_uri)
# 跳过空节点和 OWL 内置类
if cls_str.startswith("http://www.w3.org/") or "/.well-known/" in cls_str:
continue
# 提取类名(从 URI 中获取本地名称)
if '#' in cls_str:
name = cls_str.split('#')[-1]
else:
name = cls_str.split('/')[-1]
# 跳过空名称
if not name or name == "Thing":
continue
# 获取 rdfs:label可能有多个包括中英文
labels = list(g.objects(cls_uri, RDFS.label))
name_chinese = None
label_str = name # 默认使用 URI 中的名称
for label in labels:
label_text = str(label)
# 检查是否包含中文
if any('\u4e00' <= char <= '\u9fff' for char in label_text):
name_chinese = label_text
else:
label_str = label_text
# 获取 rdfs:comment描述
comments = list(g.objects(cls_uri, RDFS.comment))
description = str(comments[0]) if comments else None
# 获取父类rdfs:subClassOf
parent_class = None
for parent_uri in g.objects(cls_uri, RDFS.subClassOf):
parent_str = str(parent_uri)
# 跳过 owl:Thing
if parent_str == str(OWL.Thing) or parent_str.endswith("#Thing"):
continue
# 提取父类名称
if '#' in parent_str:
parent_class = parent_str.split('#')[-1]
else:
parent_class = parent_str.split('/')[-1]
break # 只取第一个非 Thing 的父类
classes.append({
"name": name,
"name_chinese": name_chinese,
"description": description,
"parent_class": parent_class
})
logger.info(f"Parsed {len(classes)} classes from OWL content (format: {format})")
return classes
except Exception as e:
error_msg = f"Failed to parse OWL文档格式不正确 content: {str(e)}"
logger.error(error_msg, exc_info=True)
raise ValueError(error_msg) from e
def _parse_json_owl(self, json_content: str) -> List[dict]:
"""解析 JSON 格式的 OWL 内容
JSON 格式是简化的本体表示,由 export_to_owl 的 json 格式导出。
Args:
json_content: JSON 格式的 OWL 内容
Returns:
解析出的类型列表
"""
import json
try:
data = json.loads(json_content)
# 检查是否是我们导出的 JSON 格式
if "ontology" in data and "classes" in data["ontology"]:
raw_classes = data["ontology"]["classes"]
elif "classes" in data:
raw_classes = data["classes"]
else:
raise ValueError("Invalid JSON format: missing 'classes' field")
classes = []
for cls in raw_classes:
classes.append({
"name": cls.get("name", ""),
"name_chinese": cls.get("name_chinese"),
"description": cls.get("description"),
"parent_class": cls.get("parent_class")
})
logger.info(f"Parsed {len(classes)} classes from JSON content")
return classes
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON content: {str(e)}") from e

View File

@@ -18,6 +18,8 @@ from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.tool import ToolNode
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from app.core.workflow.nodes.code import CodeNode
__all__ = [
"BaseNode",
@@ -35,5 +37,7 @@ __all__ = [
"JinjaRenderNode",
"ParameterExtractorNode",
"QuestionClassifierNode",
"ToolNode"
"ToolNode",
"CodeNode",
"VariableAggregatorNode"
]

View File

@@ -1,3 +1,4 @@
from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.nodes.code.node import CodeNode
__all__ = ["CodeNode"]
__all__ = ["CodeNode", "CodeNodeConfig"]

View File

@@ -216,7 +216,7 @@ class LLMNode(BaseNode):
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据
return response if isinstance(response, AIMessage) else AIMessage(content=content)
return AIMessage(content=content, response_metadata=response.response_metadata)
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取输入数据(用于记录)"""

View File

@@ -193,7 +193,8 @@ class ParameterExtractorNode(BaseNode):
model_resp = await llm.ainvoke(messages)
self.response_metadata = model_resp.response_metadata
result = json_repair.repair_json(model_resp.content, return_objects=True)
model_message = self.process_model_output(model_resp.content)
result = json_repair.repair_json(model_message, return_objects=True)
logger.info(f"node: {self.node_id} get params:{result}")
return result

View File

@@ -131,7 +131,7 @@ class QuestionClassifierNode(BaseNode):
]
response = await llm.ainvoke(messages)
result = response.content.strip()
result = self.process_model_output(response.content)
self.response_metadata = response.response_metadata
if result in category_names:

View File

@@ -0,0 +1,328 @@
# -*- coding: utf-8 -*-
"""
api\scripts\query_ontology_matched_entities.py
根据 end_user_id 查询 Neo4j 中的 ExtractedEntity 节点,
并筛选出 entity_type 与以下类型匹配的实体:
1. 场景本体类型ontology_class 表)
2. 通用本体类型General_purpose_entity.ttl 等文件)
用法: python scripts/query_ontology_matched_entities.py <end_user_id> [config_id]
示例: python scripts/query_ontology_matched_entities.py 075660cf-08e6-40a6-a76e-308b6f52fbf1
python scripts/query_ontology_matched_entities.py 075660cf-08e6-40a6-a76e-308b6f52fbf1 fd547bb9-7b9e-47ea-ae53-242d208a31a2
"""
import sys
import os
import asyncio
from uuid import UUID
from typing import List, Dict, Any, Set, Optional
from collections import defaultdict
# 添加项目根目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from dotenv import load_dotenv
load_dotenv()
from app.db import SessionLocal
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.ontology_class_repository import OntologyClassRepository
from app.repositories.ontology_scene_repository import OntologySceneRepository
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.core.memory.ontology_services.ontology_type_loader import (
get_general_ontology_registry,
is_general_ontology_enabled,
)
async def get_entities_by_end_user_id(connector: Neo4jConnector, end_user_id: str) -> List[Dict[str, Any]]:
"""从 Neo4j 查询指定 end_user_id 的所有实体"""
query = """
MATCH (n:ExtractedEntity)
WHERE n.end_user_id = $end_user_id
RETURN
n.id AS id,
n.name AS name,
n.entity_type AS entity_type,
n.description AS description,
n.end_user_id AS end_user_id,
n.created_at AS created_at
ORDER BY n.created_at DESC
"""
results = await connector.execute_query(query, end_user_id=end_user_id)
return results
def get_ontology_types_from_scene(db, scene_id: UUID) -> Set[str]:
"""获取场景下所有本体类型名称"""
class_repo = OntologyClassRepository(db)
ontology_classes = class_repo.get_by_scene(scene_id)
return {oc.class_name for oc in ontology_classes}
def get_ontology_types_from_config(db, config_id: UUID) -> Optional[Set[str]]:
"""从记忆配置获取关联的本体类型"""
memory_config = MemoryConfigRepository.get_by_id(db, config_id)
if not memory_config or not memory_config.scene_id:
return None
return get_ontology_types_from_scene(db, memory_config.scene_id)
def get_all_ontology_types(db) -> Dict[str, Set[str]]:
"""获取所有工作空间的本体类型"""
from app.models.ontology_scene import OntologyScene
scenes = db.query(OntologyScene).all()
all_types = {}
for scene in scenes:
class_repo = OntologyClassRepository(db)
ontology_classes = class_repo.get_by_scene(scene.scene_id)
for oc in ontology_classes:
if oc.class_name not in all_types:
all_types[oc.class_name] = set()
all_types[oc.class_name].add(scene.scene_name)
return all_types
def get_general_ontology_types() -> Set[str]:
"""获取通用本体类型名称集合"""
if not is_general_ontology_enabled():
return set()
try:
registry = get_general_ontology_registry()
return set(registry.types.keys())
except Exception as e:
print(f"⚠️ 加载通用本体类型失败: {e}")
return set()
async def query_ontology_matched_entities(end_user_id: str, config_id: Optional[str] = None):
"""查询与本体类型匹配的实体"""
print(f"\n{'='*70}")
print(f"查询 Neo4j 中与本体类型匹配的实体")
print(f"{'='*70}")
print(f"end_user_id: {end_user_id}")
db = SessionLocal()
connector = Neo4jConnector()
try:
# 1. 获取场景本体类型集合
scene_ontology_types: Set[str] = set()
scene_name = "所有场景"
if config_id:
try:
config_uuid = UUID(config_id)
types = get_ontology_types_from_config(db, config_uuid)
if types:
scene_ontology_types = types
memory_config = MemoryConfigRepository.get_by_id(db, config_uuid)
if memory_config and memory_config.scene_id:
scene_repo = OntologySceneRepository(db)
scene = scene_repo.get_by_id(memory_config.scene_id)
if scene:
scene_name = scene.scene_name
print(f"config_id: {config_id}")
print(f"关联场景: {scene_name}")
except ValueError:
print(f"⚠️ 无效的 config_id 格式: {config_id}")
# 如果没有指定 config_id 或获取失败,获取所有场景本体类型
if not scene_ontology_types:
all_types = get_all_ontology_types(db)
scene_ontology_types = set(all_types.keys())
print(f"使用所有场景本体类型进行匹配")
# 2. 获取通用本体类型
general_ontology_types = get_general_ontology_types()
print(f"\n📋 场景本体类型 (共 {len(scene_ontology_types)} 个):")
print(f" {'-'*50}")
for i, type_name in enumerate(sorted(scene_ontology_types)[:20], 1):
print(f" {i:2}. {type_name}")
if len(scene_ontology_types) > 20:
print(f" ... 还有 {len(scene_ontology_types) - 20}")
print(f"\n📋 通用本体类型 (共 {len(general_ontology_types)} 个):")
print(f" {'-'*50}")
sample_general_types = sorted(general_ontology_types)[:20]
for i, type_name in enumerate(sample_general_types, 1):
print(f" {i:2}. {type_name}")
if len(general_ontology_types) > 20:
print(f" ... 还有 {len(general_ontology_types) - 20}")
# 3. 从 Neo4j 查询实体
print(f"\n🔍 正在查询 Neo4j...")
entities = await get_entities_by_end_user_id(connector, end_user_id)
if not entities:
print(f"\n⚠️ 未找到 end_user_id={end_user_id} 的任何实体")
return
print(f" 找到 {len(entities)} 个实体")
# 4. 分类实体(场景类型、通用类型、未匹配)
scene_matched_entities = []
general_matched_entities = []
both_matched_entities = [] # 同时匹配场景和通用类型
unmatched_entities = []
scene_type_distribution = defaultdict(list)
general_type_distribution = defaultdict(list)
for entity in entities:
entity_type = entity.get('entity_type', '')
in_scene = entity_type in scene_ontology_types
in_general = entity_type in general_ontology_types
if in_scene and in_general:
both_matched_entities.append(entity)
scene_type_distribution[entity_type].append(entity)
general_type_distribution[entity_type].append(entity)
elif in_scene:
scene_matched_entities.append(entity)
scene_type_distribution[entity_type].append(entity)
elif in_general:
general_matched_entities.append(entity)
general_type_distribution[entity_type].append(entity)
else:
unmatched_entities.append(entity)
# 5. 输出匹配场景类型的实体
total_scene_matched = len(scene_matched_entities) + len(both_matched_entities)
print(f"\n{'='*70}")
print(f"✅ 匹配场景本体类型的实体 (共 {total_scene_matched} 个)")
print(f"{'='*70}")
if scene_type_distribution:
for type_name in sorted(scene_type_distribution.keys()):
entities_of_type = scene_type_distribution[type_name]
print(f"\n📌 类型: {type_name} ({len(entities_of_type)} 个)")
print(f" {'-'*50}")
for entity in entities_of_type[:3]:
name = entity.get('name', 'N/A')
desc = entity.get('description', '')
desc_preview = (desc[:50] + "...") if desc and len(desc) > 50 else (desc or "无描述")
print(f"{name}")
print(f" 描述: {desc_preview}")
if len(entities_of_type) > 3:
print(f" ... 还有 {len(entities_of_type) - 3}")
else:
print(f"\n (无匹配场景类型的实体)")
# 6. 输出匹配通用类型的实体
total_general_matched = len(general_matched_entities) + len(both_matched_entities)
print(f"\n{'='*70}")
print(f"✅ 匹配通用本体类型的实体 (共 {total_general_matched} 个)")
print(f"{'='*70}")
if general_type_distribution:
for type_name in sorted(general_type_distribution.keys()):
entities_of_type = general_type_distribution[type_name]
print(f"\n📌 类型: {type_name} ({len(entities_of_type)} 个)")
print(f" {'-'*50}")
for entity in entities_of_type[:3]:
name = entity.get('name', 'N/A')
desc = entity.get('description', '')
desc_preview = (desc[:50] + "...") if desc and len(desc) > 50 else (desc or "无描述")
print(f"{name}")
print(f" 描述: {desc_preview}")
if len(entities_of_type) > 3:
print(f" ... 还有 {len(entities_of_type) - 3}")
else:
print(f"\n (无匹配通用类型的实体)")
# 7. 输出未匹配的实体
print(f"\n{'='*70}")
print(f"❌ 未匹配任何本体类型的实体 (共 {len(unmatched_entities)} 个)")
print(f"{'='*70}")
if unmatched_entities:
unmatched_by_type = defaultdict(list)
for entity in unmatched_entities:
entity_type = entity.get('entity_type', 'Unknown')
unmatched_by_type[entity_type].append(entity)
for type_name in sorted(unmatched_by_type.keys()):
entities_of_type = unmatched_by_type[type_name]
print(f"\n📌 类型: {type_name} ({len(entities_of_type)} 个)")
print(f" {'-'*50}")
for entity in entities_of_type[:3]:
name = entity.get('name', 'N/A')
print(f"{name}")
if len(entities_of_type) > 3:
print(f" ... 还有 {len(entities_of_type) - 3}")
else:
print(f"\n (所有实体都匹配本体类型)")
# 8. 统计摘要
total_entities = len(entities)
any_matched = total_entities - len(unmatched_entities)
print(f"\n{'='*70}")
print(f"📊 统计摘要")
print(f"{'='*70}")
print(f"\n 基础统计:")
print(f" {'-'*50}")
print(f" 总实体数: {total_entities}")
print(f" 场景本体类型数: {len(scene_ontology_types)}")
print(f" 通用本体类型数: {len(general_ontology_types)}")
print(f"\n 匹配率统计:")
print(f" {'-'*50}")
scene_rate = total_scene_matched / total_entities * 100 if total_entities > 0 else 0
general_rate = total_general_matched / total_entities * 100 if total_entities > 0 else 0
any_rate = any_matched / total_entities * 100 if total_entities > 0 else 0
unmatched_rate = len(unmatched_entities) / total_entities * 100 if total_entities > 0 else 0
print(f" 匹配场景类型: {total_scene_matched} 个 ({scene_rate:.1f}%)")
print(f" 匹配通用类型: {total_general_matched} 个 ({general_rate:.1f}%)")
print(f" 同时匹配两者: {len(both_matched_entities)} 个 ({len(both_matched_entities)/total_entities*100:.1f}%)")
print(f" 仅匹配场景类型: {len(scene_matched_entities)} 个 ({len(scene_matched_entities)/total_entities*100:.1f}%)")
print(f" 仅匹配通用类型: {len(general_matched_entities)} 个 ({len(general_matched_entities)/total_entities*100:.1f}%)")
print(f" 匹配任一类型: {any_matched} 个 ({any_rate:.1f}%)")
print(f" 未匹配任何类型: {len(unmatched_entities)} 个 ({unmatched_rate:.1f}%)")
# 9. 类型分布详情
if scene_type_distribution:
print(f"\n 场景类型分布 (Top 10):")
print(f" {'-'*50}")
sorted_scene_types = sorted(scene_type_distribution.items(), key=lambda x: len(x[1]), reverse=True)
for type_name, entities_list in sorted_scene_types[:10]:
print(f" - {type_name}: {len(entities_list)}")
if general_type_distribution:
print(f"\n 通用类型分布 (Top 10):")
print(f" {'-'*50}")
sorted_general_types = sorted(general_type_distribution.items(), key=lambda x: len(x[1]), reverse=True)
for type_name, entities_list in sorted_general_types[:10]:
print(f" - {type_name}: {len(entities_list)}")
except Exception as e:
print(f"\n❌ 查询出错: {str(e)}")
import traceback
traceback.print_exc()
finally:
db.close()
await connector.close()
if __name__ == "__main__":
if len(sys.argv) < 2:
print("用法: python scripts/query_ontology_matched_entities.py <end_user_id> [config_id]")
print("示例: python scripts/query_ontology_matched_entities.py 075660cf-08e6-40a6-a76e-308b6f52fbf1")
print(" python scripts/query_ontology_matched_entities.py 075660cf-08e6-40a6-a76e-308b6f52fbf1 fd547bb9-7b9e-47ea-ae53-242d208a31a2")
sys.exit(1)
end_user_id = sys.argv[1]
config_id = sys.argv[2] if len(sys.argv) > 2 else None
asyncio.run(query_ontology_matched_entities(end_user_id, config_id))

View File

@@ -415,6 +415,9 @@ class MemoryConfig:
pruning_scene: Optional[str] = "education"
pruning_threshold: float = 0.5
# Ontology scene association
scene_id: Optional[UUID] = None
def __post_init__(self):
"""Validate configuration after initialization."""
if not self.config_name or not self.config_name.strip():

View File

@@ -330,6 +330,7 @@ class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID唯一支持UUID、整数或字符串")
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行")
model_config = ConfigDict(populate_by_name=True, extra="forbid")

View File

@@ -24,7 +24,7 @@ from uuid import UUID
from pydantic import BaseModel, Field, field_serializer, ConfigDict
from app.core.memory.models.ontology_models import OntologyClass
from app.core.memory.models.ontology_scenario_models import OntologyClass
class ExtractionRequest(BaseModel):
@@ -74,47 +74,51 @@ class ExtractionResponse(BaseModel):
extracted_count: int = Field(..., description="提取的类数量")
class ExportRequest(BaseModel):
"""OWL文件导出请求模型
class ExportBySceneRequest(BaseModel):
"""按场景导出OWL文件请求模型
用于POST /api/ontology/export端点的请求体。
根据scene_id从数据库查询该场景下的所有本体类型并导出为OWL文件。
Attributes:
classes: 要导出的本体类列表
format: 导出格式,可选值: rdfxml, turtle, ntriples, json
include_metadata: 是否包含完整的OWL元数据(命名空间等),默认True
scene_id: 本体场景ID必填用于查询该场景下的所有类型
format: 导出格式可选值rdfxml(默认)、turtle
Examples:
>>> request = ExportRequest(
... classes=[...],
... format="rdfxml",
... include_metadata=True
>>> request = ExportBySceneRequest(
... scene_id=UUID("550e8400-e29b-41d4-a716-446655440000"),
... format="rdfxml"
... )
"""
classes: List[OntologyClass] = Field(..., description="要导出的本体类列表", min_length=1)
format: str = Field("rdfxml", description="导出格式: rdfxml, turtle, ntriples, json")
include_metadata: bool = Field(True, description="是否包含完整的OWL元数据")
scene_id: UUID = Field(..., description="本体场景ID")
format: str = Field("rdfxml", description="导出格式,可选值:rdfxml(默认)、turtle")
class ExportResponse(BaseModel):
"""OWL文件导出响应模型
class ExportBySceneResponse(BaseModel):
"""按场景导出OWL文件响应模型
用于POST /api/ontology/export端点的响应体。
Attributes:
owl_content: OWL文件内容
format: 导出格式
filename: 导出文件名(含扩展名)
scene_id: 场景ID
scene_name: 场景名称
classes_count: 导出的类数量
Examples:
>>> response = ExportResponse(
>>> response = ExportBySceneResponse(
... owl_content="<?xml version='1.0'?>...",
... format="rdfxml",
... filename="medical_ontology.owl",
... scene_id=UUID("550e8400-e29b-41d4-a716-446655440000"),
... scene_name="医疗场景",
... classes_count=7
... )
"""
owl_content: str = Field(..., description="OWL文件内容")
format: str = Field(..., description="导出格式")
filename: str = Field(..., description="导出文件名(含扩展名)")
scene_id: UUID = Field(..., description="场景ID")
scene_name: str = Field(..., description="场景名称")
classes_count: int = Field(..., description="导出的类数量")
@@ -459,3 +463,56 @@ class ClassListResponse(BaseModel):
scene_name: str = Field(..., description="场景名称")
scene_description: Optional[str] = Field(None, description="场景描述")
items: List[ClassResponse] = Field(..., description="类型列表")
# ==================== OWL 导入相关 Schema ====================
class ImportOwlRequest(BaseModel):
"""OWL 文件导入请求
用于 POST /api/ontology/import 端点的请求体。
解析 OWL 文件并将类型直接导入到指定场景。
Attributes:
scene_id: 目标场景ID必填
owl_content: OWL 文件内容(字符串形式)
format: 文件格式可选值rdfxml默认、turtle
Examples:
>>> request = ImportOwlRequest(
... scene_id=UUID("550e8400-e29b-41d4-a716-446655440000"),
... owl_content="<?xml version='1.0'?>...",
... format="rdfxml"
... )
"""
scene_id: UUID = Field(..., description="目标场景ID")
owl_content: str = Field(..., min_length=1, description="OWL 文件内容")
format: str = Field("rdfxml", description="文件格式可选值rdfxml默认、turtle")
class ImportOwlResponse(BaseModel):
"""OWL 文件导入响应
用于返回导入结果。
Attributes:
scene_id: 场景ID
scene_name: 场景名称
imported_count: 成功导入的类型数量
skipped_count: 跳过的数量(重复)
items: 导入的类型列表
Examples:
>>> response = ImportOwlResponse(
... scene_id=UUID("..."),
... scene_name="智能制造场景",
... imported_count=4,
... skipped_count=0,
... items=[...]
... )
"""
scene_id: UUID = Field(..., description="场景ID")
scene_name: str = Field(..., description="场景名称")
imported_count: int = Field(..., description="成功导入的类型数量")
skipped_count: int = Field(0, description="跳过的数量(重复)")
items: List[ClassResponse] = Field(..., description="导入的类型列表")

View File

@@ -303,6 +303,8 @@ class MemoryConfigService:
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
# Ontology scene association
scene_id=memory_config.scene_id,
)
elapsed_ms = (time.time() - start_time) * 1000
@@ -476,6 +478,43 @@ class MemoryConfigService:
"pruning_threshold": memory_config.pruning_threshold,
}
def get_ontology_types(self, memory_config: MemoryConfig):
"""Fetch ontology types for the memory configuration's scene.
Args:
memory_config: MemoryConfig object containing scene_id
Returns:
OntologyTypeList if scene_id is valid and has types, None otherwise
"""
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
from app.repositories.ontology_class_repository import OntologyClassRepository
if not memory_config.scene_id:
logger.debug("No scene_id configured, skipping ontology type fetch")
return None
try:
ontology_repo = OntologyClassRepository(self.db)
ontology_classes = ontology_repo.get_by_scene(memory_config.scene_id)
if not ontology_classes:
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
return None
ontology_types = OntologyTypeList.from_db_models(ontology_classes)
logger.info(
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
)
return ontology_types
except Exception as e:
logger.warning(
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
exc_info=True
)
return None
def get_workspace_default_config(
self,
workspace_id: UUID

View File

@@ -280,12 +280,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
if not cid:
raise ValueError("未提供 payload.config_id禁止启动试运行")
# 验证 dialogue_text 必须提供
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN_STREAM] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
if not dialogue_text:
raise ValueError("试运行模式必须提供 dialogue_text 参数")
# Load configuration from database only using centralized manager
try:
config_service = MemoryConfigService(self.db)
@@ -297,6 +291,30 @@ class DataConfigService: # 数据配置服务类PostgreSQL
except ConfigurationError as e:
raise RuntimeError(f"Configuration loading failed: {e}")
# 根据是否关联本体场景选择使用的文本
# 如果配置关联了本体场景scene_id 不为空),使用 custom_text如果提供
# 否则使用 dialogue_text
if memory_config.scene_id:
# 关联了本体场景,优先使用 custom_text
if hasattr(payload, 'custom_text') and payload.custom_text:
dialogue_text = payload.custom_text.strip()
logger.info(f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}")
else:
# 如果没有提供 custom_text回退到 dialogue_text
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}")
else:
# 没有关联本体场景,使用 dialogue_text
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN_STREAM] No scene_id, using dialogue_text, length: {len(dialogue_text)}")
# 验证最终使用的文本不为空
if not dialogue_text:
raise ValueError("试运行模式必须提供有效的文本内容dialogue_text 或 custom_text")
logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}")
# 步骤 2: 创建进度回调函数捕获管线进度
# 使用队列在回调和生成器之间传递进度事件
progress_queue: asyncio.Queue = asyncio.Queue()

View File

@@ -14,7 +14,7 @@ from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.models.ontology_models import (
from app.core.memory.models.ontology_scenario_models import (
OntologyClass,
OntologyExtractionResponse,
)
@@ -49,6 +49,10 @@ class OntologyService:
DEFAULT_LLM_TIMEOUT = 30.0
DEFAULT_ENABLE_OWL_VALIDATION = True
# 从环境变量获取默认语言
from app.core.config import settings
DEFAULT_LANGUAGE = settings.DEFAULT_LANGUAGE
def __init__(
self,
llm_client: OpenAIClient,

View File

@@ -142,6 +142,20 @@ async def run_pilot_extraction(
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
)
# 加载本体类型(如果配置了 scene_id支持通用类型回退
ontology_types = None
try:
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_with_fallback
ontology_types = load_ontology_types_with_fallback(
scene_id=memory_config.scene_id,
workspace_id=memory_config.workspace_id,
db=db,
enable_general_fallback=True
)
except Exception as e:
logger.warning(f"Failed to load ontology types: {e}", exc_info=True)
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
@@ -150,6 +164,7 @@ async def run_pilot_extraction(
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
language=language,
ontology_types=ontology_types,
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)

View File

@@ -1697,114 +1697,6 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
# Long-term Memory Storage Tasks (Batched Write Strategies)
# =============================================================================
@celery_app.task(name="app.core.memory.agent.long_term_storage.window", bind=True)
def long_term_storage_window_task(
self,
end_user_id: str,
langchain_messages: List[Dict[str, Any]],
config_id: str,
scope: int = 6
) -> Dict[str, Any]:
"""Celery task for window-based long-term memory storage.
Accumulates messages in Redis buffer until window size (scope) is reached,
then writes batched messages to Neo4j.
Args:
end_user_id: End user identifier
langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
config_id: Memory configuration ID
scope: Window size (number of messages before triggering write)
Returns:
Dict containing task status and metadata
"""
from app.core.logging_config import get_logger
logger = get_logger(__name__)
logger.info(f"[LONG_TERM_WINDOW] Starting task - end_user_id={end_user_id}, scope={scope}")
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.memory.agent.langgraph_graph.routing.write_router import window_dialogue
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
from app.core.memory.agent.utils.redis_tool import write_store
from app.services.memory_config_service import MemoryConfigService
db = next(get_db())
try:
# Save to Redis buffer first
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# Get workspace_id from end_user for fallback
from app.models.app_model import App
from app.models.end_user_model import EndUser
workspace_id = None
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if end_user:
app = db.query(App).filter(App.id == end_user.app_id).first()
if app:
workspace_id = app.workspace_id
# Load memory config with workspace fallback
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="LongTermStorageTask"
)
# Execute window-based dialogue storage
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
return {"status": "SUCCESS", "strategy": "window", "scope": scope}
finally:
db.close()
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
logger.info(f"[LONG_TERM_WINDOW] Task completed - elapsed_time={elapsed_time:.2f}s")
return {
**result,
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"[LONG_TERM_WINDOW] Task failed - error={str(e)}", exc_info=True)
return {
"status": "FAILURE",
"strategy": "window",
"error": str(e),
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
# @celery_app.task(name="app.core.memory.agent.long_term_storage.time", bind=True)
# def long_term_storage_time_task(
# self,

View File

@@ -1,4 +1,32 @@
{
"v0.2.3": {
"introduction": {
"codeName": "归墟",
"releaseDate": "2026-2-6",
"upgradePosition": "🐻 稳定性与细节打磨版本,万流归墟,静水流深",
"coreUpgrades": [
"1. 智能与记忆 🧠<br>* 提示词工程模块:新增专用提示词工程能力<br>* 长短期记忆整合:增强短期与长期记忆生命周期管理<br>* 双语记忆支持:解决情景记忆、显性记忆的双语问题",
"2. 系统架构 ⚙️<br>* 反思任务调度器:新增 worker-periodic 容器<br>* 模型配置降级:记忆管理正确降级使用空间模型",
"3. 问题修复 🔧<br>* 工作流分享修复多轮对话产生多个conversation<br>* 流式输出修复chat结尾缺少end标记<br>* 实体详情:移除未知类型记忆<br>* 提示词模板路径修复jinja2路径解析错误<br>* 知识库字段strategy更名为retrieve_type<br>* 空间头像:优化频繁调用模型接口<br>* 记忆仪表盘修复end_users接口无返回",
"<br>",
"v0.2.4 将继续完善工作流代码执行功能,并推出本体工程+记忆配置入口。",
"记忆熊,记得更牢,用得更好。🐻✨"
]
},
"introduction_en": {
"codeName": "Settle",
"releaseDate": "2026-2-6",
"upgradePosition": "🐻 Stability and refinement release — still waters run deep",
"coreUpgrades": [
"1. Intelligence & Memory 🧠<br>* Prompt Engineering Module: New dedicated prompt engineering capabilities<br>* Long-term & Short-term Memory Integration: Enhanced memory lifecycle management<br>* Bilingual Memory Support: Resolved dual-language issues in episodic and explicit memory",
"2. System Architecture ⚙️<br>* Reflection Task Worker: Added worker-periodic container for scheduled tasks<br>* Model Configuration Fallback: Memory management properly falls back to workspace model",
"3. Bug Fixes 🔧<br>* Workflow Sharing: Fixed multiple conversations created during multi-turn dialogues<br>* Streaming Output: Resolved missing end marker in chat streaming<br>* Entity Details: Removed unknown type memories from All view<br>* Prompt Template Paths: Fixed jinja2 path resolution errors<br>* Knowledge Base Schema: Renamed strategy to retrieve_type<br>* Workspace Avatar: Optimized frequent model API calls<br>* Memory Dashboard: Fixed end_users endpoint empty responses",
"<br>",
"v0.2.4 will continue with workflow code execution enhancements and the ontology engineering + memory configuration portal.",
"MemoryBear — remember better, work smarter. 🐻✨"
]
}
},
"v0.2.2": {
"introduction": {
"codeName": "淬锋Temper",

View File

@@ -129,3 +129,9 @@ KB_image2text_id=
config_id=
reranker_id=
# 本体类型融合配置 (记得写入env_example)
GENERAL_ONTOLOGY_FILES=General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔
ENABLE_GENERAL_ONTOLOGY_TYPES=true # 总开关,控制是否启用通用本体类型融合功能(false = 不使用任何本体类型指导)
MAX_ONTOLOGY_TYPES_IN_PROMPT=100 # 限制传给 LLM 的类型数量,防止 Prompt 过长
CORE_GENERAL_TYPES=Person,Organization,Place,Event,Work,Concept # 定义核心类型列表,这些类型会优先包含在合并结果中
ONTOLOGY_EXPERIMENT_MODE=true # 是否允许通过 API 动态切换本体配置

View File

@@ -141,6 +141,7 @@ dependencies = [
"flower>=2.0.1",
"aiofiles>=23.0.0",
"owlready2>=0.46",
"rdflib>=7.0.0",
"lxml>=4.9.0",
"httpx>=0.28.0",
]

View File

@@ -39,6 +39,7 @@ python-multipart>=0.0.20
pyyaml==6.0.3
redis==6.4.0
rsa==4.9.1
rdflib>=6.0.0
six==1.17.0
sniffio==1.3.1
sqlalchemy==2.0.44

View File

@@ -0,0 +1,4 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/5 15:36

View File

@@ -0,0 +1,4 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/6 14:45

View File

@@ -0,0 +1,622 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/6
import pytest
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool, VariableSelector
# ==================== VariableSelector 测试 ====================
def test_variable_selector_from_string():
"""测试从字符串创建变量选择器"""
selector = VariableSelector.from_string("sys.message")
assert selector.namespace == "sys"
assert selector.key == "message"
assert selector.path == ["sys", "message"]
def test_variable_selector_from_list():
"""测试从列表创建变量选择器"""
selector = VariableSelector(["conv", "username"])
assert selector.namespace == "conv"
assert selector.key == "username"
assert str(selector) == "conv.username"
def test_variable_selector_empty_path():
"""测试空路径抛出异常"""
with pytest.raises(ValueError) as exc_info:
VariableSelector([])
assert "变量路径不能为空" in str(exc_info.value)
def test_variable_selector_single_element():
"""测试单元素路径"""
selector = VariableSelector(["sys"])
assert selector.namespace == "sys"
assert selector.key is None
# ==================== VariablePool 基础测试 ====================
@pytest.mark.asyncio
async def test_variable_pool_new_variable():
"""测试创建新变量"""
pool = VariablePool()
await pool.new("conv", "username", "Alice", VariableType.STRING, mut=True)
assert pool.has("conv.username")
assert pool.get_value("conv.username") == "Alice"
@pytest.mark.asyncio
async def test_variable_pool_new_multiple_variables():
"""测试创建多个变量"""
pool = VariablePool()
await pool.new("conv", "name", "Bob", VariableType.STRING, mut=True)
await pool.new("conv", "age", 25, VariableType.NUMBER, mut=True)
await pool.new("conv", "active", True, VariableType.BOOLEAN, mut=True)
assert pool.get_value("conv.name") == "Bob"
assert pool.get_value("conv.age") == 25
assert pool.get_value("conv.active") is True
@pytest.mark.asyncio
async def test_variable_pool_different_namespaces():
"""测试不同命名空间的变量"""
pool = VariablePool()
await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False)
await pool.new("conv", "message", "World", VariableType.STRING, mut=True)
await pool.new("node1", "output", "Result", VariableType.STRING, mut=False)
assert pool.get_value("sys.message") == "Hello"
assert pool.get_value("conv.message") == "World"
assert pool.get_value("node1.output") == "Result"
# ==================== get_value 测试 ====================
@pytest.mark.asyncio
async def test_get_value_with_template():
"""测试使用模板语法获取值"""
pool = VariablePool()
await pool.new("conv", "test", "value", VariableType.STRING, mut=True)
# 支持模板语法
assert pool.get_value("{{ conv.test }}") == "value"
assert pool.get_value("{{conv.test}}") == "value"
assert pool.get_value("{{ conv.test}}") == "value"
@pytest.mark.asyncio
async def test_get_value_not_exist_strict():
"""测试获取不存在的变量(严格模式)"""
pool = VariablePool()
with pytest.raises(KeyError) as exc_info:
pool.get_value("conv.nonexistent")
assert "not exist" in str(exc_info.value)
@pytest.mark.asyncio
async def test_get_value_not_exist_with_default():
"""测试获取不存在的变量(使用默认值)"""
pool = VariablePool()
result = pool.get_value("conv.nonexistent", default="default_value", strict=False)
assert result == "default_value"
@pytest.mark.asyncio
async def test_get_value_different_types():
"""测试获取不同类型的变量值"""
pool = VariablePool()
await pool.new("conv", "str", "text", VariableType.STRING, mut=True)
await pool.new("conv", "num", 42, VariableType.NUMBER, mut=True)
await pool.new("conv", "bool", False, VariableType.BOOLEAN, mut=True)
await pool.new("conv", "arr", [1, 2, 3], VariableType.ARRAY_NUMBER, mut=True)
await pool.new("conv", "obj", {"key": "value"}, VariableType.OBJECT, mut=True)
assert pool.get_value("conv.str") == "text"
assert pool.get_value("conv.num") == 42
assert pool.get_value("conv.bool") is False
assert pool.get_value("conv.arr") == [1, 2, 3]
assert pool.get_value("conv.obj") == {"key": "value"}
# ==================== set 测试 ====================
@pytest.mark.asyncio
async def test_set_mutable_variable():
"""测试设置可变变量"""
pool = VariablePool()
await pool.new("conv", "counter", 0, VariableType.NUMBER, mut=True)
await pool.set("conv.counter", 10)
assert pool.get_value("conv.counter") == 10
@pytest.mark.asyncio
async def test_set_immutable_variable():
"""测试设置不可变变量(应该失败)"""
pool = VariablePool()
await pool.new("sys", "message", "original", VariableType.STRING, mut=False)
with pytest.raises(KeyError) as exc_info:
await pool.set("sys.message", "modified")
assert "cannot be modified" in str(exc_info.value)
@pytest.mark.asyncio
async def test_set_nonexistent_variable():
"""测试设置不存在的变量"""
pool = VariablePool()
with pytest.raises(KeyError) as exc_info:
await pool.set("conv.nonexistent", "value")
assert "is not defined" in str(exc_info.value)
@pytest.mark.asyncio
async def test_set_multiple_times():
"""测试多次设置变量"""
pool = VariablePool()
await pool.new("conv", "value", "first", VariableType.STRING, mut=True)
await pool.set("conv.value", "second")
await pool.set("conv.value", "third")
assert pool.get_value("conv.value") == "third"
# ==================== has 测试 ====================
@pytest.mark.asyncio
async def test_has_existing_variable():
"""测试检查存在的变量"""
pool = VariablePool()
await pool.new("conv", "test", "value", VariableType.STRING, mut=True)
assert pool.has("conv.test") is True
@pytest.mark.asyncio
async def test_has_nonexistent_variable():
"""测试检查不存在的变量"""
pool = VariablePool()
assert pool.has("conv.nonexistent") is False
# ==================== get_literal 测试 ====================
@pytest.mark.asyncio
async def test_get_literal():
"""测试获取变量的字面量表示"""
pool = VariablePool()
await pool.new("conv", "num", 42, VariableType.NUMBER, mut=True)
literal = pool.get_literal("conv.num")
assert isinstance(literal, str)
# ==================== 命名空间操作测试 ====================
@pytest.mark.asyncio
async def test_get_all_system_vars():
"""测试获取所有系统变量"""
pool = VariablePool()
await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False)
await pool.new("sys", "user_id", "user123", VariableType.STRING, mut=False)
await pool.new("conv", "other", "value", VariableType.STRING, mut=True)
sys_vars = pool.get_all_system_vars()
assert "message" in sys_vars
assert "user_id" in sys_vars
assert "other" not in sys_vars
assert sys_vars["message"] == "Hello"
assert sys_vars["user_id"] == "user123"
@pytest.mark.asyncio
async def test_get_all_conversation_vars():
"""测试获取所有会话变量"""
pool = VariablePool()
await pool.new("conv", "username", "Alice", VariableType.STRING, mut=True)
await pool.new("conv", "score", 100, VariableType.NUMBER, mut=True)
await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False)
conv_vars = pool.get_all_conversation_vars()
assert "username" in conv_vars
assert "score" in conv_vars
assert "message" not in conv_vars
assert conv_vars["username"] == "Alice"
assert conv_vars["score"] == 100
@pytest.mark.asyncio
async def test_get_all_node_outputs():
"""测试获取所有节点输出"""
pool = VariablePool()
await pool.new("node1", "output", "result1", VariableType.STRING, mut=False)
await pool.new("node2", "output", "result2", VariableType.STRING, mut=False)
await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False)
await pool.new("conv", "var", "value", VariableType.STRING, mut=True)
node_outputs = pool.get_all_node_outputs()
assert "node1" in node_outputs
assert "node2" in node_outputs
assert "sys" not in node_outputs
assert "conv" not in node_outputs
assert node_outputs["node1"]["output"] == "result1"
assert node_outputs["node2"]["output"] == "result2"
@pytest.mark.asyncio
async def test_get_node_output():
"""测试获取指定节点的输出"""
pool = VariablePool()
await pool.new("node1", "output", "result", VariableType.STRING, mut=False)
await pool.new("node1", "status", "success", VariableType.STRING, mut=False)
node_output = pool.get_node_output("node1")
assert node_output["output"] == "result"
assert node_output["status"] == "success"
@pytest.mark.asyncio
async def test_get_node_output_not_exist_strict():
"""测试获取不存在的节点输出(严格模式)"""
pool = VariablePool()
with pytest.raises(KeyError) as exc_info:
pool.get_node_output("nonexistent_node")
assert "output not exist" in str(exc_info.value)
@pytest.mark.asyncio
async def test_get_node_output_not_exist_with_default():
"""测试获取不存在的节点输出(使用默认值)"""
pool = VariablePool()
result = pool.get_node_output("nonexistent_node", defalut=None, strict=False)
assert result is None
# ==================== 复杂场景测试 ====================
@pytest.mark.asyncio
async def test_variable_pool_new_existing_mutable():
"""测试创建已存在的可变变量(应该更新值)"""
pool = VariablePool()
await pool.new("conv", "counter", 0, VariableType.NUMBER, mut=True)
await pool.new("conv", "counter", 10, VariableType.NUMBER, mut=True)
assert pool.get_value("conv.counter") == 10
@pytest.mark.asyncio
async def test_variable_pool_new_existing_immutable():
"""测试创建已存在的不可变变量(应该为新值)"""
pool = VariablePool()
await pool.new("sys", "message", "original", VariableType.STRING, mut=False)
await pool.new("sys", "message", "modified", VariableType.STRING, mut=False)
# 不可变变量被更新
assert pool.get_value("sys.message") == "modified"
@pytest.mark.asyncio
async def test_variable_pool_zero_and_false_values():
"""测试零值和 False 值"""
pool = VariablePool()
await pool.new("conv", "zero", 0, VariableType.NUMBER, mut=True)
await pool.new("conv", "false", False, VariableType.BOOLEAN, mut=True)
await pool.new("conv", "empty_str", "", VariableType.STRING, mut=True)
await pool.new("conv", "empty_arr", [], VariableType.ARRAY_NUMBER, mut=True)
await pool.new("conv", "empty_obj", {}, VariableType.OBJECT, mut=True)
assert pool.get_value("conv.zero") == 0
assert pool.get_value("conv.false") is False
assert pool.get_value("conv.empty_str") == ""
assert pool.get_value("conv.empty_arr") == []
assert pool.get_value("conv.empty_obj") == {}
@pytest.mark.asyncio
async def test_variable_pool_nested_objects():
"""测试嵌套对象"""
pool = VariablePool()
nested_obj = {
"user": {
"name": "Alice",
"age": 25,
"address": {
"city": "Beijing"
}
},
"items": [1, 2, 3]
}
await pool.new("conv", "data", nested_obj, VariableType.OBJECT, mut=True)
result = pool.get_value("conv.data")
assert result["user"]["name"] == "Alice"
assert result["user"]["address"]["city"] == "Beijing"
assert result["items"] == [1, 2, 3]
@pytest.mark.asyncio
async def test_variable_pool_array_of_objects():
"""测试对象数组"""
pool = VariablePool()
users = [
{"name": "Alice", "age": 25},
{"name": "Bob", "age": 30}
]
await pool.new("conv", "users", users, VariableType.ARRAY_OBJECT, mut=True)
result = pool.get_value("conv.users")
assert len(result) == 2
assert result[0]["name"] == "Alice"
assert result[1]["age"] == 30
@pytest.mark.asyncio
async def test_variable_pool_to_dict():
"""测试导出为字典"""
pool = VariablePool()
await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False)
await pool.new("conv", "username", "Alice", VariableType.STRING, mut=True)
await pool.new("node1", "output", "result", VariableType.STRING, mut=False)
result = pool.to_dict()
assert "system" in result
assert "conversation" in result
assert "nodes" in result
assert result["system"]["message"] == "Hello"
assert result["conversation"]["username"] == "Alice"
assert result["nodes"]["node1"]["output"] == "result"
@pytest.mark.asyncio
async def test_variable_pool_copy():
"""测试复制变量池"""
pool1 = VariablePool()
await pool1.new("conv", "test", "value", VariableType.STRING, mut=True)
pool2 = VariablePool()
pool2.copy(pool1)
assert pool2.get_value("conv.test") == "value"
# 修改 pool2 不应影响 pool1
await pool2.set("conv.test", "modified")
assert pool2.get_value("conv.test") == "modified"
assert pool1.get_value("conv.test") == "value"
@pytest.mark.asyncio
async def test_variable_pool_repr():
"""测试字符串表示"""
pool = VariablePool()
await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False)
await pool.new("conv", "username", "Alice", VariableType.STRING, mut=True)
await pool.new("node1", "output", "result", VariableType.STRING, mut=False)
repr_str = repr(pool)
assert "VariablePool" in repr_str
assert "system_vars=1" in repr_str
assert "conversation_vars=1" in repr_str
assert "runtime_vars=1" in repr_str
# ==================== 并发测试 ====================
@pytest.mark.asyncio
async def test_variable_pool_concurrent_set():
"""测试并发设置变量"""
import asyncio
pool = VariablePool()
await pool.new("conv", "counter", 0, VariableType.NUMBER, mut=True)
async def increment():
for _ in range(100):
current = pool.get_value("conv.counter")
await pool.set("conv.counter", current + 1)
# 并发执行多个增量操作
await asyncio.gather(increment(), increment())
# 由于有锁保护,最终值应该是 200
assert pool.get_value("conv.counter") == 200
# ==================== 边界情况测试 ====================
@pytest.mark.asyncio
async def test_variable_pool_empty():
"""测试空变量池"""
pool = VariablePool()
assert pool.get_all_system_vars() == {}
assert pool.get_all_conversation_vars() == {}
assert pool.get_all_node_outputs() == {}
@pytest.mark.asyncio
async def test_variable_selector_invalid():
"""测试无效的变量选择器"""
pool = VariablePool()
await pool.new("conv", "test", "value", VariableType.STRING, mut=True)
# 选择器格式错误
with pytest.raises(ValueError):
pool.get_value("conv.test.extra")
@pytest.mark.asyncio
async def test_variable_pool_special_characters():
"""测试包含特殊字符的变量名"""
pool = VariablePool()
# 变量名可以包含下划线、数字等
await pool.new("conv", "user_name_123", "Alice", VariableType.STRING, mut=True)
await pool.new("node_1", "output_data", "result", VariableType.STRING, mut=False)
assert pool.get_value("conv.user_name_123") == "Alice"
assert pool.get_value("node_1.output_data") == "result"
@pytest.mark.asyncio
async def test_variable_pool_large_data():
"""测试大数据量"""
pool = VariablePool()
# 创建大量变量
for i in range(100):
await pool.new("conv", f"var_{i}", i, VariableType.NUMBER, mut=True)
# 验证所有变量都存在
for i in range(100):
assert pool.get_value(f"conv.var_{i}") == i
conv_vars = pool.get_all_conversation_vars()
assert len(conv_vars) == 100
@pytest.mark.asyncio
async def test_variable_pool_different_types_same_name():
"""测试不同命名空间中相同名称的变量"""
pool = VariablePool()
await pool.new("sys", "value", "system", VariableType.STRING, mut=False)
await pool.new("conv", "value", "conversation", VariableType.STRING, mut=True)
await pool.new("node1", "value", "node", VariableType.STRING, mut=False)
assert pool.get_value("sys.value") == "system"
assert pool.get_value("conv.value") == "conversation"
assert pool.get_value("node1.value") == "node"
@pytest.mark.asyncio
async def test_variable_pool_update_type():
"""测试更新变量类型"""
pool = VariablePool()
# 创建字符串变量
await pool.new("conv", "data", "text", VariableType.STRING, mut=True)
assert pool.get_value("conv.data") == "text"
# 更新为数字类型变量类型不可变
with pytest.raises(TypeError):
await pool.new("conv", "data", 123, VariableType.NUMBER, mut=True)
assert pool.get_value("conv.data") == "text"
@pytest.mark.asyncio
async def test_variable_pool_array_types():
"""测试不同类型的数组"""
pool = VariablePool()
await pool.new("conv", "arr_str", ["a", "b", "c"], VariableType.ARRAY_STRING, mut=True)
await pool.new("conv", "arr_num", [1, 2, 3], VariableType.ARRAY_NUMBER, mut=True)
await pool.new("conv", "arr_bool", [True, False], VariableType.ARRAY_BOOLEAN, mut=True)
await pool.new("conv", "arr_obj", [{"id": 1}, {"id": 2}], VariableType.ARRAY_OBJECT, mut=True)
assert pool.get_value("conv.arr_str") == ["a", "b", "c"]
assert pool.get_value("conv.arr_num") == [1, 2, 3]
assert pool.get_value("conv.arr_bool") == [True, False]
assert pool.get_value("conv.arr_obj") == [{"id": 1}, {"id": 2}]
@pytest.mark.asyncio
async def test_variable_pool_namespace_isolation():
"""测试命名空间隔离"""
pool = VariablePool()
# 在不同命名空间创建变量
await pool.new("sys", "var1", "sys_value", VariableType.STRING, mut=False)
await pool.new("conv", "var2", "conv_value", VariableType.STRING, mut=True)
await pool.new("node1", "var3", "node_value", VariableType.STRING, mut=False)
# 获取各命名空间的变量
sys_vars = pool.get_all_system_vars()
conv_vars = pool.get_all_conversation_vars()
node_outputs = pool.get_all_node_outputs()
# 验证隔离性
assert "var1" in sys_vars and "var2" not in sys_vars and "var3" not in sys_vars
assert "var2" in conv_vars and "var1" not in conv_vars and "var3" not in conv_vars
assert "node1" in node_outputs and "var3" in node_outputs["node1"]
@pytest.mark.asyncio
async def test_variable_pool_mutability_rules():
"""测试可变性规则"""
pool = VariablePool()
# 系统变量应该是不可变的
await pool.new("sys", "immutable", "value", VariableType.STRING, mut=False)
with pytest.raises(KeyError):
await pool.set("sys.immutable", "new_value")
# 会话变量应该是可变的
await pool.new("conv", "mutable", "value", VariableType.STRING, mut=True)
await pool.set("conv.mutable", "new_value")
assert pool.get_value("conv.mutable") == "new_value"
# 节点输出应该是不可变的
await pool.new("node1", "output", "value", VariableType.STRING, mut=False)
with pytest.raises(KeyError):
await pool.set("node1.output", "new_value")
@pytest.mark.asyncio
async def test_variable_pool_template_variations():
"""测试模板语法的各种变体"""
pool = VariablePool()
await pool.new("conv", "test", "value", VariableType.STRING, mut=True)
# 各种模板格式都应该工作
assert pool.get_value("{{conv.test}}") == "value"
assert pool.get_value("{{ conv.test }}") == "value"
assert pool.get_value("{{ conv.test }}") == "value"
assert pool.get_value("{{ conv.test}}") == "value"
assert pool.get_value("{{conv.test }}") == "value"

View File

@@ -0,0 +1,4 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/6 14:43

View File

@@ -0,0 +1,77 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/5 18:19
import os
import pytest
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.workflow.variable_pool import VariablePool
TEST_WORKSPACE_ID = "test_workspace_id"
TEST_USER_ID = "test_user_id"
TEST_EXECUTION_ID = "test_execution_id"
TEST_CONVERSATION_ID = "test_conversation_id"
TEST_MODEL_ID = "" or os.getenv("TEST_MODEL_ID")
TEST_FILE = {
"type": "image",
"url": "https://inews.gtimg.com/om_bt/Ojy0PdDIWWXRTAMh2QjsiumDZh-D1x7qCkDSmoaaX6INAAA/641",
"__file": True
}
INPUT_DATA = {
"message": "",
"variables": [],
"conversation_id": TEST_CONVERSATION_ID,
"files": [TEST_FILE]
}
@pytest.fixture(scope="session", autouse=True)
def global_precheck():
assert bool(TEST_MODEL_ID) is True, 'PLASE SET TEST_MODEL_ID FIRST'
def simple_state():
return {
"messages": [{"role": "user", "content": "123456"}],
"node_outputs": {},
"execution_id": TEST_EXECUTION_ID,
"workspace_id": TEST_WORKSPACE_ID,
"user_id": TEST_USER_ID,
"error": None,
"error_node": None,
"cycle_nodes": [], # loop, iteration node id
"looping": 0, # loop runing flag, only use in loop node,not use in main loop
"activate": {}
}
async def simple_vairable_pool(message):
# Initialize system variables (sys namespace)
variable_pool = VariablePool()
user_message = message
user_files = INPUT_DATA.get("files") or []
# Initialize system variables (sys namespace)
input_variables = INPUT_DATA.get("variables") or {}
sys_vars = {
"message": (user_message, VariableType.STRING),
"conversation_id": (INPUT_DATA.get("conversation_id"), VariableType.STRING),
"execution_id": (TEST_EXECUTION_ID, VariableType.STRING),
"workspace_id": (TEST_WORKSPACE_ID, VariableType.STRING),
"user_id": (TEST_USER_ID, VariableType.STRING),
"input_variables": (input_variables, VariableType.OBJECT),
"files": (user_files, VariableType.ARRAY_FILE)
}
for key, var_def in sys_vars.items():
value = var_def[0]
var_type = var_def[1]
await variable_pool.new(
namespace='sys',
key=key,
value=value,
var_type=VariableType(var_type),
mut=False
)
return variable_pool

View File

@@ -0,0 +1,834 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/5 18:54
import pytest
from app.core.workflow.nodes import AssignerNode
from app.core.workflow.variable.base_variable import VariableType
from tests.workflow.nodes.base import simple_state, simple_vairable_pool
@pytest.mark.asyncio
async def test_assigner_number_add():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "add",
"value": 3
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == 4
@pytest.mark.asyncio
async def test_assigner_number_subtract():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "subtract",
"value": 3
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == -2
@pytest.mark.asyncio
async def test_assigner_number_multiply():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", 2, VariableType.NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "multiply",
"value": 3
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == 6
@pytest.mark.asyncio
async def test_assigner_number_divide():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", 6, VariableType.NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "divide",
"value": 2
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == 3
@pytest.mark.asyncio
async def test_assigner_number_assign():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "test1", 4, VariableType.NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "assign",
"value": "{{conv.test1}}"
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == 4
@pytest.mark.asyncio
async def test_assigner_number_cover():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "cover",
"value": 4
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == 4
@pytest.mark.asyncio
async def test_assigner_number_clear():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "clear",
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == 0
@pytest.mark.asyncio
async def test_assigner_number_append():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
with pytest.raises(AttributeError) as exc_info:
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "append",
"value": 3
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert "'NumberOperator' object has no attribute 'append'" in str(exc_info.value)
@pytest.mark.asyncio
async def test_assigner_number_remove_last():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
with pytest.raises(AttributeError) as exc_info:
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "remove_last"
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert "'NumberOperator' object has no attribute 'remove_last'" in str(exc_info.value)
@pytest.mark.asyncio
async def test_assigner_number_remove_first():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
with pytest.raises(AttributeError) as exc_info:
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "remove_first"
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert "'NumberOperator' object has no attribute 'remove_first'" in str(exc_info.value)
@pytest.mark.asyncio
async def test_assigner_array_append():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "append",
"value": 3
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == [1, 2, 3]
@pytest.mark.asyncio
async def test_assigner_array_remove_last():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "remove_last"
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == [1]
@pytest.mark.asyncio
async def test_assigner_array_remove_first():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "remove_first"
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == [2]
# String tests
@pytest.mark.asyncio
async def test_assigner_string_assign():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "assign",
"value": "world"
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == "world"
@pytest.mark.asyncio
async def test_assigner_string_cover():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "cover",
"value": "world"
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == "world"
@pytest.mark.asyncio
async def test_assigner_string_clear():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "clear"
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == ""
@pytest.mark.asyncio
async def test_assigner_string_invalid_operation():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "add",
"value": "world"
}
]
}
}
with pytest.raises(AttributeError) as exc_info:
await AssignerNode(config, {}).execute(state, variable_pool)
assert "'StringOperator' object has no attribute 'add'" in str(exc_info.value)
# Boolean tests
@pytest.mark.asyncio
async def test_assigner_boolean_assign():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", True, VariableType.BOOLEAN, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "assign",
"value": False
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") is False
@pytest.mark.asyncio
async def test_assigner_boolean_cover():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", False, VariableType.BOOLEAN, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "cover",
"value": True
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") is True
@pytest.mark.asyncio
async def test_assigner_boolean_clear():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", True, VariableType.BOOLEAN, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "clear"
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") is False
# Object tests
@pytest.mark.asyncio
async def test_assigner_object_assign():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", {"key": "value"}, VariableType.OBJECT, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "assign",
"value": {"new_key": "new_value"}
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == {"new_key": "new_value"}
@pytest.mark.asyncio
async def test_assigner_object_cover():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", {"key": "value"}, VariableType.OBJECT, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "cover",
"value": {"new_key": "new_value"}
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == {"new_key": "new_value"}
@pytest.mark.asyncio
async def test_assigner_object_clear():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", {"key": "value"}, VariableType.OBJECT, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "clear"
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == {}
# Array string tests
@pytest.mark.asyncio
async def test_assigner_array_string_append():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", ["a", "b"], VariableType.ARRAY_STRING, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "append",
"value": "c"
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == ["a", "b", "c"]
@pytest.mark.asyncio
async def test_assigner_array_string_clear():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", ["a", "b"], VariableType.ARRAY_STRING, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "clear"
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == []
@pytest.mark.asyncio
async def test_assigner_array_object_append():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", [{"id": 1}], VariableType.ARRAY_OBJECT, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "append",
"value": {"id": 2}
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == [{"id": 1}, {"id": 2}]
@pytest.mark.asyncio
async def test_assigner_array_assign():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "assign",
"value": [3, 4, 5]
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == [3, 4, 5]
@pytest.mark.asyncio
async def test_assigner_array_cover():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "cover",
"value": [3, 4, 5]
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == [3, 4, 5]
# Multiple assignments test
@pytest.mark.asyncio
async def test_assigner_multiple_assignments():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test1", 10, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "test2", "hello", VariableType.STRING, mut=True)
await variable_pool.new("conv", "test3", [1, 2], VariableType.ARRAY_NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test1}}",
"operation": "add",
"value": 5
},
{
"variable_selector": "{{conv.test2}}",
"operation": "assign",
"value": "world"
},
{
"variable_selector": "{{conv.test3}}",
"operation": "append",
"value": 3
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test1") == 15
assert variable_pool.get_value("conv.test2") == "world"
assert variable_pool.get_value("conv.test3") == [1, 2, 3]
# Variable reference test
@pytest.mark.asyncio
async def test_assigner_variable_reference():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "source", 100, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "target", 0, VariableType.NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.target}}",
"operation": "assign",
"value": "{{conv.source}}"
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.target") == 100
# Edge cases
@pytest.mark.asyncio
async def test_assigner_divide_by_zero():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", 10, VariableType.NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "divide",
"value": 0
}
]
}
}
with pytest.raises(ZeroDivisionError):
await AssignerNode(config, {}).execute(state, variable_pool)
@pytest.mark.asyncio
async def test_assigner_invalid_namespace():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("sys", "test", 10, VariableType.NUMBER, mut=False)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{sys.test}}",
"operation": "add",
"value": 5
}
]
}
}
with pytest.raises(ValueError) as exc_info:
await AssignerNode(config, {}).execute(state, variable_pool)
assert "Only conversation or cycle variables can be assigned" in str(exc_info.value)
@pytest.mark.asyncio
async def test_assigner_empty_array_operations():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", [], VariableType.ARRAY_NUMBER, mut=True)
# Test append on empty array
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "append",
"value": 1
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == [1]
@pytest.mark.asyncio
async def test_assigner_remove_from_single_element_array():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", [1], VariableType.ARRAY_NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "remove_last"
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == []
@pytest.mark.asyncio
async def test_assigner_float_operations():
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "test", 10.5, VariableType.NUMBER, mut=True)
config = {
"id": "assigner_test",
"type": "assigner",
"name": "赋值测试节点",
"config": {
"assignments": [
{
"variable_selector": "{{conv.test}}",
"operation": "multiply",
"value": 2.0
}
]
}
}
await AssignerNode(config, {}).execute(state, variable_pool)
assert variable_pool.get_value("conv.test") == 21.0

View File

@@ -0,0 +1,23 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/5 19:15
import pytest
from app.core.workflow.nodes.breaker import BreakNode
from tests.workflow.nodes.base import simple_state, simple_vairable_pool
@pytest.mark.asyncio
async def test_loop_breaker():
node_config = {
"id": "breaker_test",
"type": "breaker",
"name": "breaker",
"config": {
}
}
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await BreakNode(node_config, {}).execute(state, variable_pool)
assert state["looping"] == 2

View File

@@ -0,0 +1,279 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/6 09:59
import pytest
from app.core.workflow.nodes.code import CodeNode
from app.core.workflow.variable.base_variable import VariableType
from tests.workflow.nodes.base import simple_state, simple_vairable_pool
@pytest.mark.asyncio
async def test_code_python_complex_output():
node_config = {
"id": "code_test",
"type": "code",
"name": "代码执行",
"config": {
"code": "ZGVmJTIwbWFpbih4JTJDJTIweSklM0ElMEElMjAlMjAlMjAlMjByZXR1cm4lMjAlN0IlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJudW1iZXIlMjIlM0ElMjB4JTIwJTJCJTIweSUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMnN0cmluZyUyMiUzQSUyMHN0cih4JTIwJTJCJTIweSklMkMlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJib29sZWFuJTIyJTNBJTIwYm9vbCh4JTIwJTJCJTIweSklMkMlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJkaWN0JTIyJTNBJTIwJTdCJTIyc3VtJTIyJTNBJTIweCUyMCUyQiUyMHklN0QlMkMlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJhcnJheV9zdHJpbmclMjIlM0ElMjAlNUJzdHIoeCUyMCUyQiUyMHkpJTVEJTJDJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIyYXJyYXlfbnVtYmVyJTIyJTNBJTIwJTVCeCUyMCUyQiUyMHklNUQlMkMlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJhcnJheV9vYmplY3QlMjIlM0ElMjAlNUIlN0IlMjJzdW0lMjIlM0ElMjB4JTIwJTJCJTIweSU3RCU1RCUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMmFycmF5X2Jvb2xlYW4lMjIlM0ElMjAlNUJib29sKHglMjAlMkIlMjB5KSU1RCUwQSUyMCUyMCUyMCUyMCU3RA==",
"language": "python3",
"input_variables": [
{
"name": "x",
"variable": "{{conv.x}}"
},
{
"name": "y",
"variable": "{{conv.y}}"
}
],
"output_variables": [
{
"name": "number",
"type": VariableType.NUMBER
},
{
"name": "string",
"type": VariableType.STRING
},
{
"name": "boolean",
"type": VariableType.BOOLEAN
},
{
"name": "dict",
"type": VariableType.OBJECT
},
{
"name": "array_string",
"type": VariableType.ARRAY_STRING
},
{
"name": "array_number",
"type": VariableType.ARRAY_NUMBER
},
{
"name": "array_object",
"type": VariableType.ARRAY_OBJECT
},
{
"name": "array_boolean",
"type": VariableType.ARRAY_BOOLEAN
},
]
}
}
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True)
result = await CodeNode(node_config, {}).execute(state, variable_pool)
assert result == {'number': 3, 'string': '3', 'boolean': True, 'dict': {'sum': 3}, 'array_string': ['3'],
'array_number': [3], 'array_object': [{'sum': 3}], 'array_boolean': [True]}
@pytest.mark.asyncio
async def test_code_javascript_complex_output():
node_config = {
"id": "code_test",
"type": "code",
"name": "代码执行",
"config": {
"code": "ZnVuY3Rpb24gbWFpbih7eCwgeX0pIHsKICBjb25zdCBzdW0gPSB4ICsgeTsKCiAgcmV0dXJuIHsKICAgIG51bWJlcjogc3VtLAogICAgc3RyaW5nOiBTdHJpbmcoc3VtKSwKICAgIGJvb2xlYW46IEJvb2xlYW4oc3VtKSwKICAgIGRpY3Q6IHsgc3VtIH0sCiAgICBhcnJheV9zdHJpbmc6IFtTdHJpbmcoc3VtKV0sCiAgICBhcnJheV9udW1iZXI6IFtzdW1dLAogICAgYXJyYXlfb2JqZWN0OiBbeyBzdW0gfV0sCiAgICBhcnJheV9ib29sZWFuOiBbQm9vbGVhbihzdW0pXSwKICB9Owp9",
"language": "javascript",
"input_variables": [
{
"name": "x",
"variable": "{{conv.x}}"
},
{
"name": "y",
"variable": "{{conv.y}}"
}
],
"output_variables": [
{
"name": "number",
"type": VariableType.NUMBER
},
{
"name": "string",
"type": VariableType.STRING
},
{
"name": "boolean",
"type": VariableType.BOOLEAN
},
{
"name": "dict",
"type": VariableType.OBJECT
},
{
"name": "array_string",
"type": VariableType.ARRAY_STRING
},
{
"name": "array_number",
"type": VariableType.ARRAY_NUMBER
},
{
"name": "array_object",
"type": VariableType.ARRAY_OBJECT
},
{
"name": "array_boolean",
"type": VariableType.ARRAY_BOOLEAN
},
]
}
}
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True)
result = await CodeNode(node_config, {}).execute(state, variable_pool)
assert result == {'number': 3, 'string': '3', 'boolean': True, 'dict': {'sum': 3}, 'array_string': ['3'],
'array_number': [3], 'array_object': [{'sum': 3}], 'array_boolean': [True]}
@pytest.mark.asyncio
async def test_code_python_operation_permissions():
node_config = {
"id": "code_test",
"type": "code",
"name": "代码执行",
"config": {
"code": "ZGVmJTIwbWFpbih4JTJDJTIweSklM0ElMEElMjAlMjAlMjAlMjBpbXBvcnQlMjBvcyUwQSUyMCUyMCUyMCUyMG9zLmdldGN3ZCgpJTBBJTIwJTIwJTIwJTIwcmV0dXJuJTIwJTdCJTIycmVzdWx0JTIyJTNBJTIweCUyMCUyQiUyMHklN0QlMEE=",
"language": "python3",
"input_variables": [
{
"name": "x",
"variable": "{{conv.x}}"
},
{
"name": "y",
"variable": "{{conv.y}}"
}
],
"output_variables": [
{
"name": "result",
"type": "number"
}
]
}
}
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True)
with pytest.raises(RuntimeError, match="Operation not permitted"):
await CodeNode(node_config, {}).execute(state, variable_pool)
@pytest.mark.asyncio
async def test_code_javascript_operation_permissions():
node_config = {
"id": "code_test",
"type": "code",
"name": "代码执行",
"config": {
"code": "Y29uc29sZS5sb2cocHJvY2Vzcy5nZXRldWlkKCkpOw==",
"language": "javascript",
"input_variables": [
{
"name": "x",
"variable": "{{conv.x}}"
},
{
"name": "y",
"variable": "{{conv.y}}"
}
],
"output_variables": [
{
"name": "result",
"type": "number"
}
]
}
}
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True)
with pytest.raises(RuntimeError, match="Operation not permitted"):
await CodeNode(node_config, {}).execute(state, variable_pool)
@pytest.mark.asyncio
async def test_code_python_run_error():
node_config = {
"id": "code_test",
"type": "code",
"name": "代码执行",
"config": {
"code": "ZGVmJTIwbWFpbih4JTJDJTIweSUzQSUwQSUyMCUyMCUyMCUyMHJldHVybiUyMCU3QiUyMnJlc3VsdCUyMiUzQSUyMHglMjAlMkIlMjB5JTdEJTBB",
"language": "python3",
"input_variables": [
{
"name": "x",
"variable": "{{conv.x}}"
},
{
"name": "y",
"variable": "{{conv.y}}"
}
],
"output_variables": [
{
"name": "result",
"type": "number"
}
]
}
}
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True)
with pytest.raises(Exception) as exc_info:
await CodeNode(node_config, {}).execute(state, variable_pool)
assert "'(' was never closed" in str(exc_info.value)
@pytest.mark.asyncio
async def test_code_javascript_run_error():
node_config = {
"id": "code_test",
"type": "code",
"name": "代码执行",
"config": {
"code": "Y29uc29sZS5sb2co",
"language": "javascript",
"input_variables": [
{
"name": "x",
"variable": "{{conv.x}}"
},
{
"name": "y",
"variable": "{{conv.y}}"
}
],
"output_variables": [
{
"name": "result",
"type": "number"
}
]
}
}
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True)
with pytest.raises(Exception) as exc_info:
await CodeNode(node_config, {}).execute(state, variable_pool)
assert "SyntaxError" in str(exc_info.value)

View File

@@ -0,0 +1,42 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/6 12:22
import pytest
from app.core.workflow.nodes import EndNode
from app.core.workflow.variable.base_variable import VariableType
from tests.workflow.nodes.base import simple_state, simple_vairable_pool
@pytest.mark.asyncio
async def test_end_output():
node_config = {
"id": "end_test",
"type": "end",
"name": "end",
"config": {
"output": "{{conv.x}}{{sys.message}}"
}
}
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True)
result = await EndNode(node_config, {}).execute(state, variable_pool)
assert result == "1test"
@pytest.mark.asyncio
async def test_end_output_miss():
node_config = {
"id": "end_test",
"type": "end",
"name": "end",
"config": {
"output": "{{conv.x}}{{sys.message}}"
}
}
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await EndNode(node_config, {}).execute(state, variable_pool)
assert result == "test"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,889 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/6
import pytest
from app.core.workflow.nodes import JinjaRenderNode
from app.core.workflow.variable.base_variable import VariableType
from tests.workflow.nodes.base import simple_state, simple_vairable_pool
# 基础模板渲染配置
SIMPLE_TEMPLATE_CONFIG = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "Hello, {{ name }}!",
"mapping": [
{
"name": "name",
"value": "conv.username"
}
]
}
}
# 多变量模板配置
MULTI_VARIABLE_CONFIG = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ greeting }}, {{ name }}! You are {{ age }} years old.",
"mapping": [
{
"name": "greeting",
"value": "conv.greeting"
},
{
"name": "name",
"value": "conv.name"
},
{
"name": "age",
"value": "conv.age"
}
]
}
}
# 条件渲染配置
CONDITIONAL_TEMPLATE_CONFIG = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{% if is_admin %}Admin{% else %}User{% endif %}",
"mapping": [
{
"name": "is_admin",
"value": "conv.is_admin"
}
]
}
}
# 循环渲染配置
LOOP_TEMPLATE_CONFIG = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}",
"mapping": [
{
"name": "items",
"value": "conv.items"
}
]
}
}
# 过滤器配置
FILTER_TEMPLATE_CONFIG = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ text | upper }}",
"mapping": [
{
"name": "text",
"value": "conv.text"
}
]
}
}
# 对象属性访问配置
OBJECT_TEMPLATE_CONFIG = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "Name: {{ user.name }}, Age: {{ user.age }}",
"mapping": [
{
"name": "user",
"value": "conv.user"
}
]
}
}
# 数学运算配置
MATH_TEMPLATE_CONFIG = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ a }} + {{ b }} = {{ a + b }}",
"mapping": [
{
"name": "a",
"value": "conv.a"
},
{
"name": "b",
"value": "conv.b"
}
]
}
}
# 默认值配置
DEFAULT_VALUE_CONFIG = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ name | default('Guest') }}",
"mapping": [
{
"name": "name",
"value": "conv.name"
}
]
}
}
# ==================== 基础模板渲染测试 ====================
@pytest.mark.asyncio
async def test_jinja_simple_template():
"""测试简单模板渲染"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "username", "Alice", VariableType.STRING, mut=True)
result = await JinjaRenderNode(SIMPLE_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
assert result == "Hello, Alice!"
@pytest.mark.asyncio
async def test_jinja_multi_variable():
"""测试多变量模板渲染"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "greeting", "Hi", VariableType.STRING, mut=True)
await variable_pool.new("conv", "name", "Bob", VariableType.STRING, mut=True)
await variable_pool.new("conv", "age", 25, VariableType.NUMBER, mut=True)
result = await JinjaRenderNode(MULTI_VARIABLE_CONFIG, {}).execute(state, variable_pool)
assert result == "Hi, Bob! You are 25 years old."
# ==================== 条件渲染测试 ====================
@pytest.mark.asyncio
async def test_jinja_conditional_true():
"""测试条件渲染为真"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "is_admin", True, VariableType.BOOLEAN, mut=True)
result = await JinjaRenderNode(CONDITIONAL_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
assert result == "Admin"
@pytest.mark.asyncio
async def test_jinja_conditional_false():
"""测试条件渲染为假"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "is_admin", False, VariableType.BOOLEAN, mut=True)
result = await JinjaRenderNode(CONDITIONAL_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
assert result == "User"
# ==================== 循环渲染测试 ====================
@pytest.mark.asyncio
async def test_jinja_loop_array():
"""测试数组循环渲染"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "items", ["apple", "banana", "cherry"], VariableType.ARRAY_STRING, mut=True)
result = await JinjaRenderNode(LOOP_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
assert result == "apple, banana, cherry"
@pytest.mark.asyncio
async def test_jinja_loop_empty_array():
"""测试空数组循环渲染"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "items", [], VariableType.ARRAY_STRING, mut=True)
result = await JinjaRenderNode(LOOP_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
assert result == ""
@pytest.mark.asyncio
async def test_jinja_loop_single_item():
"""测试单元素数组循环渲染"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "items", ["apple"], VariableType.ARRAY_STRING, mut=True)
result = await JinjaRenderNode(LOOP_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
assert result == "apple"
# ==================== 过滤器测试 ====================
@pytest.mark.asyncio
async def test_jinja_filter_upper():
"""测试大写过滤器"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "text", "hello world", VariableType.STRING, mut=True)
result = await JinjaRenderNode(FILTER_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
assert result == "HELLO WORLD"
@pytest.mark.asyncio
async def test_jinja_filter_lower():
"""测试小写过滤器"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "text", "HELLO WORLD", VariableType.STRING, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ text | lower }}",
"mapping": [
{
"name": "text",
"value": "conv.text"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "hello world"
@pytest.mark.asyncio
async def test_jinja_filter_title():
"""测试标题化过滤器"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "text", "hello world", VariableType.STRING, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ text | title }}",
"mapping": [
{
"name": "text",
"value": "conv.text"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "Hello World"
@pytest.mark.asyncio
async def test_jinja_filter_length():
"""测试长度过滤器"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "items", [1, 2, 3, 4, 5], VariableType.ARRAY_NUMBER, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "Length: {{ items | length }}",
"mapping": [
{
"name": "items",
"value": "conv.items"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "Length: 5"
# ==================== 对象属性访问测试 ====================
@pytest.mark.asyncio
async def test_jinja_object_access():
"""测试对象属性访问"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "user", {"name": "Alice", "age": 30}, VariableType.OBJECT, mut=True)
result = await JinjaRenderNode(OBJECT_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
assert result == "Name: Alice, Age: 30"
@pytest.mark.asyncio
async def test_jinja_nested_object():
"""测试嵌套对象访问"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "data", {
"user": {
"name": "Bob",
"address": {
"city": "Beijing"
}
}
}, VariableType.OBJECT, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ data.user.name }} lives in {{ data.user.address.city }}",
"mapping": [
{
"name": "data",
"value": "conv.data"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "Bob lives in Beijing"
# ==================== 数学运算测试 ====================
@pytest.mark.asyncio
async def test_jinja_math_addition():
"""测试加法运算"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "a", 10, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "b", 20, VariableType.NUMBER, mut=True)
result = await JinjaRenderNode(MATH_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
assert result == "10 + 20 = 30"
@pytest.mark.asyncio
async def test_jinja_math_subtraction():
"""测试减法运算"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "a", 30, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "b", 10, VariableType.NUMBER, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ a }} - {{ b }} = {{ a - b }}",
"mapping": [
{
"name": "a",
"value": "conv.a"
},
{
"name": "b",
"value": "conv.b"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "30 - 10 = 20"
@pytest.mark.asyncio
async def test_jinja_math_multiplication():
"""测试乘法运算"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "a", 5, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "b", 6, VariableType.NUMBER, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ a }} * {{ b }} = {{ a * b }}",
"mapping": [
{
"name": "a",
"value": "conv.a"
},
{
"name": "b",
"value": "conv.b"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "5 * 6 = 30"
@pytest.mark.asyncio
async def test_jinja_math_division():
"""测试除法运算"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "a", 20, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "b", 4, VariableType.NUMBER, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ a }} / {{ b }} = {{ a / b }}",
"mapping": [
{
"name": "a",
"value": "conv.a"
},
{
"name": "b",
"value": "conv.b"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "20 / 4 = 5.0"
# ==================== 默认值测试 ====================
@pytest.mark.asyncio
async def test_jinja_default_value_missing():
"""测试变量缺失时使用默认值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 不创建 name 变量
result = await JinjaRenderNode(DEFAULT_VALUE_CONFIG, {}).execute(state, variable_pool)
assert result == "Guest"
@pytest.mark.asyncio
async def test_jinja_default_value_present():
"""测试变量存在时不使用默认值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "name", "Alice", VariableType.STRING, mut=True)
result = await JinjaRenderNode(DEFAULT_VALUE_CONFIG, {}).execute(state, variable_pool)
assert result == "Alice"
# ==================== 字符串拼接测试 ====================
@pytest.mark.asyncio
async def test_jinja_string_concatenation():
"""测试字符串拼接"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "first", "Hello", VariableType.STRING, mut=True)
await variable_pool.new("conv", "second", "World", VariableType.STRING, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ first ~ ' ' ~ second }}",
"mapping": [
{
"name": "first",
"value": "conv.first"
},
{
"name": "second",
"value": "conv.second"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "Hello World"
# ==================== 比较运算测试 ====================
@pytest.mark.asyncio
async def test_jinja_comparison():
"""测试比较运算"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "score", 85, VariableType.NUMBER, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{% if score >= 90 %}A{% elif score >= 80 %}B{% elif score >= 70 %}C{% else %}D{% endif %}",
"mapping": [
{
"name": "score",
"value": "conv.score"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "B"
# ==================== 数组操作测试 ====================
@pytest.mark.asyncio
async def test_jinja_array_index():
"""测试数组索引访问"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "items", ["first", "second", "third"], VariableType.ARRAY_STRING, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "First: {{ items[0] }}, Last: {{ items[-1] }}",
"mapping": [
{
"name": "items",
"value": "conv.items"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "First: first, Last: third"
@pytest.mark.asyncio
async def test_jinja_array_slice():
"""测试数组切片"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "numbers", [1, 2, 3, 4, 5], VariableType.ARRAY_NUMBER, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{% for n in numbers[1:4] %}{{ n }}{% endfor %}",
"mapping": [
{
"name": "numbers",
"value": "conv.numbers"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "234"
# ==================== 复杂模板测试 ====================
@pytest.mark.asyncio
async def test_jinja_complex_template():
"""测试复杂模板"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "users", [
{"name": "Alice", "age": 25},
{"name": "Bob", "age": 30},
{"name": "Charlie", "age": 35}
], VariableType.ARRAY_OBJECT, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{% for user in users %}{{ user.name }} ({{ user.age }}){% if not loop.last %}, {% endif %}{% endfor %}",
"mapping": [
{
"name": "users",
"value": "conv.users"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "Alice (25), Bob (30), Charlie (35)"
# ==================== 空值处理测试 ====================
@pytest.mark.asyncio
async def test_jinja_empty_string():
"""测试空字符串"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "text", "", VariableType.STRING, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{% if text %}{{ text }}{% else %}Empty{% endif %}",
"mapping": [
{
"name": "text",
"value": "conv.text"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "Empty"
@pytest.mark.asyncio
async def test_jinja_zero_value():
"""测试零值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "count", 0, VariableType.NUMBER, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "Count: {{ count }}",
"mapping": [
{
"name": "count",
"value": "conv.count"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "Count: 0"
# ==================== 特殊字符测试 ====================
@pytest.mark.asyncio
async def test_jinja_special_characters():
"""测试特殊字符"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "text", "Hello \"World\"", VariableType.STRING, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ text }}",
"mapping": [
{
"name": "text",
"value": "conv.text"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "Hello \"World\""
@pytest.mark.asyncio
async def test_jinja_newline():
"""测试换行符"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "line1", "First line", VariableType.STRING, mut=True)
await variable_pool.new("conv", "line2", "Second line", VariableType.STRING, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ line1 }}\n{{ line2 }}",
"mapping": [
{
"name": "line1",
"value": "conv.line1"
},
{
"name": "line2",
"value": "conv.line2"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "First line\nSecond line"
# ==================== 错误处理测试 ====================
@pytest.mark.asyncio
async def test_jinja_invalid_template():
"""测试无效模板语法"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "name", "Alice", VariableType.STRING, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ name", # 缺少闭合括号
"mapping": [
{
"name": "name",
"value": "conv.name"
}
]
}
}
with pytest.raises(RuntimeError) as exc_info:
await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert "render failed" in str(exc_info.value)
@pytest.mark.asyncio
async def test_jinja_undefined_variable_strict_false():
"""测试未定义变量(非严格模式)"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 不创建任何变量
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "Hello, {{ undefined_var }}!",
"mapping": [
{
"name": "undefined_var",
"value": "conv.undefined"
}
]
}
}
# 非严格模式下,未定义变量会被渲染为空字符串
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "Hello, !"
# ==================== 布尔值测试 ====================
@pytest.mark.asyncio
async def test_jinja_boolean_true():
"""测试布尔值 True"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "flag", True, VariableType.BOOLEAN, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "Flag is {{ flag }}",
"mapping": [
{
"name": "flag",
"value": "conv.flag"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "Flag is True"
@pytest.mark.asyncio
async def test_jinja_boolean_false():
"""测试布尔值 False"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "flag", False, VariableType.BOOLEAN, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "Flag is {{ flag }}",
"mapping": [
{
"name": "flag",
"value": "conv.flag"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "Flag is False"
# ==================== 浮点数测试 ====================
@pytest.mark.asyncio
async def test_jinja_float_number():
"""测试浮点数"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "price", 19.99, VariableType.NUMBER, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "Price: ${{ price }}",
"mapping": [
{
"name": "price",
"value": "conv.price"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "Price: $19.99"
@pytest.mark.asyncio
async def test_jinja_float_formatting():
"""测试浮点数格式化"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "value", 3.14159, VariableType.NUMBER, mut=True)
config = {
"id": "jinja_test",
"type": "jinja-render",
"name": "Jinja渲染测试节点",
"config": {
"template": "{{ '%.2f' | format(value) }}",
"mapping": [
{
"name": "value",
"value": "conv.value"
}
]
}
}
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
assert result == "3.14"

View File

@@ -0,0 +1,145 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/5 15:39
import pytest
from app.core.workflow.nodes import LLMNode
from tests.workflow.nodes.base import TEST_MODEL_ID, simple_state, simple_vairable_pool
@pytest.mark.asyncio
async def test_llm_memory_no_stream():
node_config = {
"id": "llm_test",
"type": "llm",
"name": "LLM 问答",
"config": {
"messages": [
{
"role": "system",
"content": "你是一个专业、友好且乐于助人的 AI 助手。"
"你的职责:- "
"准确理解用户的问题并提供有价值的回答"
"- 保持回答的专业性和准确性"
"- 如果不确定答案,诚实地告知用户"
"- 使用清晰、易懂的语言进行交流"
"回答风格:"
"- 简洁明了,直击要点"
"- 必要时提供详细解释和示例"
"- 使用友好、礼貌的语气"
"- 适当使用格式化(如列表、段落)提高可读性"
},
{
"role": "user",
"content": "{{ sys.message }}"
}
],
"model_id": TEST_MODEL_ID,
"temperature": 0.7,
"max_tokens": 1000,
"memory": {
"enable": True,
"enable_window": True,
"window_size": 5
},
"vision": False,
"vision_input": "{{sys.files}}"
}
}
state = simple_state()
variable_pool = await simple_vairable_pool("输出上一句话")
result = await LLMNode(node_config, {}).execute(state, variable_pool)
assert '123456' in result.content
@pytest.mark.asyncio
async def test_llm_memory_stream():
node_config = {
"id": "llm_test",
"type": "llm",
"name": "LLM 问答",
"config": {
"messages": [
{
"role": "system",
"content": "你是一个专业、友好且乐于助人的 AI 助手。"
"你的职责:- "
"准确理解用户的问题并提供有价值的回答"
"- 保持回答的专业性和准确性"
"- 如果不确定答案,诚实地告知用户"
"- 使用清晰、易懂的语言进行交流"
"回答风格:"
"- 简洁明了,直击要点"
"- 必要时提供详细解释和示例"
"- 使用友好、礼貌的语气"
"- 适当使用格式化(如列表、段落)提高可读性"
},
{
"role": "user",
"content": "{{ sys.message }}"
}
],
"model_id": TEST_MODEL_ID,
"temperature": 0.7,
"max_tokens": 1000,
"memory": {
"enable": True,
"enable_window": True,
"window_size": 5
},
"vision": False,
"vision_input": "{{sys.files}}"
}
}
state = simple_state()
variable_pool = await simple_vairable_pool("输出上一句话")
async for event in LLMNode(node_config, {}).execute_stream(state, variable_pool):
if event.get("__final__"):
assert '123456' in event.get("result").content
@pytest.mark.asyncio
async def test_llm_vision():
node_config = {
"id": "llm_test",
"type": "llm",
"name": "LLM 问答",
"config": {
"messages": [
{
"role": "system",
"content": "你是一个专业、友好且乐于助人的 AI 助手。"
"你的职责:- "
"准确理解用户的问题并提供有价值的回答"
"- 保持回答的专业性和准确性"
"- 如果不确定答案,诚实地告知用户"
"- 使用清晰、易懂的语言进行交流"
"回答风格:"
"- 简洁明了,直击要点"
"- 必要时提供详细解释和示例"
"- 使用友好、礼貌的语气"
"- 适当使用格式化(如列表、段落)提高可读性"
},
{
"role": "user",
"content": "{{ sys.message }}"
}
],
"model_id": TEST_MODEL_ID,
"temperature": 0.7,
"max_tokens": 1000,
"memory": {
"enable": True,
"enable_window": True,
"window_size": 5
},
"vision": True,
"vision_input": "{{sys.files}}"
}
}
state = simple_state()
variable_pool = await simple_vairable_pool("图片里面有什么")
async for event in LLMNode(node_config, {}).execute_stream(state, variable_pool):
if event.get("__final__"):
assert '' in event.get("result").content

View File

@@ -0,0 +1,504 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/6 14:10
import pytest
from app.core.workflow.nodes import ParameterExtractorNode
from app.core.workflow.variable.base_variable import VariableType
from tests.workflow.nodes.base import TEST_MODEL_ID, simple_state, simple_vairable_pool
# 基础参数提取配置 - 单个字符串参数
SINGLE_STRING_PARAM_CONFIG = {
"id": "param_extractor_test",
"type": "parameter-extractor",
"name": "参数提取测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"text": "我的名字是张三今年25岁",
"params": [
{
"name": "name",
"type": "string",
"desc": "用户的姓名",
"required": True
}
],
"prompt": ""
}
}
# 多参数提取配置
MULTI_PARAMS_CONFIG = {
"id": "param_extractor_test",
"type": "parameter-extractor",
"name": "参数提取测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"text": "我的名字是李四今年30岁住在北京",
"params": [
{
"name": "name",
"type": "string",
"desc": "用户的姓名",
"required": True
},
{
"name": "age",
"type": "number",
"desc": "用户的年龄",
"required": True
},
{
"name": "city",
"type": "string",
"desc": "用户所在的城市",
"required": False
}
],
"prompt": ""
}
}
# 数字参数提取配置
NUMBER_PARAM_CONFIG = {
"id": "param_extractor_test",
"type": "parameter-extractor",
"name": "参数提取测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"text": "这个产品的价格是99.99元库存有100件",
"params": [
{
"name": "price",
"type": "number",
"desc": "产品价格",
"required": True
},
{
"name": "stock",
"type": "number",
"desc": "库存数量",
"required": True
}
],
"prompt": ""
}
}
# 布尔参数提取配置
BOOLEAN_PARAM_CONFIG = {
"id": "param_extractor_test",
"type": "parameter-extractor",
"name": "参数提取测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"text": "这个用户已经完成了实名认证,但还没有绑定手机号",
"params": [
{
"name": "verified",
"type": "boolean",
"desc": "是否完成实名认证",
"required": True
},
{
"name": "phone_bound",
"type": "boolean",
"desc": "是否绑定手机号",
"required": True
}
],
"prompt": ""
}
}
# 数组参数提取配置
ARRAY_STRING_PARAM_CONFIG = {
"id": "param_extractor_test",
"type": "parameter-extractor",
"name": "参数提取测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"text": "我喜欢的水果有苹果、香蕉、橙子",
"params": [
{
"name": "fruits",
"type": "array[string]",
"desc": "喜欢的水果列表",
"required": True
}
],
"prompt": ""
}
}
# 数字数组参数提取配置
ARRAY_NUMBER_PARAM_CONFIG = {
"id": "param_extractor_test",
"type": "parameter-extractor",
"name": "参数提取测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"text": "这个月的销售额分别是第一周10000第二周12000第三周15000第四周18000",
"params": [
{
"name": "weekly_sales",
"type": "array[number]",
"desc": "每周的销售额",
"required": True
}
],
"prompt": ""
}
}
# 带自定义提示的配置
CUSTOM_PROMPT_CONFIG = {
"id": "param_extractor_test",
"type": "parameter-extractor",
"name": "参数提取测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"text": "订单号ORD123456金额299元",
"params": [
{
"name": "order_id",
"type": "string",
"desc": "订单编号",
"required": True
},
{
"name": "amount",
"type": "number",
"desc": "订单金额",
"required": True
}
],
"prompt": "请仔细提取订单信息,确保订单号和金额准确无误"
}
}
# 使用变量的配置
VARIABLE_INPUT_CONFIG = {
"id": "param_extractor_test",
"type": "parameter-extractor",
"name": "参数提取测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"text": "{{ conv.user_input }}",
"params": [
{
"name": "name",
"type": "string",
"desc": "用户姓名",
"required": True
},
{
"name": "age",
"type": "number",
"desc": "用户年龄",
"required": True
}
],
"prompt": ""
}
}
# ==================== 基础参数提取测试 ====================
@pytest.mark.asyncio
async def test_extract_single_string_param():
"""测试提取单个字符串参数"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await ParameterExtractorNode(SINGLE_STRING_PARAM_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "name" in result
assert isinstance(result["name"], str)
assert "张三" in result["name"]
@pytest.mark.asyncio
async def test_extract_multi_params():
"""测试提取多个参数"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await ParameterExtractorNode(MULTI_PARAMS_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "name" in result
assert "age" in result
assert "city" in result
assert isinstance(result["name"], str)
assert isinstance(result["age"], (int, float))
assert "李四" in result["name"]
assert result["age"] == 30
assert "北京" in result["city"]
# ==================== 数字参数提取测试 ====================
@pytest.mark.asyncio
async def test_extract_number_params():
"""测试提取数字参数"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await ParameterExtractorNode(NUMBER_PARAM_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "price" in result
assert "stock" in result
assert isinstance(result["price"], (int, float))
assert isinstance(result["stock"], (int, float))
assert abs(result["price"] - 99.99) < 0.1
assert result["stock"] == 100
# ==================== 布尔参数提取测试 ====================
@pytest.mark.asyncio
async def test_extract_boolean_params():
"""测试提取布尔参数"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await ParameterExtractorNode(BOOLEAN_PARAM_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "verified" in result
assert "phone_bound" in result
assert isinstance(result["verified"], bool)
assert isinstance(result["phone_bound"], bool)
assert result["verified"] is True
assert result["phone_bound"] is False
# ==================== 数组参数提取测试 ====================
@pytest.mark.asyncio
async def test_extract_array_string_param():
"""测试提取字符串数组参数"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await ParameterExtractorNode(ARRAY_STRING_PARAM_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "fruits" in result
assert isinstance(result["fruits"], list)
assert len(result["fruits"]) >= 3
assert "苹果" in result["fruits"]
assert "香蕉" in result["fruits"]
assert "橙子" in result["fruits"]
@pytest.mark.asyncio
async def test_extract_array_number_param():
"""测试提取数字数组参数"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await ParameterExtractorNode(ARRAY_NUMBER_PARAM_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "weekly_sales" in result
assert isinstance(result["weekly_sales"], list)
assert len(result["weekly_sales"]) == 4
assert 10000 in result["weekly_sales"]
assert 12000 in result["weekly_sales"]
assert 15000 in result["weekly_sales"]
assert 18000 in result["weekly_sales"]
# ==================== 自定义提示测试 ====================
@pytest.mark.asyncio
async def test_extract_with_custom_prompt():
"""测试使用自定义提示提取参数"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await ParameterExtractorNode(CUSTOM_PROMPT_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "order_id" in result
assert "amount" in result
assert "ORD123456" in result["order_id"]
assert isinstance(result["amount"], (int, float))
assert result["amount"] == 299
# ==================== 变量输入测试 ====================
@pytest.mark.asyncio
async def test_extract_with_variable_input():
"""测试使用变量作为输入文本"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "user_input", "我叫王五今年28岁", VariableType.STRING, mut=True)
result = await ParameterExtractorNode(VARIABLE_INPUT_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "name" in result
assert "age" in result
assert "王五" in result["name"]
assert result["age"] == 28
# ==================== 复杂场景测试 ====================
@pytest.mark.asyncio
async def test_extract_from_complex_text():
"""测试从复杂文本中提取参数"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "param_extractor_test",
"type": "parameter-extractor",
"name": "参数提取测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"text": """
客户信息:
姓名:赵六
年龄35岁
职业:软件工程师
城市:上海
邮箱zhaoliu@example.com
是否VIP
""",
"params": [
{
"name": "name",
"type": "string",
"desc": "客户姓名",
"required": True
},
{
"name": "age",
"type": "number",
"desc": "客户年龄",
"required": True
},
{
"name": "occupation",
"type": "string",
"desc": "客户职业",
"required": False
},
{
"name": "city",
"type": "string",
"desc": "所在城市",
"required": False
},
{
"name": "is_vip",
"type": "boolean",
"desc": "是否为VIP客户",
"required": False
}
],
"prompt": ""
}
}
result = await ParameterExtractorNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "name" in result
assert "age" in result
assert "赵六" in result["name"]
assert result["age"] == 35
if "occupation" in result:
assert "工程师" in result["occupation"]
if "city" in result:
assert "上海" in result["city"]
if "is_vip" in result:
assert result["is_vip"] is True
@pytest.mark.asyncio
async def test_extract_optional_params():
"""测试提取可选参数"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "param_extractor_test",
"type": "parameter-extractor",
"name": "参数提取测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"text": "我叫小明",
"params": [
{
"name": "name",
"type": "string",
"desc": "用户姓名",
"required": True
},
{
"name": "age",
"type": "number",
"desc": "用户年龄",
"required": False
},
{
"name": "city",
"type": "string",
"desc": "所在城市",
"required": False
}
],
"prompt": ""
}
}
result = await ParameterExtractorNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "name" in result
assert "小明" in result["name"]
# 可选参数可能不存在或为 None
@pytest.mark.asyncio
async def test_extract_with_sys_message():
"""测试使用系统消息变量"""
state = simple_state()
variable_pool = await simple_vairable_pool("我叫小红今年22岁")
config = {
"id": "param_extractor_test",
"type": "parameter-extractor",
"name": "参数提取测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"text": "{{ sys.message }}",
"params": [
{
"name": "name",
"type": "string",
"desc": "用户姓名",
"required": True
},
{
"name": "age",
"type": "number",
"desc": "用户年龄",
"required": True
}
],
"prompt": ""
}
}
result = await ParameterExtractorNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "name" in result
assert "age" in result
assert "小红" in result["name"]
assert result["age"] == 22

View File

@@ -0,0 +1,647 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/6
import pytest
from app.core.workflow.nodes import QuestionClassifierNode
from app.core.workflow.variable.base_variable import VariableType
from tests.workflow.nodes.base import TEST_MODEL_ID, simple_state, simple_vairable_pool
# 基础分类配置 - 两个类别
BASIC_TWO_CATEGORIES_CONFIG = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "我想买一台笔记本电脑",
"categories": [
{
"class_name": "产品咨询"
},
{
"class_name": "售后服务"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
# 多类别配置
MULTI_CATEGORIES_CONFIG = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "我的订单什么时候能到?",
"categories": [
{
"class_name": "产品咨询"
},
{
"class_name": "订单查询"
},
{
"class_name": "售后服务"
},
{
"class_name": "投诉建议"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
# 带补充提示的配置
WITH_SUPPLEMENT_PROMPT_CONFIG = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "这个产品怎么样?",
"categories": [
{
"class_name": "产品咨询"
},
{
"class_name": "用户评价"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": "如果用户在询问产品信息或特性,归类为产品咨询;如果是评价或反馈,归类为用户评价"
}
}
# 使用变量的配置
VARIABLE_INPUT_CONFIG = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "{{ conv.user_question }}",
"categories": [
{
"class_name": "技术支持"
},
{
"class_name": "账号问题"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
# 使用系统消息的配置
SYS_MESSAGE_CONFIG = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "{{ sys.message }}",
"categories": [
{
"class_name": "产品咨询"
},
{
"class_name": "售后服务"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
# 空问题配置
EMPTY_QUESTION_CONFIG = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "",
"categories": [
{
"class_name": "产品咨询"
},
{
"class_name": "售后服务"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
# ==================== 基础分类测试 ====================
@pytest.mark.asyncio
async def test_classify_product_inquiry():
"""测试产品咨询分类"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await QuestionClassifierNode(BASIC_TWO_CATEGORIES_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "class_name" in result
assert "output" in result
assert result["class_name"] == "产品咨询"
assert result["output"] == "CASE1"
@pytest.mark.asyncio
async def test_classify_after_sales():
"""测试售后服务分类"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "我的产品坏了,怎么维修?",
"categories": [
{
"class_name": "产品咨询"
},
{
"class_name": "售后服务"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["class_name"] == "售后服务"
assert result["output"] == "CASE2"
# ==================== 多类别分类测试 ====================
@pytest.mark.asyncio
async def test_classify_order_inquiry():
"""测试订单查询分类"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await QuestionClassifierNode(MULTI_CATEGORIES_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["class_name"] == "订单查询"
assert result["output"] == "CASE2"
@pytest.mark.asyncio
async def test_classify_complaint():
"""测试投诉建议分类"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "你们的服务态度太差了!",
"categories": [
{
"class_name": "产品咨询"
},
{
"class_name": "订单查询"
},
{
"class_name": "售后服务"
},
{
"class_name": "投诉建议"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["class_name"] == "投诉建议"
assert result["output"] == "CASE4"
# ==================== 补充提示测试 ====================
@pytest.mark.asyncio
async def test_classify_with_supplement_prompt():
"""测试使用补充提示进行分类"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await QuestionClassifierNode(WITH_SUPPLEMENT_PROMPT_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "class_name" in result
assert "output" in result
assert result["class_name"] in ["产品咨询", "用户评价"]
assert result["output"] in ["CASE1", "CASE2"]
# ==================== 变量输入测试 ====================
@pytest.mark.asyncio
async def test_classify_with_conv_variable():
"""测试使用 conv 变量作为输入"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
await variable_pool.new("conv", "user_question", "我忘记密码了", VariableType.STRING, mut=True)
result = await QuestionClassifierNode(VARIABLE_INPUT_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["class_name"] == "账号问题"
assert result["output"] == "CASE2"
@pytest.mark.asyncio
async def test_classify_with_sys_message():
"""测试使用系统消息变量"""
state = simple_state()
variable_pool = await simple_vairable_pool("我想了解一下你们的产品功能")
result = await QuestionClassifierNode(SYS_MESSAGE_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["class_name"] == "产品咨询"
assert result["output"] == "CASE1"
# ==================== 边界情况测试 ====================
@pytest.mark.asyncio
async def test_classify_empty_question():
"""测试空问题输入"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await QuestionClassifierNode(EMPTY_QUESTION_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "class_name" in result
assert "output" in result
# 空问题应该返回默认分类(第一个分类)
assert result["class_name"] == "产品咨询"
assert result["output"] == "CASE1"
@pytest.mark.asyncio
async def test_classify_single_category():
"""测试只有一个分类的情况"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "任何问题",
"categories": [
{
"class_name": "通用咨询"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["class_name"] == "通用咨询"
assert result["output"] == "CASE1"
# ==================== 复杂场景测试 ====================
@pytest.mark.asyncio
async def test_classify_ambiguous_question():
"""测试模糊问题分类"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "你好",
"categories": [
{
"class_name": "产品咨询"
},
{
"class_name": "售后服务"
},
{
"class_name": "闲聊"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["class_name"] in ["产品咨询", "售后服务", "闲聊"]
assert result["output"] in ["CASE1", "CASE2", "CASE3"]
@pytest.mark.asyncio
async def test_classify_long_question():
"""测试长问题分类"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "我在上个月购买了你们的产品,使用了一段时间后发现有一些问题,想咨询一下售后政策和维修流程,请问应该怎么办?",
"categories": [
{
"class_name": "产品咨询"
},
{
"class_name": "售后服务"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["class_name"] == "售后服务"
assert result["output"] == "CASE2"
@pytest.mark.asyncio
async def test_classify_technical_support():
"""测试技术支持分类"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "软件安装失败报错代码0x80070005",
"categories": [
{
"class_name": "技术支持"
},
{
"class_name": "账号问题"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["class_name"] == "技术支持"
assert result["output"] == "CASE1"
@pytest.mark.asyncio
async def test_classify_multiple_categories():
"""测试多个类别的详细分类"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "我想申请退款",
"categories": [
{
"class_name": "产品咨询"
},
{
"class_name": "订单查询"
},
{
"class_name": "退换货"
},
{
"class_name": "售后服务"
},
{
"class_name": "投诉建议"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["class_name"] == "退换货"
assert result["output"] == "CASE3"
@pytest.mark.asyncio
async def test_classify_with_detailed_supplement():
"""测试使用详细补充提示"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "这个功能怎么用?",
"categories": [
{
"class_name": "产品使用"
},
{
"class_name": "产品介绍"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": "如果用户询问如何使用某个功能,归类为产品使用;如果询问功能是什么或有什么功能,归类为产品介绍"
}
}
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["class_name"] == "产品使用"
assert result["output"] == "CASE1"
@pytest.mark.asyncio
async def test_classify_chinese_categories():
"""测试中文类别名称"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "我要投诉",
"categories": [
{
"class_name": "咨询类"
},
{
"class_name": "投诉类"
},
{
"class_name": "建议类"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["class_name"] == "投诉类"
assert result["output"] == "CASE2"
@pytest.mark.asyncio
async def test_classify_case_mapping():
"""测试分类到 CASE 的映射关系"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "测试问题",
"categories": [
{
"class_name": "类别A"
},
{
"class_name": "类别B"
},
{
"class_name": "类别C"
},
{
"class_name": "类别D"
},
{
"class_name": "类别E"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "class_name" in result
assert "output" in result
# 验证 CASE 映射关系
category_names = ["类别A", "类别B", "类别C", "类别D", "类别E"]
if result["class_name"] in category_names:
expected_case = f"CASE{category_names.index(result['class_name']) + 1}"
assert result["output"] == expected_case
@pytest.mark.asyncio
async def test_classify_with_special_characters():
"""测试包含特殊字符的问题"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "classifier_test",
"type": "question-classifier",
"name": "问题分类测试节点",
"config": {
"model_id": TEST_MODEL_ID,
"input_variable": "产品价格是多少?有优惠吗?",
"categories": [
{
"class_name": "价格咨询"
},
{
"class_name": "促销活动"
}
],
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
"user_supplement_prompt": None
}
}
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["class_name"] in ["价格咨询", "促销活动"]
assert result["output"] in ["CASE1", "CASE2"]

View File

@@ -0,0 +1,735 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/6
import pytest
from app.core.workflow.nodes import StartNode
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from tests.workflow.nodes.base import (
simple_state,
simple_vairable_pool,
TEST_EXECUTION_ID,
TEST_WORKSPACE_ID,
TEST_USER_ID,
TEST_CONVERSATION_ID,
TEST_FILE
)
async def create_variable_pool_with_inputs(message: str, input_variables: dict = None):
"""创建带有自定义输入变量的变量池"""
variable_pool = VariablePool()
sys_vars = {
"message": (message, VariableType.STRING),
"conversation_id": (TEST_CONVERSATION_ID, VariableType.STRING),
"execution_id": (TEST_EXECUTION_ID, VariableType.STRING),
"workspace_id": (TEST_WORKSPACE_ID, VariableType.STRING),
"user_id": (TEST_USER_ID, VariableType.STRING),
"input_variables": (input_variables or {}, VariableType.OBJECT),
"files": ([TEST_FILE], VariableType.ARRAY_FILE)
}
for key, var_def in sys_vars.items():
value = var_def[0]
var_type = var_def[1]
await variable_pool.new(
namespace='sys',
key=key,
value=value,
var_type=VariableType(var_type),
mut=False # 系统变量不可变
)
return variable_pool
# 基础配置 - 无自定义变量
BASIC_CONFIG = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": []
}
}
# 带单个自定义变量的配置
SINGLE_VARIABLE_CONFIG = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "language",
"type": "string",
"required": False,
"default": "zh-CN",
"description": "语言设置"
}
]
}
}
# 带多个自定义变量的配置
MULTI_VARIABLES_CONFIG = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "language",
"type": "string",
"required": False,
"default": "zh-CN",
"description": "语言设置"
},
{
"name": "max_length",
"type": "number",
"required": False,
"default": 1000,
"description": "最大长度"
},
{
"name": "enable_cache",
"type": "boolean",
"required": False,
"default": True,
"description": "是否启用缓存"
}
]
}
}
# 带必需变量的配置
REQUIRED_VARIABLE_CONFIG = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "api_key",
"type": "string",
"required": True,
"description": "API密钥"
}
]
}
}
# 混合必需和可选变量的配置
MIXED_VARIABLES_CONFIG = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "user_id",
"type": "string",
"required": True,
"description": "用户ID"
},
{
"name": "timeout",
"type": "number",
"required": False,
"default": 30,
"description": "超时时间(秒)"
}
]
}
}
# 不同类型变量的配置
DIFFERENT_TYPES_CONFIG = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "name",
"type": "string",
"required": False,
"default": "default_name",
"description": "名称"
},
{
"name": "count",
"type": "number",
"required": False,
"default": 0,
"description": "计数"
},
{
"name": "enabled",
"type": "boolean",
"required": False,
"default": False,
"description": "是否启用"
},
{
"name": "tags",
"type": "array[string]",
"required": False,
"default": [],
"description": "标签列表"
},
{
"name": "config",
"type": "object",
"required": False,
"default": {},
"description": "配置对象"
}
]
}
}
# ==================== 基础功能测试 ====================
@pytest.mark.asyncio
async def test_start_node_basic():
"""测试基础 Start 节点(无自定义变量)"""
state = simple_state()
variable_pool = await simple_vairable_pool("test message")
result = await StartNode(BASIC_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert "message" in result
assert "execution_id" in result
assert "conversation_id" in result
assert "workspace_id" in result
assert "user_id" in result
assert result["message"] == "test message"
@pytest.mark.asyncio
async def test_start_node_system_variables():
"""测试系统变量输出"""
state = simple_state()
variable_pool = await simple_vairable_pool("hello world")
result = await StartNode(BASIC_CONFIG, {}).execute(state, variable_pool)
assert result["message"] == "hello world"
assert result["execution_id"] == state["execution_id"]
assert result["workspace_id"] == state["workspace_id"]
assert result["user_id"] == state["user_id"]
# ==================== 自定义变量测试 ====================
@pytest.mark.asyncio
async def test_start_node_single_variable_with_default():
"""测试单个自定义变量使用默认值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
assert "language" in result
assert result["language"] == "zh-CN"
@pytest.mark.asyncio
async def test_start_node_single_variable_with_input():
"""测试单个自定义变量使用输入值"""
state = simple_state()
# 使用带输入变量的变量池
input_vars = {"language": "en-US"}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
assert "language" in result
assert result["language"] == "en-US"
@pytest.mark.asyncio
async def test_start_node_multi_variables_with_defaults():
"""测试多个自定义变量使用默认值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await StartNode(MULTI_VARIABLES_CONFIG, {}).execute(state, variable_pool)
assert "language" in result
assert "max_length" in result
assert "enable_cache" in result
assert result["language"] == "zh-CN"
assert result["max_length"] == 1000
assert result["enable_cache"] is True
@pytest.mark.asyncio
async def test_start_node_multi_variables_with_inputs():
"""测试多个自定义变量使用输入值"""
state = simple_state()
# 使用带输入变量的变量池
input_vars = {
"language": "ja-JP",
"max_length": 2000,
"enable_cache": False
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(MULTI_VARIABLES_CONFIG, {}).execute(state, variable_pool)
assert result["language"] == "ja-JP"
assert result["max_length"] == 2000
assert result["enable_cache"] is False
@pytest.mark.asyncio
async def test_start_node_partial_inputs():
"""测试部分输入变量,其他使用默认值"""
state = simple_state()
# 只设置部分输入变量
input_vars = {
"language": "fr-FR"
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(MULTI_VARIABLES_CONFIG, {}).execute(state, variable_pool)
assert result["language"] == "fr-FR" # 使用输入值
assert result["max_length"] == 1000 # 使用默认值
assert result["enable_cache"] is True # 使用默认值
# ==================== 必需变量测试 ====================
@pytest.mark.asyncio
async def test_start_node_required_variable_provided():
"""测试提供必需变量"""
state = simple_state()
# 提供必需变量
input_vars = {
"api_key": "test_api_key_12345"
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(REQUIRED_VARIABLE_CONFIG, {}).execute(state, variable_pool)
assert "api_key" in result
assert result["api_key"] == "test_api_key_12345"
@pytest.mark.asyncio
async def test_start_node_required_variable_missing():
"""测试缺少必需变量"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 不提供必需变量
with pytest.raises(ValueError) as exc_info:
await StartNode(REQUIRED_VARIABLE_CONFIG, {}).execute(state, variable_pool)
assert "缺少必需的输入变量" in str(exc_info.value)
assert "api_key" in str(exc_info.value)
@pytest.mark.asyncio
async def test_start_node_mixed_variables():
"""测试混合必需和可选变量"""
state = simple_state()
# 只提供必需变量
input_vars = {
"user_id": "user_123"
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(MIXED_VARIABLES_CONFIG, {}).execute(state, variable_pool)
assert result["user_id"] == "user_123" # 必需变量
assert result["timeout"] == 30 # 可选变量使用默认值
@pytest.mark.asyncio
async def test_start_node_mixed_variables_all_provided():
"""测试混合变量全部提供"""
state = simple_state()
# 提供所有变量
input_vars = {
"user_id": "user_456",
"timeout": 60
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(MIXED_VARIABLES_CONFIG, {}).execute(state, variable_pool)
assert result["user_id"] == "user_456"
assert result["timeout"] == 60
# ==================== 不同类型变量测试 ====================
@pytest.mark.asyncio
async def test_start_node_different_types_defaults():
"""测试不同类型变量的默认值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
result = await StartNode(DIFFERENT_TYPES_CONFIG, {}).execute(state, variable_pool)
assert result["name"] == "default_name"
assert result["count"] == 0
assert result["enabled"] is False
assert result["tags"] == []
assert result["config"] == {}
@pytest.mark.asyncio
async def test_start_node_different_types_inputs():
"""测试不同类型变量的输入值"""
state = simple_state()
# 提供不同类型的输入值
input_vars = {
"name": "custom_name",
"count": 100,
"enabled": True,
"tags": ["tag1", "tag2", "tag3"],
"config": {"key": "value", "nested": {"data": 123}}
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(DIFFERENT_TYPES_CONFIG, {}).execute(state, variable_pool)
assert result["name"] == "custom_name"
assert result["count"] == 100
assert result["enabled"] is True
assert result["tags"] == ["tag1", "tag2", "tag3"]
assert result["config"] == {"key": "value", "nested": {"data": 123}}
# ==================== 边界情况测试 ====================
@pytest.mark.asyncio
async def test_start_node_empty_message():
"""测试空消息"""
state = simple_state()
variable_pool = await simple_vairable_pool("")
result = await StartNode(BASIC_CONFIG, {}).execute(state, variable_pool)
assert result["message"] == ""
@pytest.mark.asyncio
async def test_start_node_no_input_variables():
"""测试没有输入变量的情况"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 不设置 input_variables
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
# 应该使用默认值
assert result["language"] == "zh-CN"
@pytest.mark.asyncio
async def test_start_node_empty_input_variables():
"""测试空的输入变量字典"""
state = simple_state()
# 设置空的输入变量字典
variable_pool = await create_variable_pool_with_inputs("test", {})
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
# 应该使用默认值
assert result["language"] == "zh-CN"
@pytest.mark.asyncio
async def test_start_node_extra_input_variables():
"""测试额外的输入变量(未在配置中定义)"""
state = simple_state()
# 提供额外的未定义变量
input_vars = {
"language": "de-DE",
"extra_var": "should_be_ignored"
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
assert result["language"] == "de-DE"
assert "extra_var" not in result # 额外变量不应该出现在输出中
# ==================== 数组类型变量测试 ====================
@pytest.mark.asyncio
async def test_start_node_array_string_variable():
"""测试字符串数组变量"""
state = simple_state()
config = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "categories",
"type": "array[string]",
"required": False,
"default": ["default1", "default2"],
"description": "分类列表"
}
]
}
}
input_vars = {
"categories": ["cat1", "cat2", "cat3"]
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(config, {}).execute(state, variable_pool)
assert result["categories"] == ["cat1", "cat2", "cat3"]
@pytest.mark.asyncio
async def test_start_node_array_number_variable():
"""测试数字数组变量"""
state = simple_state()
config = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "scores",
"type": "array[number]",
"required": False,
"default": [0, 0, 0],
"description": "分数列表"
}
]
}
}
input_vars = {
"scores": [85, 90, 95]
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(config, {}).execute(state, variable_pool)
assert result["scores"] == [85, 90, 95]
@pytest.mark.asyncio
async def test_start_node_array_object_variable():
"""测试对象数组变量"""
state = simple_state()
config = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "users",
"type": "array[object]",
"required": False,
"default": [],
"description": "用户列表"
}
]
}
}
input_vars = {
"users": [
{"name": "Alice", "age": 25},
{"name": "Bob", "age": 30}
]
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(config, {}).execute(state, variable_pool)
assert len(result["users"]) == 2
assert result["users"][0]["name"] == "Alice"
assert result["users"][1]["age"] == 30
# ==================== 复杂场景测试 ====================
@pytest.mark.asyncio
async def test_start_node_complex_object():
"""测试复杂对象变量"""
state = simple_state()
config = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "settings",
"type": "object",
"required": False,
"default": {"theme": "light"},
"description": "设置对象"
}
]
}
}
input_vars = {
"settings": {
"theme": "dark",
"language": "zh-CN",
"notifications": {
"email": True,
"sms": False
},
"features": ["feature1", "feature2"]
}
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(config, {}).execute(state, variable_pool)
assert result["settings"]["theme"] == "dark"
assert result["settings"]["language"] == "zh-CN"
assert result["settings"]["notifications"]["email"] is True
assert result["settings"]["features"] == ["feature1", "feature2"]
@pytest.mark.asyncio
async def test_start_node_zero_and_false_values():
"""测试零值和 False 值(确保不被当作空值)"""
state = simple_state()
config = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "count",
"type": "number",
"required": False,
"default": 10,
"description": "计数"
},
{
"name": "enabled",
"type": "boolean",
"required": False,
"default": True,
"description": "是否启用"
}
]
}
}
input_vars = {
"count": 0,
"enabled": False
}
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
result = await StartNode(config, {}).execute(state, variable_pool)
# 0 和 False 应该被正确识别,而不是使用默认值
assert result["count"] == 0
assert result["enabled"] is False
@pytest.mark.asyncio
async def test_start_node_output_types():
"""测试输出类型定义"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
node = StartNode(MULTI_VARIABLES_CONFIG, {})
await node.execute(state, variable_pool)
output_types = node._output_types()
# 验证系统变量类型
assert output_types["message"] == VariableType.STRING
assert output_types["execution_id"] == VariableType.STRING
assert output_types["conversation_id"] == VariableType.STRING
assert output_types["workspace_id"] == VariableType.STRING
assert output_types["user_id"] == VariableType.STRING
# 验证自定义变量类型
assert output_types["language"] == VariableType.STRING
assert output_types["max_length"] == VariableType.NUMBER
assert output_types["enable_cache"] == VariableType.BOOLEAN
@pytest.mark.asyncio
async def test_start_node_multiple_executions():
"""测试多次执行 Start 节点"""
state = simple_state()
node = StartNode(SINGLE_VARIABLE_CONFIG, {})
# 第一次执行
variable_pool1 = await create_variable_pool_with_inputs("first message", {})
result1 = await node.execute(state, variable_pool1)
assert result1["message"] == "first message"
assert result1["language"] == "zh-CN"
# 第二次执行(使用新的变量池)
variable_pool2 = await create_variable_pool_with_inputs("second message", {})
result2 = await node.execute(state, variable_pool2)
assert result2["message"] == "second message"
assert result2["language"] == "zh-CN"
@pytest.mark.asyncio
async def test_start_node_with_description():
"""测试带描述的变量"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "start_test",
"type": "start",
"name": "开始节点",
"config": {
"variables": [
{
"name": "api_endpoint",
"type": "string",
"required": True,
"description": "API 端点 URL用于连接外部服务"
}
]
}
}
# 测试缺少必需变量时,错误信息包含描述
with pytest.raises(ValueError) as exc_info:
await StartNode(config, {}).execute(state, variable_pool)
assert "api_endpoint" in str(exc_info.value)
assert "API 端点 URL" in str(exc_info.value)

View File

@@ -0,0 +1,621 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/6
import pytest
from app.core.workflow.nodes import VariableAggregatorNode
from app.core.workflow.variable.base_variable import VariableType
from tests.workflow.nodes.base import simple_state, simple_vairable_pool
# 非分组模式配置 - 返回第一个非空变量
NON_GROUP_CONFIG = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": False,
"group_variables": [
"{{conv.var1}}",
"{{conv.var2}}",
"{{conv.var3}}"
]
}
}
# 非分组模式配置 - 带类型定义
NON_GROUP_WITH_TYPE_CONFIG = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": False,
"group_variables": [
"{{conv.var1}}",
"{{conv.var2}}"
],
"group_type": {
"output": "string"
}
}
}
# 分组模式配置
GROUP_CONFIG = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": True,
"group_variables": {
"user_message": [
"{{conv.msg1}}",
"{{conv.msg2}}"
],
"user_name": [
"{{conv.name1}}",
"{{conv.name2}}"
]
}
}
}
# 分组模式配置 - 带类型定义
GROUP_WITH_TYPE_CONFIG = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": True,
"group_variables": {
"count": [
"{{conv.count1}}",
"{{conv.count2}}"
],
"enabled": [
"{{conv.flag1}}",
"{{conv.flag2}}"
]
},
"group_type": {
"count": "number",
"enabled": "boolean"
}
}
}
# ==================== 非分组模式测试 ====================
@pytest.mark.asyncio
async def test_non_group_first_variable():
"""测试非分组模式返回第一个非空变量"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 设置变量
await variable_pool.new("conv", "var1", "first_value", VariableType.STRING, mut=True)
await variable_pool.new("conv", "var2", "second_value", VariableType.STRING, mut=True)
await variable_pool.new("conv", "var3", "third_value", VariableType.STRING, mut=True)
result = await VariableAggregatorNode(NON_GROUP_CONFIG, {}).execute(state, variable_pool)
assert result == "first_value"
@pytest.mark.asyncio
async def test_non_group_skip_none():
"""测试非分组模式跳过 None 值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 第一个变量不存在,第二个存在
await variable_pool.new("conv", "var2", "second_value", VariableType.STRING, mut=True)
await variable_pool.new("conv", "var3", "third_value", VariableType.STRING, mut=True)
result = await VariableAggregatorNode(NON_GROUP_CONFIG, {}).execute(state, variable_pool)
assert result == "second_value"
@pytest.mark.asyncio
async def test_non_group_all_none():
"""测试非分组模式所有变量都不存在"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 不创建任何变量
result = await VariableAggregatorNode(NON_GROUP_CONFIG, {}).execute(state, variable_pool)
assert result == ""
@pytest.mark.asyncio
async def test_non_group_with_type_all_none():
"""测试非分组模式带类型定义,所有变量都不存在"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 不创建任何变量
result = await VariableAggregatorNode(NON_GROUP_WITH_TYPE_CONFIG, {}).execute(state, variable_pool)
# 应该返回类型的默认值
assert result == ""
@pytest.mark.asyncio
async def test_non_group_different_types():
"""测试非分组模式不同类型的变量"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": False,
"group_variables": [
"{{conv.num}}",
"{{conv.str}}",
"{{conv.bool}}"
]
}
}
# 设置不同类型的变量
await variable_pool.new("conv", "num", 123, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "str", "text", VariableType.STRING, mut=True)
await variable_pool.new("conv", "bool", True, VariableType.BOOLEAN, mut=True)
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
assert result == 123
@pytest.mark.asyncio
async def test_non_group_zero_and_false():
"""测试非分组模式零值和 False 值(不应被视为 None"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": False,
"group_variables": [
"{{conv.zero}}",
"{{conv.text}}"
]
}
}
# 设置零值
await variable_pool.new("conv", "zero", 0, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "text", "fallback", VariableType.STRING, mut=True)
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
# 0 不应被视为 None应该返回 0
assert result == 0
@pytest.mark.asyncio
async def test_non_group_false_value():
"""测试非分组模式 False 值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": False,
"group_variables": [
"{{conv.flag}}",
"{{conv.text}}"
]
}
}
# 设置 False 值
await variable_pool.new("conv", "flag", False, VariableType.BOOLEAN, mut=True)
await variable_pool.new("conv", "text", "fallback", VariableType.STRING, mut=True)
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
# False 不应被视为 None应该返回 False
assert result is False
# ==================== 分组模式测试 ====================
@pytest.mark.asyncio
async def test_group_mode_all_groups():
"""测试分组模式所有分组都有值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 设置变量
await variable_pool.new("conv", "msg1", "Hello", VariableType.STRING, mut=True)
await variable_pool.new("conv", "name1", "Alice", VariableType.STRING, mut=True)
result = await VariableAggregatorNode(GROUP_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["user_message"] == "Hello"
assert result["user_name"] == "Alice"
@pytest.mark.asyncio
async def test_group_mode_fallback():
"""测试分组模式使用备用变量"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 第一个变量不存在,使用第二个
await variable_pool.new("conv", "msg2", "Fallback message", VariableType.STRING, mut=True)
await variable_pool.new("conv", "name2", "Bob", VariableType.STRING, mut=True)
result = await VariableAggregatorNode(GROUP_CONFIG, {}).execute(state, variable_pool)
assert result["user_message"] == "Fallback message"
assert result["user_name"] == "Bob"
@pytest.mark.asyncio
async def test_group_mode_partial_none():
"""测试分组模式部分分组没有值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 只设置一个分组的变量
await variable_pool.new("conv", "msg1", "Hello", VariableType.STRING, mut=True)
result = await VariableAggregatorNode(GROUP_CONFIG, {}).execute(state, variable_pool)
assert result["user_message"] == "Hello"
assert result["user_name"] == "" # 没有值的分组返回空字符串
@pytest.mark.asyncio
async def test_group_mode_all_none():
"""测试分组模式所有分组都没有值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 不创建任何变量
result = await VariableAggregatorNode(GROUP_CONFIG, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["user_message"] == ""
assert result["user_name"] == ""
@pytest.mark.asyncio
async def test_group_mode_with_type():
"""测试分组模式带类型定义"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 设置变量
await variable_pool.new("conv", "count1", 100, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "flag1", True, VariableType.BOOLEAN, mut=True)
result = await VariableAggregatorNode(GROUP_WITH_TYPE_CONFIG, {}).execute(state, variable_pool)
assert result["count"] == 100
assert result["enabled"] is True
@pytest.mark.asyncio
async def test_group_mode_with_type_defaults():
"""测试分组模式带类型定义,使用默认值"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 不创建任何变量
result = await VariableAggregatorNode(GROUP_WITH_TYPE_CONFIG, {}).execute(state, variable_pool)
# 应该返回类型的默认值
assert result["count"] == 0 # number 的默认值
assert result["enabled"] is False # boolean 的默认值
@pytest.mark.asyncio
async def test_group_mode_mixed_values():
"""测试分组模式混合有值和无值的情况"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
# 只设置 count2
await variable_pool.new("conv", "count2", 200, VariableType.NUMBER, mut=True)
result = await VariableAggregatorNode(GROUP_WITH_TYPE_CONFIG, {}).execute(state, variable_pool)
assert result["count"] == 200 # 使用第二个变量
assert result["enabled"] is False # 没有值,使用默认值
@pytest.mark.asyncio
async def test_group_mode_multiple_groups():
"""测试分组模式多个分组"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": True,
"group_variables": {
"group1": ["{{conv.g1_v1}}", "{{conv.g1_v2}}"],
"group2": ["{{conv.g2_v1}}", "{{conv.g2_v2}}"],
"group3": ["{{conv.g3_v1}}", "{{conv.g3_v2}}"]
}
}
}
# 设置不同分组的变量
await variable_pool.new("conv", "g1_v1", "value1", VariableType.STRING, mut=True)
await variable_pool.new("conv", "g2_v2", "value2", VariableType.STRING, mut=True)
await variable_pool.new("conv", "g3_v1", "value3", VariableType.STRING, mut=True)
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
assert result["group1"] == "value1"
assert result["group2"] == "value2"
assert result["group3"] == "value3"
# ==================== 复杂场景测试 ====================
@pytest.mark.asyncio
async def test_aggregator_with_array():
"""测试聚合数组变量"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": False,
"group_variables": [
"{{conv.arr1}}",
"{{conv.arr2}}"
]
}
}
# 设置数组变量
await variable_pool.new("conv", "arr1", [1, 2, 3], VariableType.ARRAY_NUMBER, mut=True)
await variable_pool.new("conv", "arr2", [4, 5, 6], VariableType.ARRAY_NUMBER, mut=True)
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
assert result == [1, 2, 3]
@pytest.mark.asyncio
async def test_aggregator_with_object():
"""测试聚合对象变量"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": False,
"group_variables": [
"{{conv.obj1}}",
"{{conv.obj2}}"
]
}
}
# 设置对象变量
await variable_pool.new("conv", "obj1", {"key": "value1"}, VariableType.OBJECT, mut=True)
await variable_pool.new("conv", "obj2", {"key": "value2"}, VariableType.OBJECT, mut=True)
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
assert result == {"key": "value1"}
@pytest.mark.asyncio
async def test_aggregator_empty_string():
"""测试空字符串不被视为 None"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": False,
"group_variables": [
"{{conv.empty}}",
"{{conv.text}}"
]
}
}
# 设置空字符串
await variable_pool.new("conv", "empty", "", VariableType.STRING, mut=True)
await variable_pool.new("conv", "text", "fallback", VariableType.STRING, mut=True)
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
# 空字符串不应被视为 None应该返回空字符串
assert result == ""
@pytest.mark.asyncio
async def test_aggregator_empty_array():
"""测试空数组不被视为 None"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": False,
"group_variables": [
"{{conv.empty_arr}}",
"{{conv.arr}}"
]
}
}
# 设置空数组
await variable_pool.new("conv", "empty_arr", [], VariableType.ARRAY_NUMBER, mut=True)
await variable_pool.new("conv", "arr", [1, 2], VariableType.ARRAY_NUMBER, mut=True)
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
# 空数组不应被视为 None应该返回空数组
assert result == []
@pytest.mark.asyncio
async def test_aggregator_empty_object():
"""测试空对象不被视为 None"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": False,
"group_variables": [
"{{conv.empty_obj}}",
"{{conv.obj}}"
]
}
}
# 设置空对象
await variable_pool.new("conv", "empty_obj", {}, VariableType.OBJECT, mut=True)
await variable_pool.new("conv", "obj", {"key": "value"}, VariableType.OBJECT, mut=True)
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
# 空对象不应被视为 None应该返回空对象
assert result == {}
@pytest.mark.asyncio
async def test_group_mode_with_different_types():
"""测试分组模式不同类型的变量"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": True,
"group_variables": {
"text": ["{{conv.str1}}", "{{conv.str2}}"],
"number": ["{{conv.num1}}", "{{conv.num2}}"],
"array": ["{{conv.arr1}}", "{{conv.arr2}}"],
"object": ["{{conv.obj1}}", "{{conv.obj2}}"]
},
"group_type": {
"text": "string",
"number": "number",
"array": "array[number]",
"object": "object"
}
}
}
# 设置不同类型的变量
await variable_pool.new("conv", "str1", "hello", VariableType.STRING, mut=True)
await variable_pool.new("conv", "num1", 42, VariableType.NUMBER, mut=True)
await variable_pool.new("conv", "arr1", [1, 2, 3], VariableType.ARRAY_NUMBER, mut=True)
await variable_pool.new("conv", "obj1", {"key": "value"}, VariableType.OBJECT, mut=True)
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
assert result["text"] == "hello"
assert result["number"] == 42
assert result["array"] == [1, 2, 3]
assert result["object"] == {"key": "value"}
@pytest.mark.asyncio
async def test_aggregator_output_types():
"""测试输出类型定义"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
node = VariableAggregatorNode(GROUP_WITH_TYPE_CONFIG, {})
output_types = node._output_types()
assert output_types["count"] == VariableType.NUMBER
assert output_types["enabled"] == VariableType.BOOLEAN
@pytest.mark.asyncio
async def test_non_group_single_variable():
"""测试非分组模式只有一个变量"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": False,
"group_variables": [
"{{conv.only_var}}"
]
}
}
await variable_pool.new("conv", "only_var", "single_value", VariableType.STRING, mut=True)
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
assert result == "single_value"
@pytest.mark.asyncio
async def test_group_mode_single_group():
"""测试分组模式只有一个分组"""
state = simple_state()
variable_pool = await simple_vairable_pool("test")
config = {
"id": "aggregator_test",
"type": "var-aggregator",
"name": "变量聚合测试节点",
"config": {
"group": True,
"group_variables": {
"only_group": ["{{conv.var1}}", "{{conv.var2}}"]
}
}
}
await variable_pool.new("conv", "var1", "value", VariableType.STRING, mut=True)
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
assert isinstance(result, dict)
assert result["only_group"] == "value"

View File

@@ -36,7 +36,7 @@ async def run_code(request: RunCodeRequest):
elif request.language == "javascript":
return await run_nodejs_code(request.code, request.preload, request.options)
else:
return error_response(-400, "unsupported language")
return error_response(400, "unsupported language")
@router.get("/dependencies", response_model=ApiResponse)
@@ -45,7 +45,7 @@ async def get_dependencies(language: str):
if language == "python3":
return await list_python_dependencies()
else:
return error_response(-400, "unsupported language")
return error_response(400, "unsupported language")
@router.post("/dependencies/update", response_model=ApiResponse)
@@ -54,4 +54,4 @@ async def update_dependencies(request: UpdateDependencyRequest):
if request.language == "python3":
return await update_python_dependencies()
else:
return error_response(-400, "unsupported language")
return error_response(400, "unsupported language")

View File

@@ -75,6 +75,4 @@ def success_response(data: Any) -> ApiResponse:
def error_response(code: int, message: str) -> ApiResponse:
"""Create error response"""
if code >= 0:
code = -1
return ApiResponse(code=code, message=message, data=None)

View File

@@ -27,11 +27,11 @@ async def run_nodejs_code(code: str, preload: str, options: RunnerOptions):
try:
runner = NodejsRunner()
result = await runner.run(code, options, preload)
if result.exit_code == signal.SIGSYS + 0x80:
if result.exit_code in [signal.SIGSYS + 0x80, -signal.SIGSYS]:
return error_response(31, "sandbox security policy violation")
if result.exit_code != 0:
return error_response(500, result.stderr)
return error_response(result.exit_code, result.stderr)
return success_response(RunCodeResponse(
stdout=result.stdout,
@@ -39,5 +39,5 @@ async def run_nodejs_code(code: str, preload: str, options: RunnerOptions):
))
except Exception as e:
logger.error(f"Python execution failed: {e}", exc_info=True)
return error_response(-500, str(e))
logger.error(f"JavaScript execution failed: {e}", exc_info=True)
return error_response(500, str(e))

View File

@@ -47,7 +47,7 @@ async def run_python_code(code: str, preload: str, options: RunnerOptions):
except Exception as e:
logger.error(f"Python execution failed: {e}", exc_info=True)
return error_response(-500, str(e))
return error_response(500, str(e))
async def list_python_dependencies():