Merge branch 'develop' into fix/memory-enduser-config
This commit is contained in:
28618
api/General_purpose_entity.ttl
Normal file
28618
api/General_purpose_entity.ttl
Normal file
File diff suppressed because it is too large
Load Diff
@@ -290,7 +290,8 @@ async def pilot_run(
|
||||
|
||||
api_logger.info(
|
||||
f"Pilot run requested: config_id={payload.config_id}, "
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}, language={language}"
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}, "
|
||||
f"custom_text_length={len(payload.custom_text) if payload.custom_text else 0}"
|
||||
)
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
svc = DataConfigService(db)
|
||||
|
||||
@@ -4,13 +4,14 @@
|
||||
|
||||
Endpoints:
|
||||
POST /api/memory/ontology/extract - 提取本体类
|
||||
POST /api/memory/ontology/export - 导出OWL文件
|
||||
POST /api/memory/ontology/export - 按场景导出OWL文件
|
||||
POST /api/memory/ontology/import - 导入OWL文件到指定场景
|
||||
POST /api/memory/ontology/scene - 创建本体场景
|
||||
PUT /api/memory/ontology/scene/{scene_id} - 更新本体场景
|
||||
DELETE /api/memory/ontology/scene/{scene_id} - 删除本体场景
|
||||
GET /api/memory/ontology/scene/{scene_id} - 获取单个场景
|
||||
GET /api/memory/ontology/scenes - 获取场景列表
|
||||
POST /api/memory/ontology/class - 创建本体类型
|
||||
POST /api/memory/ontology/class - 创建本体类型(支持批量)
|
||||
PUT /api/memory/ontology/class/{class_id} - 更新本体类型
|
||||
DELETE /api/memory/ontology/class/{class_id} - 删除本体类型
|
||||
GET /api/memory/ontology/class/{class_id} - 获取单个类型
|
||||
@@ -19,11 +20,15 @@ Endpoints:
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import Dict, Optional
|
||||
import io
|
||||
from typing import Dict, Optional, List
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header
|
||||
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
@@ -31,11 +36,10 @@ from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.core.memory.models.ontology_models import OntologyClass
|
||||
from typing import List
|
||||
from app.core.memory.models.ontology_scenario_models import OntologyClass
|
||||
from app.schemas.ontology_schemas import (
|
||||
ExportRequest,
|
||||
ExportResponse,
|
||||
ExportBySceneRequest,
|
||||
ExportBySceneResponse,
|
||||
ExtractionRequest,
|
||||
ExtractionResponse,
|
||||
SceneCreateRequest,
|
||||
@@ -46,6 +50,7 @@ from app.schemas.ontology_schemas import (
|
||||
ClassUpdateRequest,
|
||||
ClassResponse,
|
||||
ClassListResponse,
|
||||
ImportOwlResponse,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.ontology_service import OntologyService
|
||||
@@ -187,22 +192,19 @@ async def extract_ontology(
|
||||
从场景描述中提取符合OWL规范的本体类。
|
||||
提取结果仅返回给前端,不会自动保存到数据库。
|
||||
前端可以从返回结果中选择需要的类型,然后调用 /class 接口创建类型。
|
||||
支持中英文切换,通过 X-Language-Type Header 指定语言。
|
||||
|
||||
Args:
|
||||
request: 提取请求,包含scenario、domain、llm_id和scene_id
|
||||
language_type: 语言类型,'zh'(中文)或 'en'(英文),默认 'zh'
|
||||
language_type: 语言类型 Header (zh/en)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Ontology extraction requested by user {current_user.id}, "
|
||||
f"scenario_length={len(request.scenario)}, "
|
||||
f"domain={request.domain}, "
|
||||
f"llm_id={request.llm_id}, "
|
||||
f"scene_id={request.scene_id}, "
|
||||
f"language_type={language_type}"
|
||||
f"scene_id={request.scene_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -222,7 +224,7 @@ async def extract_ontology(
|
||||
llm_id=request.llm_id
|
||||
)
|
||||
|
||||
# 调用服务层执行提取,传入scene_id和workspace_id
|
||||
# 调用服务层执行提取
|
||||
result = await service.extract_ontology(
|
||||
scenario=request.scenario,
|
||||
domain=request.domain,
|
||||
@@ -231,7 +233,7 @@ async def extract_ontology(
|
||||
language=language
|
||||
)
|
||||
|
||||
# 构建响应(语言已在提取时通过模板控制,无需二次翻译)
|
||||
# 构建响应
|
||||
response = ExtractionResponse(
|
||||
classes=result.classes,
|
||||
domain=result.domain,
|
||||
@@ -240,7 +242,7 @@ async def extract_ontology(
|
||||
|
||||
api_logger.info(
|
||||
f"Ontology extraction completed, extracted {len(result.classes)} classes, "
|
||||
f"saved to scene {request.scene_id}, language={language_type}"
|
||||
f"scene_id={request.scene_id}, language={language}"
|
||||
)
|
||||
|
||||
return success(data=response.model_dump(), msg="本体提取成功")
|
||||
@@ -261,146 +263,6 @@ async def extract_ontology(
|
||||
return fail(BizCode.INTERNAL_ERROR, "本体提取失败", str(e))
|
||||
|
||||
|
||||
@router.post("/export", response_model=ApiResponse)
|
||||
async def export_owl(
|
||||
request: ExportRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""导出OWL文件
|
||||
|
||||
将提取的本体类导出为OWL文件,支持多种格式。
|
||||
导出操作不需要LLM,只使用OWL验证器和Owlready2库。
|
||||
|
||||
Args:
|
||||
request: 导出请求,包含classes、format和include_metadata
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含OWL文件内容的响应
|
||||
|
||||
Supported formats:
|
||||
- rdfxml: 标准OWL RDF/XML格式(完整)
|
||||
- turtle: Turtle格式(可读性好)
|
||||
- ntriples: N-Triples格式(简单)
|
||||
- json: JSON格式(简化,只包含类信息)
|
||||
|
||||
Response format:
|
||||
{
|
||||
"code": 200,
|
||||
"msg": "OWL文件导出成功",
|
||||
"data": {
|
||||
"owl_content": "...",
|
||||
"format": "rdfxml",
|
||||
"classes_count": 7
|
||||
}
|
||||
}
|
||||
"""
|
||||
api_logger.info(
|
||||
f"OWL export requested by user {current_user.id}, "
|
||||
f"classes_count={len(request.classes)}, "
|
||||
f"format={request.format}, "
|
||||
f"include_metadata={request.include_metadata}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证格式
|
||||
valid_formats = ["rdfxml", "turtle", "ntriples", "json"]
|
||||
if request.format not in valid_formats:
|
||||
api_logger.warning(f"Invalid export format: {request.format}")
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"不支持的导出格式",
|
||||
f"format必须是以下之一: {', '.join(valid_formats)}"
|
||||
)
|
||||
|
||||
# JSON格式直接导出,不需要OWL验证
|
||||
if request.format == "json":
|
||||
owl_validator = OWLValidator()
|
||||
owl_content = owl_validator.export_to_owl(
|
||||
world=None,
|
||||
format="json",
|
||||
classes=request.classes
|
||||
)
|
||||
|
||||
response = ExportResponse(
|
||||
owl_content=owl_content,
|
||||
format=request.format,
|
||||
classes_count=len(request.classes)
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"JSON export completed, content_length={len(owl_content)}"
|
||||
)
|
||||
|
||||
return success(data=response.model_dump(), msg="OWL文件导出成功")
|
||||
|
||||
# 创建临时文件路径
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode='w',
|
||||
suffix='.owl',
|
||||
delete=False
|
||||
) as tmp_file:
|
||||
output_path = tmp_file.name
|
||||
|
||||
# 导出操作不需要LLM,直接使用OWL验证器
|
||||
owl_validator = OWLValidator()
|
||||
|
||||
# 验证本体类
|
||||
logger.debug("Validating ontology classes")
|
||||
is_valid, errors, world = owl_validator.validate_ontology_classes(
|
||||
classes=request.classes,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"OWL validation found {len(errors)} issues during export: {errors}"
|
||||
)
|
||||
# 继续导出,但记录警告
|
||||
|
||||
if not world:
|
||||
error_msg = "Failed to create OWL world for export"
|
||||
logger.error(error_msg)
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建OWL世界失败", error_msg)
|
||||
|
||||
# 导出OWL文件
|
||||
logger.info(f"Exporting to {request.format} format")
|
||||
owl_content = owl_validator.export_to_owl(
|
||||
world=world,
|
||||
output_path=output_path,
|
||||
format=request.format,
|
||||
classes=request.classes
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response = ExportResponse(
|
||||
owl_content=owl_content,
|
||||
format=request.format,
|
||||
classes_count=len(request.classes)
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"OWL export completed, format={request.format}, "
|
||||
f"content_length={len(owl_content)}"
|
||||
)
|
||||
|
||||
return success(data=response.model_dump(), msg="OWL文件导出成功")
|
||||
|
||||
except ValueError as e:
|
||||
# 验证错误 (400)
|
||||
api_logger.warning(f"Validation error in export: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
# 运行时错误 (500)
|
||||
api_logger.error(f"Runtime error in export: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "OWL文件导出失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
# 未知错误 (500)
|
||||
api_logger.error(f"Unexpected error in export: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "OWL文件导出失败", str(e))
|
||||
|
||||
|
||||
# ==================== 本体场景管理接口 ====================
|
||||
@@ -893,3 +755,370 @@ async def get_class(
|
||||
"""
|
||||
from app.controllers.ontology_secondary_routes import get_class_handler
|
||||
return await get_class_handler(class_id, db, current_user)
|
||||
|
||||
|
||||
# ==================== OWL 导入接口 ====================
|
||||
|
||||
@router.post("/import", response_model=ApiResponse)
|
||||
async def import_owl_file(
|
||||
scene_name: str = Form(..., description="场景名称"),
|
||||
scene_description: Optional[str] = Form(None, description="场景描述(可选)"),
|
||||
file: UploadFile = File(..., description="OWL/TTL文件"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""导入 OWL/TTL 文件并创建新场景
|
||||
|
||||
上传 OWL 或 TTL 文件,解析其中定义的本体类型(owl:Class),
|
||||
解析成功后创建新场景,并将类型保存到该场景的 ontology_class 表中。
|
||||
|
||||
文件格式根据文件扩展名自动识别:
|
||||
- .owl, .rdf, .xml -> rdfxml 格式
|
||||
- .ttl -> turtle 格式
|
||||
|
||||
Args:
|
||||
scene_name: 场景名称(表单字段)
|
||||
scene_description: 场景描述(表单字段,可选)
|
||||
file: 上传的文件(支持 .owl, .ttl, .rdf, .xml)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含导入结果
|
||||
"""
|
||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
# 根据文件扩展名确定格式
|
||||
filename = file.filename.lower() if file.filename else ""
|
||||
if filename.endswith('.ttl'):
|
||||
owl_format = "turtle"
|
||||
file_type = "ttl"
|
||||
elif filename.endswith(('.owl', '.rdf', '.xml')):
|
||||
owl_format = "rdfxml"
|
||||
file_type = "owl"
|
||||
else:
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"文件格式不支持",
|
||||
f"不支持的文件格式: {filename},支持的格式: .owl, .ttl, .rdf, .xml"
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"OWL import requested by user {current_user.id}, "
|
||||
f"scene_name={scene_name}, "
|
||||
f"filename={file.filename}, "
|
||||
f"format={owl_format}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 1. 验证场景名称不为空
|
||||
if not scene_name or not scene_name.strip():
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "场景名称不能为空")
|
||||
|
||||
scene_name = scene_name.strip()
|
||||
|
||||
# 2. 检查场景名称是否已存在
|
||||
scene_repo = OntologySceneRepository(db)
|
||||
existing_scene = scene_repo.get_by_name(scene_name, workspace_id)
|
||||
if existing_scene:
|
||||
api_logger.warning(f"Scene name already exists: {scene_name}")
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"场景名称已存在",
|
||||
f"工作空间下已存在名为 '{scene_name}' 的场景"
|
||||
)
|
||||
|
||||
# 3. 读取文件内容
|
||||
try:
|
||||
content = await file.read()
|
||||
owl_content = content.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
f"{file_type}文件导入失败",
|
||||
"文件编码错误,请确保文件使用 UTF-8 编码"
|
||||
)
|
||||
|
||||
# 4. 解析 OWL 内容(先解析,成功后再创建场景)
|
||||
owl_validator = OWLValidator()
|
||||
parsed_classes = owl_validator.parse_owl_content(
|
||||
owl_content=owl_content,
|
||||
format=owl_format
|
||||
)
|
||||
|
||||
if not parsed_classes:
|
||||
api_logger.warning("No classes found in OWL content")
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"未找到本体类型",
|
||||
"文件中没有定义任何本体类型(owl:Class)"
|
||||
)
|
||||
|
||||
# 5. 文件解析成功,创建场景
|
||||
scene = scene_repo.create(
|
||||
scene_data={
|
||||
"scene_name": scene_name,
|
||||
"scene_description": scene_description
|
||||
},
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
scene_uuid = scene.scene_id
|
||||
|
||||
api_logger.info(f"Scene created for import: {scene_uuid}")
|
||||
|
||||
# 6. 批量创建类型(去重同一批次内的重复类型)
|
||||
class_repo = OntologyClassRepository(db)
|
||||
created_items = []
|
||||
existing_names = set()
|
||||
skipped_count = 0
|
||||
|
||||
for cls in parsed_classes:
|
||||
class_name = cls["name"]
|
||||
class_description = cls.get("description")
|
||||
|
||||
# 检查同一批次内是否重复
|
||||
if class_name in existing_names:
|
||||
skipped_count += 1
|
||||
api_logger.debug(f"Skipping duplicate class in batch: {class_name}")
|
||||
continue
|
||||
|
||||
# 创建类型
|
||||
ontology_class = class_repo.create(
|
||||
class_data={
|
||||
"class_name": class_name,
|
||||
"class_description": class_description
|
||||
},
|
||||
scene_id=scene_uuid
|
||||
)
|
||||
|
||||
# 添加到已存在集合,防止同一批次内重复
|
||||
existing_names.add(class_name)
|
||||
|
||||
created_items.append(ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
))
|
||||
|
||||
# 7. 提交事务
|
||||
db.commit()
|
||||
|
||||
# 8. 构建响应
|
||||
response = ImportOwlResponse(
|
||||
scene_id=scene_uuid,
|
||||
scene_name=scene.scene_name,
|
||||
imported_count=len(created_items),
|
||||
skipped_count=skipped_count,
|
||||
items=created_items
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"{file_type} import completed, "
|
||||
f"scene_id={scene_uuid}, "
|
||||
f"scene_name={scene_name}, "
|
||||
f"format={owl_format}, "
|
||||
f"imported={len(created_items)}, "
|
||||
f"skipped={skipped_count}"
|
||||
)
|
||||
|
||||
return success(data=response.model_dump(), msg=f"{file_type}文件导入成功")
|
||||
|
||||
except ValueError as e:
|
||||
db.rollback()
|
||||
api_logger.warning(f"Validation error in import: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, f"{file_type}文件导入失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"Unexpected error in import: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, f"{file_type}文件导入失败", str(e))
|
||||
|
||||
# ==================== OWL 导出接口 ====================
|
||||
@router.post("/export")
|
||||
async def export_owl_by_scene(
|
||||
request: ExportBySceneRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""按场景导出OWL/TTL文件
|
||||
|
||||
根据scene_id从数据库查询该场景下的所有本体类型,并导出为文件下载。
|
||||
|
||||
Args:
|
||||
request: 导出请求,包含 scene_id 和 format
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
StreamingResponse: 文件流响应,浏览器会直接下载文件
|
||||
"""
|
||||
from uuid import UUID
|
||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
api_logger.info(
|
||||
f"OWL export by scene requested by user {current_user.id}, "
|
||||
f"scene_id={request.scene_id}, "
|
||||
f"format={request.format}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证格式参数
|
||||
valid_formats = ["rdfxml", "turtle"]
|
||||
owl_format = request.format.lower() if request.format else "rdfxml"
|
||||
if owl_format not in valid_formats:
|
||||
api_logger.warning(f"Invalid format: {request.format}")
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"格式参数无效",
|
||||
f"不支持的格式: {request.format},支持的格式: rdfxml, turtle"
|
||||
)
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 1. 查询场景信息
|
||||
scene_repo = OntologySceneRepository(db)
|
||||
scene = scene_repo.get_by_id(request.scene_id)
|
||||
|
||||
if not scene:
|
||||
api_logger.warning(f"Scene not found: {request.scene_id}")
|
||||
return fail(BizCode.NOT_FOUND, "场景不存在", f"找不到场景: {request.scene_id}")
|
||||
|
||||
# 验证场景属于当前工作空间
|
||||
if scene.workspace_id != workspace_id:
|
||||
api_logger.warning(
|
||||
f"Scene {request.scene_id} does not belong to workspace {workspace_id}"
|
||||
)
|
||||
return fail(BizCode.FORBIDDEN, "无权访问", "该场景不属于当前工作空间")
|
||||
|
||||
# 2. 查询场景下的所有本体类型
|
||||
class_repo = OntologyClassRepository(db)
|
||||
ontology_classes_db = class_repo.get_by_scene(request.scene_id)
|
||||
|
||||
if not ontology_classes_db:
|
||||
api_logger.warning(f"No classes found in scene: {request.scene_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "场景为空", "该场景下没有定义任何本体类型")
|
||||
|
||||
# 3. 将数据库模型转换为OWL导出所需的OntologyClass格式
|
||||
ontology_classes: List[OntologyClass] = []
|
||||
for db_class in ontology_classes_db:
|
||||
owl_class = OntologyClass(
|
||||
id=str(db_class.class_id),
|
||||
name=db_class.class_name,
|
||||
name_chinese=db_class.class_name if _is_chinese(db_class.class_name) else None,
|
||||
description=db_class.class_description or "",
|
||||
examples=[],
|
||||
parent_class=None,
|
||||
entity_type="Concept",
|
||||
domain=scene.scene_name
|
||||
)
|
||||
ontology_classes.append(owl_class)
|
||||
|
||||
# 4. 确定文件名、扩展名和 MIME 类型
|
||||
file_ext = ".ttl" if owl_format == "turtle" else ".owl"
|
||||
filename = _sanitize_filename(scene.scene_name) + file_ext
|
||||
media_type = "text/turtle" if owl_format == "turtle" else "application/rdf+xml"
|
||||
file_type = "ttl" if owl_format == "turtle" else "owl"
|
||||
|
||||
# 5. 导出OWL文件
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode='w',
|
||||
suffix='.owl',
|
||||
delete=False
|
||||
) as tmp_file:
|
||||
output_path = tmp_file.name
|
||||
|
||||
owl_validator = OWLValidator()
|
||||
|
||||
# 验证本体类
|
||||
is_valid, errors, world = owl_validator.validate_ontology_classes(
|
||||
classes=ontology_classes,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"OWL validation found {len(errors)} issues during export: {errors}"
|
||||
)
|
||||
|
||||
if not world:
|
||||
error_msg = "Failed to create OWL world for export"
|
||||
logger.error(error_msg)
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建OWL世界失败", error_msg)
|
||||
|
||||
# 导出OWL文件(使用请求指定的格式)
|
||||
owl_content = owl_validator.export_to_owl(
|
||||
world=world,
|
||||
output_path=output_path,
|
||||
format=owl_format,
|
||||
classes=ontology_classes
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"{file_type} export by scene completed, "
|
||||
f"scene={scene.scene_name}, "
|
||||
f"filename={filename}, "
|
||||
f"format={owl_format}, "
|
||||
f"classes_count={len(ontology_classes)}"
|
||||
)
|
||||
|
||||
# 6. 返回文件流响应
|
||||
# filename 使用 ASCII 安全的默认名,filename* 使用 UTF-8 编码的实际名称
|
||||
ascii_filename = f"ontology{file_ext}"
|
||||
encoded_filename = quote(filename)
|
||||
return StreamingResponse(
|
||||
io.BytesIO(owl_content.encode('utf-8')),
|
||||
media_type=media_type,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=\"{ascii_filename}\"; filename*=UTF-8''{encoded_filename}"
|
||||
}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in export by scene: {str(e)}")
|
||||
file_type = "ttl" if (request.format and request.format.lower() == "turtle") else "owl"
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in export by scene: {str(e)}", exc_info=True)
|
||||
file_type = "ttl" if (request.format and request.format.lower() == "turtle") else "owl"
|
||||
return fail(BizCode.INTERNAL_ERROR, f"{file_type}文件导出失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in export by scene: {str(e)}", exc_info=True)
|
||||
file_type = "ttl" if (request.format and request.format.lower() == "turtle") else "owl"
|
||||
return fail(BizCode.INTERNAL_ERROR, f"{file_type}文件导出失败", str(e))
|
||||
|
||||
|
||||
def _is_chinese(text: str) -> bool:
|
||||
"""检查文本是否包含中文字符"""
|
||||
for char in text:
|
||||
if '\u4e00' <= char <= '\u9fff':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _sanitize_filename(name: str) -> str:
|
||||
"""清理文件名,移除不合法字符"""
|
||||
import re
|
||||
# 移除或替换不合法的文件名字符
|
||||
sanitized = re.sub(r'[<>:"/\\|?*]', '_', name)
|
||||
# 移除前后空格
|
||||
sanitized = sanitized.strip()
|
||||
# 如果为空,使用默认名称
|
||||
if not sanitized:
|
||||
sanitized = "ontology_export"
|
||||
return sanitized
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
import uuid
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.schemas import skill_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
@@ -16,7 +16,6 @@ router = APIRouter(prefix="/skills", tags=["Skills"])
|
||||
|
||||
|
||||
@router.post("", summary="创建技能")
|
||||
@cur_workspace_access_guard()
|
||||
def create_skill(
|
||||
data: skill_schema.SkillCreate,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -29,7 +28,6 @@ def create_skill(
|
||||
|
||||
|
||||
@router.get("", summary="技能列表")
|
||||
@cur_workspace_access_guard()
|
||||
def list_skills(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_active: Optional[bool] = Query(None, description="是否激活"),
|
||||
@@ -51,7 +49,6 @@ def list_skills(
|
||||
|
||||
|
||||
@router.get("/{skill_id}", summary="获取技能详情")
|
||||
@cur_workspace_access_guard()
|
||||
def get_skill(
|
||||
skill_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
@@ -64,7 +61,6 @@ def get_skill(
|
||||
|
||||
|
||||
@router.put("/{skill_id}", summary="更新技能")
|
||||
@cur_workspace_access_guard()
|
||||
def update_skill(
|
||||
skill_id: uuid.UUID,
|
||||
data: skill_schema.SkillUpdate,
|
||||
@@ -78,7 +74,6 @@ def update_skill(
|
||||
|
||||
|
||||
@router.delete("/{skill_id}", summary="删除技能")
|
||||
@cur_workspace_access_guard()
|
||||
def delete_skill(
|
||||
skill_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
|
||||
@@ -221,6 +221,28 @@ class Settings:
|
||||
# workflow config
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
# ========================================================================
|
||||
# General Ontology Type Configuration
|
||||
# ========================================================================
|
||||
# 通用本体文件路径列表(逗号分隔)
|
||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl")
|
||||
|
||||
# 是否启用通用本体类型功能
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
|
||||
|
||||
# Prompt 中最大类型数量
|
||||
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
|
||||
|
||||
# 核心通用类型列表(逗号分隔)
|
||||
CORE_GENERAL_TYPES: str = os.getenv(
|
||||
"CORE_GENERAL_TYPES",
|
||||
"Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building,"
|
||||
"Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject"
|
||||
)
|
||||
|
||||
# 实验模式开关(允许通过 API 动态切换本体配置)
|
||||
ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true"
|
||||
|
||||
def get_memory_output_path(self, filename: str = "") -> str:
|
||||
"""
|
||||
Get the full path for memory module output files.
|
||||
|
||||
@@ -94,6 +94,31 @@ async def write(
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
pipeline_config = get_pipeline_config(memory_config)
|
||||
|
||||
# Fetch ontology types if scene_id is configured
|
||||
ontology_types = None
|
||||
if memory_config.scene_id:
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
|
||||
|
||||
with get_db_context() as db:
|
||||
ontology_types = load_ontology_types_for_scene(
|
||||
scene_id=memory_config.scene_id,
|
||||
workspace_id=memory_config.workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
if ontology_types:
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
|
||||
@@ -58,12 +58,25 @@ from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
|
||||
# Ontology models
|
||||
from app.core.memory.models.ontology_models import (
|
||||
# Ontology scenario models (LLM extracted from scenarios)
|
||||
from app.core.memory.models.ontology_scenario_models import (
|
||||
OntologyClass,
|
||||
OntologyExtractionResponse,
|
||||
)
|
||||
|
||||
# Ontology extraction models (for extraction flow)
|
||||
from app.core.memory.models.ontology_extraction_models import (
|
||||
OntologyTypeInfo,
|
||||
OntologyTypeList,
|
||||
)
|
||||
|
||||
# Ontology general models (loaded from external ontology files)
|
||||
from app.core.memory.models.ontology_general_models import (
|
||||
OntologyFileFormat,
|
||||
GeneralOntologyType,
|
||||
GeneralOntologyTypeRegistry,
|
||||
)
|
||||
|
||||
# Variable configuration models
|
||||
from app.core.memory.models.variate_config import (
|
||||
StatementExtractionConfig,
|
||||
@@ -114,6 +127,13 @@ __all__ = [
|
||||
# Ontology models
|
||||
"OntologyClass",
|
||||
"OntologyExtractionResponse",
|
||||
# Ontology type models for extraction flow
|
||||
"OntologyTypeInfo",
|
||||
"OntologyTypeList",
|
||||
# General ontology type models
|
||||
"OntologyFileFormat",
|
||||
"GeneralOntologyType",
|
||||
"GeneralOntologyTypeRegistry",
|
||||
# Variable configuration
|
||||
"StatementExtractionConfig",
|
||||
"ForgettingEngineConfig",
|
||||
|
||||
105
api/app/core/memory/models/ontology_extraction_models.py
Normal file
105
api/app/core/memory/models/ontology_extraction_models.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体类型数据结构模块
|
||||
|
||||
本模块定义用于在萃取流程中传递本体类型信息的轻量级数据类。
|
||||
|
||||
Classes:
|
||||
OntologyTypeInfo: 单个本体类型信息
|
||||
OntologyTypeList: 本体类型列表
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class OntologyTypeInfo:
|
||||
"""本体类型信息,用于萃取流程中传递。
|
||||
|
||||
Attributes:
|
||||
class_name: 类型名称
|
||||
class_description: 类型描述
|
||||
"""
|
||||
class_name: str
|
||||
class_description: str
|
||||
|
||||
def to_prompt_format(self) -> str:
|
||||
"""转换为提示词格式。
|
||||
|
||||
Returns:
|
||||
格式化的字符串,如 "- TypeName: Description"
|
||||
"""
|
||||
return f"- {self.class_name}: {self.class_description}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OntologyTypeList:
|
||||
"""本体类型列表。
|
||||
|
||||
Attributes:
|
||||
types: 本体类型信息列表
|
||||
"""
|
||||
types: List[OntologyTypeInfo]
|
||||
|
||||
@classmethod
|
||||
def from_db_models(cls, ontology_classes: list) -> "OntologyTypeList":
|
||||
"""从数据库模型转换创建 OntologyTypeList。
|
||||
|
||||
Args:
|
||||
ontology_classes: OntologyClass 数据库模型列表,
|
||||
每个对象应包含 class_name 和 class_description 属性
|
||||
|
||||
Returns:
|
||||
包含转换后类型信息的 OntologyTypeList 实例
|
||||
"""
|
||||
types = [
|
||||
OntologyTypeInfo(
|
||||
class_name=oc.class_name,
|
||||
class_description=oc.class_description or ""
|
||||
)
|
||||
for oc in ontology_classes
|
||||
]
|
||||
return cls(types=types)
|
||||
|
||||
def to_prompt_section(self) -> str:
|
||||
"""转换为提示词中的类型列表部分。
|
||||
|
||||
Returns:
|
||||
格式化的类型列表字符串,每行一个类型;
|
||||
如果列表为空则返回空字符串
|
||||
"""
|
||||
if not self.types:
|
||||
return ""
|
||||
lines = [t.to_prompt_format() for t in self.types]
|
||||
return "\n".join(lines)
|
||||
|
||||
def get_type_names(self) -> List[str]:
|
||||
"""获取所有类型名称列表。
|
||||
|
||||
Returns:
|
||||
类型名称字符串列表
|
||||
"""
|
||||
return [t.class_name for t in self.types]
|
||||
|
||||
def get_type_hierarchy_hints(self) -> List[str]:
|
||||
"""获取类型层次结构提示列表。
|
||||
|
||||
尝试从通用本体注册表中获取每个类型的继承链信息。
|
||||
|
||||
Returns:
|
||||
层次提示字符串列表,格式为 "类型名 → 父类1 → 父类2"
|
||||
"""
|
||||
hints = []
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_merger import OntologyTypeMerger
|
||||
|
||||
merger = OntologyTypeMerger()
|
||||
for type_info in self.types:
|
||||
hint = merger.get_type_hierarchy_hint(type_info.class_name)
|
||||
if hint:
|
||||
hints.append(hint)
|
||||
except Exception:
|
||||
# 如果无法获取层次信息,返回空列表
|
||||
pass
|
||||
|
||||
return hints
|
||||
223
api/app/core/memory/models/ontology_general_models.py
Normal file
223
api/app/core/memory/models/ontology_general_models.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""通用本体类型数据模型模块
|
||||
|
||||
本模块定义用于通用本体类型管理的数据结构,包括:
|
||||
- OntologyFileFormat: 本体文件格式枚举
|
||||
- GeneralOntologyType: 通用本体类型数据类
|
||||
- GeneralOntologyTypeRegistry: 通用本体类型注册表
|
||||
|
||||
Classes:
|
||||
OntologyFileFormat: 本体文件格式枚举,支持 TTL、OWL/XML、RDF/XML、N-Triples、JSON-LD
|
||||
GeneralOntologyType: 通用本体类型,包含类名、URI、标签、描述、父类等信息
|
||||
GeneralOntologyTypeRegistry: 类型注册表,管理类型集合和层次结构
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OntologyFileFormat(Enum):
|
||||
"""本体文件格式枚举
|
||||
|
||||
支持的格式:
|
||||
- TURTLE: Turtle 格式 (.ttl 文件)
|
||||
- RDF_XML: RDF/XML 格式 (.owl, .rdf 文件)
|
||||
- N_TRIPLES: N-Triples 格式 (.nt 文件)
|
||||
- JSON_LD: JSON-LD 格式 (.jsonld, .json 文件)
|
||||
"""
|
||||
TURTLE = "turtle" # .ttl 文件
|
||||
RDF_XML = "xml" # .owl, .rdf (RDF/XML 格式)
|
||||
N_TRIPLES = "nt" # .nt 文件
|
||||
JSON_LD = "json-ld" # .jsonld 文件
|
||||
|
||||
@classmethod
|
||||
def from_extension(cls, file_path: str) -> "OntologyFileFormat":
|
||||
"""根据文件扩展名推断格式
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
推断出的文件格式,默认返回 RDF_XML
|
||||
"""
|
||||
ext = file_path.lower().split('.')[-1]
|
||||
format_map = {
|
||||
'ttl': cls.TURTLE,
|
||||
'owl': cls.RDF_XML,
|
||||
'rdf': cls.RDF_XML,
|
||||
'nt': cls.N_TRIPLES,
|
||||
'jsonld': cls.JSON_LD,
|
||||
'json': cls.JSON_LD,
|
||||
}
|
||||
return format_map.get(ext, cls.RDF_XML)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralOntologyType:
|
||||
"""通用本体类型
|
||||
|
||||
表示从本体文件中解析出的类型定义,包含类型的基本信息和层次关系。
|
||||
|
||||
Attributes:
|
||||
class_name: 类型名称,如 "Person"
|
||||
class_uri: 完整 URI,如 "http://dbpedia.org/ontology/Person"
|
||||
labels: 多语言标签字典,键为语言代码(如 "en", "zh"),值为标签文本
|
||||
description: 类型描述
|
||||
parent_class: 父类名称,用于构建类型层次
|
||||
source_file: 来源文件路径
|
||||
"""
|
||||
class_name: str # 类型名称,如 "Person"
|
||||
class_uri: str # 完整 URI
|
||||
labels: Dict[str, str] = field(default_factory=dict) # 多语言标签
|
||||
description: Optional[str] = None # 类型描述
|
||||
parent_class: Optional[str] = None # 父类名称
|
||||
source_file: Optional[str] = None # 来源文件
|
||||
|
||||
def get_label(self, lang: str = "en") -> str:
|
||||
"""获取指定语言的标签
|
||||
|
||||
优先返回指定语言的标签,如果不存在则尝试返回英文标签,
|
||||
最后返回类型名称作为默认值。
|
||||
|
||||
Args:
|
||||
lang: 语言代码,默认为 "en"
|
||||
|
||||
Returns:
|
||||
指定语言的标签,或默认值
|
||||
"""
|
||||
return self.labels.get(lang, self.labels.get("en", self.class_name))
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralOntologyTypeRegistry:
|
||||
"""通用本体类型注册表
|
||||
|
||||
管理解析后的本体类型集合,提供类型查询、层次遍历、注册表合并等功能。
|
||||
|
||||
Attributes:
|
||||
types: 类型字典,键为类型名称,值为 GeneralOntologyType 实例
|
||||
hierarchy: 层次结构字典,键为父类名称,值为子类名称集合
|
||||
source_files: 已加载的源文件路径列表
|
||||
"""
|
||||
types: Dict[str, GeneralOntologyType] = field(default_factory=dict)
|
||||
hierarchy: Dict[str, Set[str]] = field(default_factory=dict) # 父类 -> 子类集合
|
||||
source_files: List[str] = field(default_factory=list)
|
||||
|
||||
def get_type(self, name: str) -> Optional[GeneralOntologyType]:
|
||||
"""根据名称获取类型
|
||||
|
||||
Args:
|
||||
name: 类型名称
|
||||
|
||||
Returns:
|
||||
对应的 GeneralOntologyType 实例,如果不存在则返回 None
|
||||
"""
|
||||
return self.types.get(name)
|
||||
|
||||
def get_ancestors(self, name: str) -> List[str]:
|
||||
"""获取类型的所有祖先类型(防循环)
|
||||
|
||||
从当前类型开始,沿着父类链向上遍历,返回所有祖先类型名称。
|
||||
使用 visited 集合防止循环引用导致的无限循环。
|
||||
|
||||
Args:
|
||||
name: 类型名称
|
||||
|
||||
Returns:
|
||||
祖先类型名称列表,按从近到远的顺序排列
|
||||
"""
|
||||
ancestors = []
|
||||
current = name
|
||||
visited = set()
|
||||
while current and current not in visited:
|
||||
visited.add(current)
|
||||
type_info = self.types.get(current)
|
||||
if type_info and type_info.parent_class:
|
||||
# 检测循环引用
|
||||
if type_info.parent_class in visited:
|
||||
logger.warning(
|
||||
f"检测到类型层次循环引用: {current} -> {type_info.parent_class},"
|
||||
f"已遍历路径: {' -> '.join([name] + ancestors)}"
|
||||
)
|
||||
break
|
||||
ancestors.append(type_info.parent_class)
|
||||
current = type_info.parent_class
|
||||
else:
|
||||
break
|
||||
return ancestors
|
||||
|
||||
def get_descendants(self, name: str) -> Set[str]:
|
||||
"""获取类型的所有后代类型
|
||||
|
||||
从当前类型开始,沿着子类关系向下遍历,返回所有后代类型名称。
|
||||
使用广度优先搜索,避免重复处理已访问的类型。
|
||||
|
||||
Args:
|
||||
name: 类型名称
|
||||
|
||||
Returns:
|
||||
后代类型名称集合
|
||||
"""
|
||||
descendants: Set[str] = set()
|
||||
to_process = [name]
|
||||
while to_process:
|
||||
current = to_process.pop()
|
||||
children = self.hierarchy.get(current, set())
|
||||
new_children = children - descendants
|
||||
descendants.update(new_children)
|
||||
to_process.extend(new_children)
|
||||
return descendants
|
||||
|
||||
def merge(self, other: "GeneralOntologyTypeRegistry") -> None:
|
||||
"""合并另一个注册表(先加载的优先)
|
||||
|
||||
将另一个注册表的类型和层次结构合并到当前注册表。
|
||||
对于同名类型,保留当前注册表中已存在的定义(先加载优先)。
|
||||
层次结构会合并所有子类关系。
|
||||
|
||||
Args:
|
||||
other: 要合并的另一个注册表
|
||||
"""
|
||||
for name, type_info in other.types.items():
|
||||
if name not in self.types:
|
||||
self.types[name] = type_info
|
||||
for parent, children in other.hierarchy.items():
|
||||
if parent not in self.hierarchy:
|
||||
self.hierarchy[parent] = set()
|
||||
self.hierarchy[parent].update(children)
|
||||
self.source_files.extend(other.source_files)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""获取注册表统计信息
|
||||
|
||||
Returns:
|
||||
包含以下键的字典:
|
||||
- total_types: 总类型数
|
||||
- root_types: 根类型数(无父类的类型)
|
||||
- max_depth: 类型层次的最大深度
|
||||
- source_files: 源文件列表
|
||||
"""
|
||||
return {
|
||||
"total_types": len(self.types),
|
||||
"root_types": len([t for t in self.types.values() if not t.parent_class]),
|
||||
"max_depth": self._calculate_max_depth(),
|
||||
"source_files": self.source_files,
|
||||
}
|
||||
|
||||
def _calculate_max_depth(self) -> int:
|
||||
"""计算类型层次的最大深度
|
||||
|
||||
遍历所有类型,计算每个类型到根的深度,返回最大值。
|
||||
|
||||
Returns:
|
||||
类型层次的最大深度
|
||||
"""
|
||||
max_depth = 0
|
||||
for type_name in self.types:
|
||||
depth = len(self.get_ancestors(type_name))
|
||||
max_depth = max(max_depth, depth)
|
||||
return max_depth
|
||||
30
api/app/core/memory/ontology_services/__init__.py
Normal file
30
api/app/core/memory/ontology_services/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体类型服务模块
|
||||
|
||||
本模块提供本体类型相关的服务,包括:
|
||||
- OntologyTypeMerger: 本体类型合并服务
|
||||
- get_general_ontology_registry: 获取通用本体类型注册表(单例,懒加载)
|
||||
- get_ontology_type_merger: 获取类型合并服务实例
|
||||
- reload_ontology_registry: 重新加载本体注册表(实验模式)
|
||||
- clear_ontology_cache: 清除本体缓存
|
||||
- is_general_ontology_enabled: 检查通用本体类型功能是否启用
|
||||
"""
|
||||
|
||||
from .ontology_type_merger import OntologyTypeMerger, DEFAULT_CORE_GENERAL_TYPES
|
||||
from .ontology_type_loader import (
|
||||
get_general_ontology_registry,
|
||||
get_ontology_type_merger,
|
||||
reload_ontology_registry,
|
||||
clear_ontology_cache,
|
||||
is_general_ontology_enabled,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OntologyTypeMerger",
|
||||
"DEFAULT_CORE_GENERAL_TYPES",
|
||||
"get_general_ontology_registry",
|
||||
"get_ontology_type_merger",
|
||||
"reload_ontology_registry",
|
||||
"clear_ontology_cache",
|
||||
"is_general_ontology_enabled",
|
||||
]
|
||||
145
api/app/core/memory/ontology_services/ontology_type_loader.py
Normal file
145
api/app/core/memory/ontology_services/ontology_type_loader.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""本体类型加载器
|
||||
|
||||
提供统一的本体类型加载逻辑,避免代码重复。
|
||||
|
||||
Functions:
|
||||
load_ontology_types_for_scene: 从数据库加载场景的本体类型
|
||||
is_general_ontology_enabled: 检查是否启用通用本体
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_ontology_types_for_scene(
|
||||
scene_id: Optional[UUID],
|
||||
workspace_id: UUID,
|
||||
db: Session
|
||||
) -> Optional["OntologyTypeList"]:
|
||||
"""从数据库加载场景的本体类型
|
||||
|
||||
统一的本体类型加载逻辑,用于替代各处重复的加载代码。
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID,如果为 None 则返回 None
|
||||
workspace_id: 工作空间ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
OntologyTypeList 如果场景有类型定义,否则返回 None
|
||||
|
||||
Examples:
|
||||
>>> ontology_types = load_ontology_types_for_scene(
|
||||
... scene_id=scene_uuid,
|
||||
... workspace_id=workspace_uuid,
|
||||
... db=db_session
|
||||
... )
|
||||
>>> if ontology_types:
|
||||
... print(f"Loaded {len(ontology_types.types)} types")
|
||||
"""
|
||||
if not scene_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
# 查询场景的本体类型
|
||||
ontology_repo = OntologyClassRepository(db)
|
||||
ontology_classes = ontology_repo.get_classes_by_scene(
|
||||
scene_id=scene_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
if not ontology_classes:
|
||||
logger.info(f"No ontology types found for scene_id: {scene_id}")
|
||||
return None
|
||||
|
||||
# 转换为 OntologyTypeList
|
||||
ontology_types = OntologyTypeList.from_db_models(ontology_classes)
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {scene_id}"
|
||||
)
|
||||
|
||||
return ontology_types
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load ontology types for scene_id {scene_id}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def create_empty_ontology_type_list() -> Optional["OntologyTypeList"]:
|
||||
"""创建空的本体类型列表(用于仅使用通用类型的场景)
|
||||
|
||||
Returns:
|
||||
空的 OntologyTypeList 如果通用本体已启用,否则返回 None
|
||||
"""
|
||||
try:
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
|
||||
if is_general_ontology_enabled():
|
||||
logger.info("Creating empty OntologyTypeList for general types only")
|
||||
return OntologyTypeList(types=[])
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create empty OntologyTypeList: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def is_general_ontology_enabled() -> bool:
|
||||
"""检查是否启用了通用本体
|
||||
|
||||
Returns:
|
||||
True 如果通用本体已启用,否则 False
|
||||
"""
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_merger import OntologyTypeMerger
|
||||
|
||||
merger = OntologyTypeMerger()
|
||||
return merger.general_registry is not None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check general ontology status: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def load_ontology_types_with_fallback(
|
||||
scene_id: Optional[UUID],
|
||||
workspace_id: UUID,
|
||||
db: Session,
|
||||
enable_general_fallback: bool = True
|
||||
) -> Optional["OntologyTypeList"]:
|
||||
"""加载本体类型,如果场景没有类型则回退到通用类型
|
||||
|
||||
这是一个便捷函数,组合了场景类型加载和通用类型回退逻辑。
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID
|
||||
workspace_id: 工作空间ID
|
||||
db: 数据库会话
|
||||
enable_general_fallback: 是否在没有场景类型时启用通用类型回退
|
||||
|
||||
Returns:
|
||||
OntologyTypeList 或 None
|
||||
"""
|
||||
# 首先尝试加载场景类型
|
||||
ontology_types = load_ontology_types_for_scene(
|
||||
scene_id=scene_id,
|
||||
workspace_id=workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
# 如果没有场景类型且启用了回退,创建空列表以使用通用类型
|
||||
if ontology_types is None and enable_general_fallback:
|
||||
ontology_types = create_empty_ontology_type_list()
|
||||
if ontology_types:
|
||||
logger.info("No scene ontology types, will use general ontology types only")
|
||||
|
||||
return ontology_types
|
||||
231
api/app/core/memory/ontology_services/ontology_type_merger.py
Normal file
231
api/app/core/memory/ontology_services/ontology_type_merger.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体类型合并服务模块
|
||||
|
||||
本模块实现本体类型合并服务,负责按优先级合并场景类型与通用类型。
|
||||
|
||||
合并优先级:
|
||||
1. 场景特定类型(最高优先级)
|
||||
2. 核心通用类型
|
||||
3. 相关父类类型(最低优先级)
|
||||
|
||||
Classes:
|
||||
OntologyTypeMerger: 本体类型合并服务类
|
||||
|
||||
Constants:
|
||||
DEFAULT_CORE_GENERAL_TYPES: 默认核心通用类型集合
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from app.core.memory.models.ontology_general_models import GeneralOntologyTypeRegistry
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeInfo, OntologyTypeList
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认核心通用类型
|
||||
DEFAULT_CORE_GENERAL_TYPES: Set[str] = {
|
||||
"Person", "Organization", "Company", "GovernmentAgency",
|
||||
"Place", "Location", "City", "Country", "Building",
|
||||
"Event", "SportsEvent", "MusicEvent", "SocialEvent",
|
||||
"Work", "Book", "Film", "Software", "Album",
|
||||
"Concept", "TopicalConcept", "AcademicSubject",
|
||||
"Device", "Food", "Drug", "ChemicalSubstance",
|
||||
"TimePeriod", "Year",
|
||||
}
|
||||
|
||||
|
||||
class OntologyTypeMerger:
|
||||
"""本体类型合并服务
|
||||
|
||||
负责按优先级合并场景类型与通用类型,生成用于三元组提取的类型列表。
|
||||
|
||||
合并优先级:
|
||||
1. 场景特定类型(最高优先级)- 标记为 [场景类型]
|
||||
2. 核心通用类型 - 标记为 [通用类型]
|
||||
3. 相关父类类型(最低优先级)- 标记为 [通用父类]
|
||||
|
||||
Attributes:
|
||||
general_registry: 通用本体类型注册表
|
||||
max_types_in_prompt: Prompt 中最大类型数量限制
|
||||
core_types: 核心通用类型集合
|
||||
|
||||
Example:
|
||||
>>> registry = GeneralOntologyTypeRegistry()
|
||||
>>> merger = OntologyTypeMerger(registry, max_types_in_prompt=50)
|
||||
>>> merged = merger.merge(scene_types)
|
||||
>>> print(len(merged.types))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
general_registry: GeneralOntologyTypeRegistry,
|
||||
max_types_in_prompt: int = 50,
|
||||
core_types: Optional[List[str]] = None
|
||||
):
|
||||
"""初始化本体类型合并服务
|
||||
|
||||
Args:
|
||||
general_registry: 通用本体类型注册表
|
||||
max_types_in_prompt: Prompt 中最大类型数量,默认 50
|
||||
core_types: 自定义核心类型列表,如果为 None 则使用默认核心类型
|
||||
"""
|
||||
self.general_registry = general_registry
|
||||
self.max_types_in_prompt = max_types_in_prompt
|
||||
self.core_types: Set[str] = set(core_types) if core_types else DEFAULT_CORE_GENERAL_TYPES.copy()
|
||||
|
||||
def update_core_types(self, core_types: List[str]) -> None:
|
||||
"""动态更新核心类型列表
|
||||
|
||||
更新后立即生效,无需重启服务。
|
||||
|
||||
Args:
|
||||
core_types: 新的核心类型列表
|
||||
"""
|
||||
self.core_types = set(core_types)
|
||||
logger.info(f"核心类型已更新: {len(self.core_types)} 个类型")
|
||||
|
||||
def merge(
|
||||
self,
|
||||
scene_types: Optional[OntologyTypeList],
|
||||
include_related_types: bool = True
|
||||
) -> OntologyTypeList:
|
||||
"""合并场景类型与通用类型
|
||||
|
||||
按优先级合并类型:
|
||||
1. 场景特定类型(最高优先级)
|
||||
2. 核心通用类型
|
||||
3. 相关父类类型(可选)
|
||||
|
||||
合并后的类型总数不超过 max_types_in_prompt。
|
||||
|
||||
Args:
|
||||
scene_types: 场景特定类型列表,可以为 None
|
||||
include_related_types: 是否包含相关父类类型,默认 True
|
||||
|
||||
Returns:
|
||||
合并后的类型列表,每个类型带有来源标记
|
||||
"""
|
||||
merged_types: List[OntologyTypeInfo] = []
|
||||
seen_names: Set[str] = set()
|
||||
|
||||
# 1. 场景特定类型(最高优先级)
|
||||
scene_type_count = 0
|
||||
if scene_types and scene_types.types:
|
||||
for scene_type in scene_types.types:
|
||||
if scene_type.class_name not in seen_names:
|
||||
merged_types.append(OntologyTypeInfo(
|
||||
class_name=scene_type.class_name,
|
||||
class_description=f"[场景类型] {scene_type.class_description}"
|
||||
))
|
||||
seen_names.add(scene_type.class_name)
|
||||
scene_type_count += 1
|
||||
|
||||
# 2. 核心通用类型
|
||||
remaining_slots = self.max_types_in_prompt - len(merged_types)
|
||||
core_types_added: List[OntologyTypeInfo] = []
|
||||
|
||||
for type_name in self.core_types:
|
||||
if type_name not in seen_names and remaining_slots > 0:
|
||||
general_type = self.general_registry.get_type(type_name)
|
||||
if general_type:
|
||||
description = (
|
||||
general_type.labels.get("zh") or
|
||||
general_type.description or
|
||||
general_type.get_label("en") or
|
||||
type_name
|
||||
)
|
||||
core_types_added.append(OntologyTypeInfo(
|
||||
class_name=type_name,
|
||||
class_description=f"[通用类型] {description}"
|
||||
))
|
||||
seen_names.add(type_name)
|
||||
remaining_slots -= 1
|
||||
|
||||
merged_types.extend(core_types_added)
|
||||
|
||||
# 3. 相关父类类型
|
||||
related_types_added: List[OntologyTypeInfo] = []
|
||||
if include_related_types and scene_types and scene_types.types:
|
||||
for scene_type in scene_types.types:
|
||||
if remaining_slots <= 0:
|
||||
break
|
||||
general_type = self.general_registry.get_type(scene_type.class_name)
|
||||
if general_type and general_type.parent_class:
|
||||
parent_name = general_type.parent_class
|
||||
if parent_name not in seen_names:
|
||||
parent_type = self.general_registry.get_type(parent_name)
|
||||
if parent_type:
|
||||
description = (
|
||||
parent_type.labels.get("zh") or
|
||||
parent_type.description or
|
||||
parent_name
|
||||
)
|
||||
related_types_added.append(OntologyTypeInfo(
|
||||
class_name=parent_name,
|
||||
class_description=f"[通用父类] {description}"
|
||||
))
|
||||
seen_names.add(parent_name)
|
||||
remaining_slots -= 1
|
||||
|
||||
merged_types.extend(related_types_added)
|
||||
|
||||
logger.info(
|
||||
f"类型合并完成: 场景类型 {scene_type_count} 个, "
|
||||
f"核心通用类型 {len(core_types_added)} 个, "
|
||||
f"相关类型 {len(related_types_added)} 个, "
|
||||
f"总计 {len(merged_types)} 个"
|
||||
)
|
||||
|
||||
return OntologyTypeList(types=merged_types)
|
||||
|
||||
def get_type_hierarchy_hint(self, type_name: str) -> Optional[str]:
|
||||
"""获取类型的层次提示信息(最多 3 级)
|
||||
|
||||
返回类型的继承链信息,格式为 "类型名 → 父类1 → 父类2 → 父类3"。
|
||||
|
||||
Args:
|
||||
type_name: 类型名称
|
||||
|
||||
Returns:
|
||||
层次提示字符串,如果类型不存在或没有父类则返回 None
|
||||
"""
|
||||
general_type = self.general_registry.get_type(type_name)
|
||||
if not general_type:
|
||||
return None
|
||||
ancestors = self.general_registry.get_ancestors(type_name)
|
||||
if ancestors:
|
||||
# 限制最多 3 级祖先
|
||||
return f"{type_name} → {' → '.join(ancestors[:3])}"
|
||||
return None
|
||||
|
||||
def get_merge_statistics(self, scene_types: Optional[OntologyTypeList]) -> dict:
|
||||
"""获取合并统计信息
|
||||
|
||||
执行合并操作并返回各类型来源的数量统计。
|
||||
|
||||
Args:
|
||||
scene_types: 场景特定类型列表
|
||||
|
||||
Returns:
|
||||
包含以下键的统计字典:
|
||||
- total_types: 合并后总类型数
|
||||
- scene_types: 场景类型数量
|
||||
- general_types: 通用类型数量
|
||||
- parent_types: 父类类型数量
|
||||
- available_core_types: 可用核心类型数量
|
||||
- registry_total_types: 注册表中总类型数
|
||||
"""
|
||||
merged = self.merge(scene_types)
|
||||
scene_count = sum(1 for t in merged.types if "[场景类型]" in t.class_description)
|
||||
general_count = sum(1 for t in merged.types if "[通用类型]" in t.class_description)
|
||||
parent_count = sum(1 for t in merged.types if "[通用父类]" in t.class_description)
|
||||
|
||||
return {
|
||||
"total_types": len(merged.types),
|
||||
"scene_types": scene_count,
|
||||
"general_types": general_count,
|
||||
"parent_types": parent_count,
|
||||
"available_core_types": len(self.core_types),
|
||||
"registry_total_types": len(self.general_registry.types),
|
||||
}
|
||||
@@ -34,6 +34,8 @@ from app.core.memory.models.graph_models import (
|
||||
StatementNode,
|
||||
)
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.core.memory.models.variate_config import (
|
||||
ExtractionPipelineConfig,
|
||||
)
|
||||
@@ -95,6 +97,8 @@ class ExtractionOrchestrator:
|
||||
config: Optional[ExtractionPipelineConfig] = None,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
||||
embedding_id: Optional[str] = None,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
enable_general_types: bool = True,
|
||||
language: str = "zh",
|
||||
):
|
||||
"""
|
||||
@@ -119,6 +123,29 @@ class ExtractionOrchestrator:
|
||||
self.progress_callback = progress_callback # 保存进度回调函数
|
||||
self.embedding_id = embedding_id # 保存嵌入模型ID
|
||||
self.language = language # 保存语言配置
|
||||
|
||||
# 处理本体类型配置
|
||||
# 根据 enable_general_types 参数决定是否将通用本体类型与场景特定类型合并
|
||||
# 如果启用合并且配置中开启了通用本体功能,则使用 OntologyTypeMerger 进行融合
|
||||
if enable_general_types and ontology_types:
|
||||
from app.core.memory.ontology_services.ontology_type_loader import (
|
||||
get_ontology_type_merger,
|
||||
is_general_ontology_enabled,
|
||||
)
|
||||
if is_general_ontology_enabled():
|
||||
merger = get_ontology_type_merger()
|
||||
self.ontology_types = merger.merge(ontology_types)
|
||||
logger.info(
|
||||
f"已启用通用本体类型融合: 场景类型 {len(ontology_types.types) if ontology_types.types else 0} 个 -> "
|
||||
f"合并后 {len(self.ontology_types.types) if self.ontology_types.types else 0} 个"
|
||||
)
|
||||
else:
|
||||
self.ontology_types = ontology_types
|
||||
logger.info("通用本体类型功能已在配置中禁用,仅使用场景类型")
|
||||
else:
|
||||
self.ontology_types = ontology_types
|
||||
if not enable_general_types and ontology_types:
|
||||
logger.info("enable_general_types=False,仅使用场景类型")
|
||||
|
||||
# 保存去重消歧的详细记录(内存中的数据结构)
|
||||
self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录
|
||||
@@ -130,7 +157,7 @@ class ExtractionOrchestrator:
|
||||
llm_client=llm_client,
|
||||
config=self.config.statement_extraction,
|
||||
)
|
||||
self.triplet_extractor = TripletExtractor(llm_client=llm_client, language=language)
|
||||
self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language)
|
||||
self.temporal_extractor = TemporalExtractor(llm_client=llm_client)
|
||||
|
||||
logger.info("ExtractionOrchestrator 初始化完成")
|
||||
|
||||
@@ -14,7 +14,7 @@ import time
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.models.ontology_models import (
|
||||
from app.core.memory.models.ontology_scenario_models import (
|
||||
OntologyClass,
|
||||
OntologyExtractionResponse,
|
||||
)
|
||||
@@ -118,7 +118,7 @@ class OntologyExtractor:
|
||||
logger.info(
|
||||
f"Starting ontology extraction - scenario_length={len(scenario)}, "
|
||||
f"domain={domain}, max_classes={max_classes}, min_classes={min_classes}, "
|
||||
f"timeout={timeout}"
|
||||
f"timeout={timeout}, language={language}"
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
@@ -8,6 +8,7 @@ from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.core.memory.utils.log.logging_utils import prompt_logger
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
@@ -17,14 +18,21 @@ logger = get_memory_logger(__name__)
|
||||
class TripletExtractor:
|
||||
"""Extracts knowledge triplets and entities from statements using LLM"""
|
||||
|
||||
def __init__(self, llm_client: OpenAIClient, language: str = "zh"):
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
language: str = "zh"):
|
||||
"""Initialize the TripletExtractor with an LLM client
|
||||
|
||||
Args:
|
||||
llm_client: OpenAIClient instance for processing
|
||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||
ontology_types: Optional OntologyTypeList containing predefined ontology types
|
||||
for entity classification guidance
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.ontology_types = ontology_types
|
||||
self.language = language
|
||||
|
||||
def _get_language(self) -> str:
|
||||
@@ -51,7 +59,8 @@ class TripletExtractor:
|
||||
chunk_content=chunk_content,
|
||||
json_schema=TripletExtractionResponse.model_json_schema(),
|
||||
predicate_instructions=PREDICATE_DEFINITIONS,
|
||||
language=self._get_language()
|
||||
language=self._get_language(),
|
||||
ontology_types=self.ontology_types,
|
||||
)
|
||||
|
||||
# Create messages for LLM
|
||||
|
||||
12
api/app/core/memory/utils/ontology/__init__.py
Normal file
12
api/app/core/memory/utils/ontology/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体解析工具模块
|
||||
|
||||
本模块提供本体文件解析功能,支持多种 RDF 格式的本体文件解析。
|
||||
|
||||
Modules:
|
||||
ontology_parser: 本体文件解析器
|
||||
"""
|
||||
|
||||
from .ontology_parser import MultiOntologyParser, OntologyParser
|
||||
|
||||
__all__ = ["OntologyParser", "MultiOntologyParser"]
|
||||
366
api/app/core/memory/utils/ontology/ontology_parser.py
Normal file
366
api/app/core/memory/utils/ontology/ontology_parser.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体文件解析器模块
|
||||
|
||||
本模块提供统一的本体文件解析功能,支持多种 RDF 格式:
|
||||
- Turtle (.ttl)
|
||||
- OWL/XML (.owl)
|
||||
- RDF/XML (.rdf)
|
||||
- N-Triples (.nt)
|
||||
- JSON-LD (.jsonld)
|
||||
|
||||
解析器会自动根据文件扩展名推断格式,并在解析失败时尝试其他格式。
|
||||
解析结果包含类定义的名称、URI、多语言标签、描述和父类信息。
|
||||
|
||||
Classes:
|
||||
OntologyParser: 统一本体文件解析器
|
||||
MultiOntologyParser: 多本体文件解析器
|
||||
|
||||
Example:
|
||||
>>> parser = OntologyParser("ontology.ttl")
|
||||
>>> registry = parser.parse()
|
||||
>>> print(f"解析了 {len(registry.types)} 个类型")
|
||||
|
||||
>>> multi_parser = MultiOntologyParser(["ontology1.ttl", "ontology2.owl"])
|
||||
>>> merged_registry = multi_parser.parse_all()
|
||||
>>> print(f"合并后共 {len(merged_registry.types)} 个类型")
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from rdflib import OWL, RDF, RDFS, Graph, URIRef
|
||||
|
||||
from app.core.memory.models.ontology_general_models import (
|
||||
GeneralOntologyType,
|
||||
GeneralOntologyTypeRegistry,
|
||||
OntologyFileFormat,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OntologyParser:
|
||||
"""统一本体文件解析器
|
||||
|
||||
解析本体文件并提取类定义,构建类型注册表。支持多种 RDF 格式,
|
||||
并提供格式自动推断和回退机制。
|
||||
|
||||
Attributes:
|
||||
file_path: 本体文件路径
|
||||
file_format: 文件格式,如果未指定则根据扩展名推断
|
||||
graph: rdflib Graph 实例,用于存储解析后的 RDF 数据
|
||||
|
||||
Example:
|
||||
>>> parser = OntologyParser("dbpedia.owl")
|
||||
>>> registry = parser.parse()
|
||||
>>> person_type = registry.get_type("Person")
|
||||
>>> if person_type:
|
||||
... print(f"Person URI: {person_type.class_uri}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
file_format: Optional[OntologyFileFormat] = None,
|
||||
):
|
||||
"""初始化解析器
|
||||
|
||||
Args:
|
||||
file_path: 本体文件路径
|
||||
file_format: 文件格式,如果未指定则根据扩展名自动推断
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.file_format = file_format or OntologyFileFormat.from_extension(file_path)
|
||||
self.graph = Graph()
|
||||
|
||||
def parse(self) -> GeneralOntologyTypeRegistry:
|
||||
"""解析本体文件,返回类型注册表
|
||||
|
||||
首先尝试使用推断的格式解析文件,如果失败则尝试其他格式。
|
||||
解析成功后,遍历所有 owl:Class 和 rdfs:Class 定义,
|
||||
提取类信息并构建层次结构。
|
||||
|
||||
Returns:
|
||||
GeneralOntologyTypeRegistry: 包含所有解析出的类型和层次结构的注册表
|
||||
|
||||
Raises:
|
||||
ValueError: 当所有格式都无法解析文件时抛出
|
||||
"""
|
||||
logger.info(f"开始解析本体文件: {self.file_path}")
|
||||
|
||||
# 尝试解析,失败则尝试其他格式
|
||||
self._parse_with_fallback()
|
||||
|
||||
registry = GeneralOntologyTypeRegistry()
|
||||
registry.source_files.append(self.file_path)
|
||||
|
||||
# 遍历 owl:Class
|
||||
for class_uri in self.graph.subjects(RDF.type, OWL.Class):
|
||||
type_info = self._parse_class(class_uri)
|
||||
if type_info:
|
||||
registry.types[type_info.class_name] = type_info
|
||||
self._update_hierarchy(registry, type_info)
|
||||
|
||||
# 遍历 rdfs:Class(避免重复)
|
||||
for class_uri in self.graph.subjects(RDF.type, RDFS.Class):
|
||||
uri_str = str(class_uri)
|
||||
# 检查是否已经作为 owl:Class 解析过
|
||||
if uri_str not in [t.class_uri for t in registry.types.values()]:
|
||||
type_info = self._parse_class(class_uri)
|
||||
if type_info and type_info.class_name not in registry.types:
|
||||
registry.types[type_info.class_name] = type_info
|
||||
self._update_hierarchy(registry, type_info)
|
||||
|
||||
logger.info(f"本体解析完成: {len(registry.types)} 个类型")
|
||||
return registry
|
||||
|
||||
def _parse_with_fallback(self) -> None:
|
||||
"""尝试解析文件,失败时尝试其他格式
|
||||
|
||||
首先使用推断的格式解析,如果失败则依次尝试 RDF_XML 和 TURTLE 格式。
|
||||
|
||||
Raises:
|
||||
ValueError: 当所有格式都无法解析文件时抛出
|
||||
"""
|
||||
try:
|
||||
self.graph.parse(self.file_path, format=self.file_format.value)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"使用 {self.file_format.value} 格式解析失败: {e}")
|
||||
|
||||
# 尝试其他格式
|
||||
fallback_formats = [
|
||||
OntologyFileFormat.RDF_XML,
|
||||
OntologyFileFormat.TURTLE,
|
||||
OntologyFileFormat.N_TRIPLES,
|
||||
OntologyFileFormat.JSON_LD,
|
||||
]
|
||||
|
||||
for fmt in fallback_formats:
|
||||
if fmt != self.file_format:
|
||||
try:
|
||||
self.graph.parse(self.file_path, format=fmt.value)
|
||||
logger.info(f"使用回退格式 {fmt.value} 解析成功")
|
||||
return
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
raise ValueError(f"无法解析本体文件: {self.file_path}")
|
||||
|
||||
def _update_hierarchy(
|
||||
self,
|
||||
registry: GeneralOntologyTypeRegistry,
|
||||
type_info: GeneralOntologyType
|
||||
) -> None:
|
||||
"""更新层次结构
|
||||
|
||||
如果类型有父类,将其添加到层次结构中。
|
||||
|
||||
Args:
|
||||
registry: 类型注册表
|
||||
type_info: 类型信息
|
||||
"""
|
||||
if type_info.parent_class:
|
||||
if type_info.parent_class not in registry.hierarchy:
|
||||
registry.hierarchy[type_info.parent_class] = set()
|
||||
registry.hierarchy[type_info.parent_class].add(type_info.class_name)
|
||||
|
||||
def _parse_class(self, class_uri: URIRef) -> Optional[GeneralOntologyType]:
|
||||
"""解析单个类定义
|
||||
|
||||
从 RDF 图中提取类的名称、URI、标签、描述和父类信息。
|
||||
过滤空白节点和内置类型(Thing、Resource)。
|
||||
|
||||
Args:
|
||||
class_uri: 类的 URI 引用
|
||||
|
||||
Returns:
|
||||
GeneralOntologyType 实例,如果应该跳过该类则返回 None
|
||||
"""
|
||||
uri_str = str(class_uri)
|
||||
class_name = self._extract_local_name(uri_str)
|
||||
|
||||
# 过滤空白节点和内置类型
|
||||
if not class_name:
|
||||
return None
|
||||
if class_name.startswith('_:'):
|
||||
return None
|
||||
if class_name in ('Thing', 'Resource'):
|
||||
return None
|
||||
# 过滤空白节点 URI(以 _: 开头或包含空白节点标识)
|
||||
if uri_str.startswith('_:'):
|
||||
return None
|
||||
|
||||
# 提取标签
|
||||
labels = self._extract_labels(class_uri)
|
||||
|
||||
# 提取描述
|
||||
description = self._extract_description(class_uri)
|
||||
|
||||
# 提取父类
|
||||
parent_class = self._extract_parent_class(class_uri)
|
||||
|
||||
return GeneralOntologyType(
|
||||
class_name=class_name,
|
||||
class_uri=uri_str,
|
||||
labels=labels,
|
||||
description=description,
|
||||
parent_class=parent_class,
|
||||
source_file=self.file_path
|
||||
)
|
||||
|
||||
def _extract_labels(self, class_uri: URIRef) -> dict:
|
||||
"""提取类的多语言标签
|
||||
|
||||
从 rdfs:label 属性中提取所有语言的标签。
|
||||
如果没有标签,使用类名作为英文标签。
|
||||
|
||||
Args:
|
||||
class_uri: 类的 URI 引用
|
||||
|
||||
Returns:
|
||||
语言代码到标签文本的字典
|
||||
"""
|
||||
labels = {}
|
||||
for label in self.graph.objects(class_uri, RDFS.label):
|
||||
lang = getattr(label, 'language', None) or "en"
|
||||
labels[lang] = str(label)
|
||||
|
||||
# 如果没有标签,使用类名作为默认标签
|
||||
if not labels:
|
||||
class_name = self._extract_local_name(str(class_uri))
|
||||
if class_name:
|
||||
labels["en"] = class_name
|
||||
|
||||
return labels
|
||||
|
||||
def _extract_description(self, class_uri: URIRef) -> Optional[str]:
|
||||
"""提取类的描述
|
||||
|
||||
从 rdfs:comment 属性中提取描述,优先使用英文描述。
|
||||
|
||||
Args:
|
||||
class_uri: 类的 URI 引用
|
||||
|
||||
Returns:
|
||||
类的描述文本,如果没有则返回 None
|
||||
"""
|
||||
description = None
|
||||
for comment in self.graph.objects(class_uri, RDFS.comment):
|
||||
lang = getattr(comment, 'language', None)
|
||||
# 优先使用英文描述
|
||||
if lang == "en":
|
||||
return str(comment)
|
||||
# 如果还没有描述,使用无语言标记或其他语言的描述
|
||||
if description is None:
|
||||
description = str(comment)
|
||||
return description
|
||||
|
||||
def _extract_parent_class(self, class_uri: URIRef) -> Optional[str]:
|
||||
"""提取类的父类
|
||||
|
||||
从 rdfs:subClassOf 属性中提取第一个有效的父类。
|
||||
过滤内置类型(Thing、Resource)和空白节点。
|
||||
|
||||
Args:
|
||||
class_uri: 类的 URI 引用
|
||||
|
||||
Returns:
|
||||
父类名称,如果没有有效父类则返回 None
|
||||
"""
|
||||
for parent_uri in self.graph.objects(class_uri, RDFS.subClassOf):
|
||||
parent_uri_str = str(parent_uri)
|
||||
# 跳过空白节点
|
||||
if parent_uri_str.startswith('_:'):
|
||||
continue
|
||||
|
||||
parent_name = self._extract_local_name(parent_uri_str)
|
||||
# 过滤内置类型
|
||||
if parent_name and parent_name not in ('Thing', 'Resource'):
|
||||
return parent_name
|
||||
|
||||
return None
|
||||
|
||||
def _extract_local_name(self, uri: str) -> Optional[str]:
|
||||
"""从 URI 中提取本地名称
|
||||
|
||||
支持两种常见的 URI 格式:
|
||||
1. 使用 # 分隔的 URI,如 http://example.org/ontology#Person
|
||||
2. 使用 / 分隔的 URI,如 http://dbpedia.org/ontology/Person
|
||||
|
||||
Args:
|
||||
uri: 完整的 URI 字符串
|
||||
|
||||
Returns:
|
||||
本地名称,如果无法提取则返回 None
|
||||
"""
|
||||
# 处理空白节点
|
||||
if uri.startswith('_:'):
|
||||
return None
|
||||
|
||||
# 尝试使用 # 分隔
|
||||
if '#' in uri:
|
||||
local_name = uri.rsplit('#', 1)[1]
|
||||
if local_name:
|
||||
return local_name
|
||||
|
||||
# 尝试使用 / 分隔
|
||||
if '/' in uri:
|
||||
local_name = uri.rsplit('/', 1)[1]
|
||||
if local_name:
|
||||
return local_name
|
||||
|
||||
# 使用正则表达式作为最后手段
|
||||
match = re.search(r'[#/]([^#/]+)$', uri)
|
||||
return match.group(1) if match else None
|
||||
|
||||
|
||||
class MultiOntologyParser:
|
||||
"""多本体文件解析器
|
||||
|
||||
支持加载多个本体文件并将它们合并到一个统一的类型注册表中。
|
||||
先加载的文件中的类型定义优先保留(当存在同名类型时)。
|
||||
|
||||
Attributes:
|
||||
file_paths: 本体文件路径列表
|
||||
|
||||
Example:
|
||||
>>> parser = MultiOntologyParser([
|
||||
... "General_purpose_entity.ttl",
|
||||
... "domain_specific.owl"
|
||||
... ])
|
||||
>>> registry = parser.parse_all()
|
||||
>>> print(f"合并后共 {len(registry.types)} 个类型")
|
||||
"""
|
||||
|
||||
def __init__(self, file_paths: List[str]):
|
||||
"""初始化多文件解析器
|
||||
|
||||
Args:
|
||||
file_paths: 本体文件路径列表
|
||||
"""
|
||||
self.file_paths = file_paths
|
||||
|
||||
def parse_all(self) -> GeneralOntologyTypeRegistry:
|
||||
"""解析所有本体文件并合并
|
||||
|
||||
依次解析每个本体文件,并将结果合并到一个统一的注册表中。
|
||||
如果某个文件解析失败,会记录警告日志并跳过该文件继续处理。
|
||||
|
||||
Returns:
|
||||
GeneralOntologyTypeRegistry: 合并后的类型注册表
|
||||
"""
|
||||
merged_registry = GeneralOntologyTypeRegistry()
|
||||
|
||||
for file_path in self.file_paths:
|
||||
try:
|
||||
parser = OntologyParser(file_path)
|
||||
registry = parser.parse()
|
||||
merged_registry.merge(registry)
|
||||
logger.info(f"已合并本体文件: {file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过无法解析的本体文件 {file_path}: {e}")
|
||||
|
||||
logger.info(f"多本体合并完成: 共 {len(merged_registry.types)} 个类型")
|
||||
return merged_registry
|
||||
@@ -9,22 +9,29 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
prompt_dir = os.path.join(current_dir, "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
async def get_prompts(message: str) -> list[dict]:
|
||||
async def get_prompts(message: str, language: str = "zh") -> list[dict]:
|
||||
"""
|
||||
Renders system and user prompts using Jinja2 templates.
|
||||
|
||||
Args:
|
||||
message: The message content
|
||||
language: Language for output ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
List of message dictionaries with role and content
|
||||
"""
|
||||
system_template = prompt_env.get_template("system.jinja2")
|
||||
user_template = prompt_env.get_template("user.jinja2")
|
||||
|
||||
system_prompt = system_template.render()
|
||||
user_prompt = user_template.render(message=message)
|
||||
system_prompt = system_template.render(language=language)
|
||||
user_prompt = user_template.render(message=message, language=language)
|
||||
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('system', system_prompt)
|
||||
log_prompt_rendering('user', user_prompt)
|
||||
# 可选:记录模板渲染信息(仅当 prompt_templates.log 存在时生效)
|
||||
log_template_rendering('system.jinja2', {})
|
||||
log_template_rendering('user.jinja2', {'message': message})
|
||||
log_template_rendering('system.jinja2', {'language': language})
|
||||
log_template_rendering('user.jinja2', {'message': message, 'language': language})
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
@@ -38,6 +45,7 @@ async def render_statement_extraction_prompt(
|
||||
include_dialogue_context: bool = False,
|
||||
dialogue_content: str | None = None,
|
||||
max_dialogue_chars: int | None = None,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Renders the statement extraction prompt using the extract_statement.jinja2 template.
|
||||
@@ -46,6 +54,11 @@ async def render_statement_extraction_prompt(
|
||||
chunk_content: The content of the chunk to process
|
||||
definitions: Label definitions for statement classification
|
||||
json_schema: JSON schema for the expected output format
|
||||
granularity: Extraction granularity level (1-3)
|
||||
include_dialogue_context: Whether to include full dialogue context
|
||||
dialogue_content: Full dialogue content for context
|
||||
max_dialogue_chars: Maximum characters for dialogue context
|
||||
language: Language for output ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -69,6 +82,7 @@ async def render_statement_extraction_prompt(
|
||||
granularity=granularity,
|
||||
include_dialogue_context=include_dialogue_context,
|
||||
dialogue_context=ctx,
|
||||
language=language,
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('statement extraction', rendered_prompt)
|
||||
@@ -90,6 +104,7 @@ async def render_temporal_extraction_prompt(
|
||||
temporal_guide: dict,
|
||||
statement_guide: dict,
|
||||
json_schema: dict,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Renders the temporal extraction prompt using the extract_temporal.jinja2 template.
|
||||
@@ -100,6 +115,7 @@ async def render_temporal_extraction_prompt(
|
||||
temporal_guide: Guidance on temporal types.
|
||||
statement_guide: Guidance on statement types.
|
||||
json_schema: JSON schema for the expected output format.
|
||||
language: Language for output ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as a string.
|
||||
@@ -111,6 +127,7 @@ async def render_temporal_extraction_prompt(
|
||||
temporal_guide=temporal_guide,
|
||||
statement_guide=statement_guide,
|
||||
json_schema=json_schema,
|
||||
language=language,
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('temporal extraction', rendered_prompt)
|
||||
@@ -130,6 +147,7 @@ def render_entity_dedup_prompt(
|
||||
context: dict,
|
||||
json_schema: dict,
|
||||
disambiguation_mode: bool = False,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Render the entity deduplication prompt using the entity_dedup.jinja2 template.
|
||||
@@ -139,6 +157,8 @@ def render_entity_dedup_prompt(
|
||||
entity_b: Dict of entity B attributes
|
||||
context: Dict of computed signals (group/type gate, similarities, co-occurrence, relation statements)
|
||||
json_schema: JSON schema for the structured output (EntityDedupDecision)
|
||||
disambiguation_mode: Whether to use disambiguation mode
|
||||
language: Language for output ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -157,6 +177,7 @@ def render_entity_dedup_prompt(
|
||||
relation_statements=context.get("relation_statements", []),
|
||||
json_schema=json_schema,
|
||||
disambiguation_mode=disambiguation_mode,
|
||||
language=language,
|
||||
)
|
||||
|
||||
# prompt_logger.info("\n=== RENDERED ENTITY DEDUP PROMPT ===")
|
||||
@@ -177,7 +198,14 @@ def render_entity_dedup_prompt(
|
||||
|
||||
# Args:
|
||||
# entity_a: Dict of entity A attributes
|
||||
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None, language: str = "zh") -> str:
|
||||
async def render_triplet_extraction_prompt(
|
||||
statement: str,
|
||||
chunk_content: str,
|
||||
json_schema: dict,
|
||||
predicate_instructions: dict = None,
|
||||
language: str = "zh",
|
||||
ontology_types: "OntologyTypeList | None" = None,
|
||||
) -> str:
|
||||
"""
|
||||
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
|
||||
|
||||
@@ -187,17 +215,31 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
||||
json_schema: JSON schema for the expected output format
|
||||
predicate_instructions: Optional predicate instructions
|
||||
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
|
||||
ontology_types: Optional OntologyTypeList containing predefined ontology types for entity classification
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("extract_triplet.jinja2")
|
||||
|
||||
# 准备本体类型数据
|
||||
ontology_type_section = ""
|
||||
ontology_type_names = []
|
||||
type_hierarchy_hints = []
|
||||
if ontology_types and ontology_types.types:
|
||||
ontology_type_section = ontology_types.to_prompt_section()
|
||||
ontology_type_names = ontology_types.get_type_names()
|
||||
type_hierarchy_hints = ontology_types.get_type_hierarchy_hints()
|
||||
|
||||
rendered_prompt = template.render(
|
||||
statement=statement,
|
||||
chunk_content=chunk_content,
|
||||
json_schema=json_schema,
|
||||
predicate_instructions=predicate_instructions,
|
||||
language=language
|
||||
language=language,
|
||||
ontology_types=ontology_type_section,
|
||||
ontology_type_names=ontology_type_names,
|
||||
type_hierarchy_hints=type_hierarchy_hints,
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('triplet extraction', rendered_prompt)
|
||||
@@ -207,7 +249,10 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
||||
'chunk_content': 'str',
|
||||
'json_schema': 'TripletExtractionResponse.schema',
|
||||
'predicate_instructions': 'PREDICATE_DEFINITIONS',
|
||||
'language': language
|
||||
'language': language,
|
||||
'ontology_types': bool(ontology_type_section),
|
||||
'ontology_type_count': len(ontology_type_names),
|
||||
'type_hierarchy_hints_count': len(type_hierarchy_hints),
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
@@ -249,7 +294,8 @@ async def render_memory_summary_prompt(
|
||||
async def render_emotion_extraction_prompt(
|
||||
statement: str,
|
||||
extract_keywords: bool,
|
||||
enable_subject: bool
|
||||
enable_subject: bool,
|
||||
language: str = "zh"
|
||||
) -> str:
|
||||
"""
|
||||
Renders the emotion extraction prompt using the extract_emotion.jinja2 template.
|
||||
@@ -258,6 +304,7 @@ async def render_emotion_extraction_prompt(
|
||||
statement: The statement to analyze
|
||||
extract_keywords: Whether to extract emotion keywords
|
||||
enable_subject: Whether to enable subject classification
|
||||
language: Language for output ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -266,7 +313,8 @@ async def render_emotion_extraction_prompt(
|
||||
rendered_prompt = template.render(
|
||||
statement=statement,
|
||||
extract_keywords=extract_keywords,
|
||||
enable_subject=enable_subject
|
||||
enable_subject=enable_subject,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
@@ -467,7 +515,8 @@ async def render_ontology_extraction_prompt(
|
||||
'scenario_len': len(scenario) if scenario else 0,
|
||||
'domain': domain,
|
||||
'max_classes': max_classes,
|
||||
'json_schema': 'OntologyExtractionResponse.schema'
|
||||
'json_schema': 'OntologyExtractionResponse.schema',
|
||||
'language': language
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
===任务===
|
||||
===Task===
|
||||
{% if language == "zh" %}
|
||||
你是一个实体去重/消歧判断助手。你将被提供两个实体的详细信息和上下文,请严格根据指引判断它们是否是同一真实世界实体,并在需要时进行类型消歧。
|
||||
|
||||
模式: {{ '消歧模式' if disambiguation_mode else '去重模式' }}
|
||||
{% else %}
|
||||
You are an entity deduplication/disambiguation assistant. You will be provided with detailed information and context for two entities. Please strictly follow the guidelines to determine whether they are the same real-world entity and perform type disambiguation when necessary.
|
||||
|
||||
===输入===
|
||||
Mode: {{ 'Disambiguation Mode' if disambiguation_mode else 'Deduplication Mode' }}
|
||||
{% endif %}
|
||||
|
||||
===Input===
|
||||
{% if language == "zh" %}
|
||||
实体A:
|
||||
- 名称: "{{ entity_a.name | default('') }}"
|
||||
- 类型: "{{ entity_a.entity_type | default('') }}"
|
||||
@@ -34,8 +41,41 @@
|
||||
{% for s in relation_statements %}
|
||||
- {{ s }}
|
||||
{% endfor %}
|
||||
{% else %}
|
||||
Entity A:
|
||||
- Name: "{{ entity_a.name | default('') }}"
|
||||
- Type: "{{ entity_a.entity_type | default('') }}"
|
||||
- Description: "{{ entity_a.description | default('') }}"
|
||||
- Aliases: {{ entity_a.aliases | default([]) }}
|
||||
{# TODO: fact_summary feature temporarily disabled, to be enabled after future development #}
|
||||
{# - Summary: "{{ entity_a.fact_summary | default('') }}" #}
|
||||
- Connection Strength: "{{ entity_a.connect_strength | default('') }}"
|
||||
|
||||
===判定指引===
|
||||
Entity B:
|
||||
- Name: "{{ entity_b.name | default('') }}"
|
||||
- Type: "{{ entity_b.entity_type | default('') }}"
|
||||
- Description: "{{ entity_b.description | default('') }}"
|
||||
- Aliases: {{ entity_b.aliases | default([]) }}
|
||||
{# TODO: fact_summary feature temporarily disabled, to be enabled after future development #}
|
||||
{# - Summary: "{{ entity_b.fact_summary | default('') }}" #}
|
||||
- Connection Strength: "{{ entity_b.connect_strength | default('') }}"
|
||||
|
||||
Context:
|
||||
- Same Group: {{ same_group | default(false) }}
|
||||
- Type Consistent or Unknown: {{ type_ok | default(false) }}
|
||||
- Type Similarity (0-1): {{ type_similarity | default(0.0) }}
|
||||
- Name Text Similarity (0-1): {{ name_text_sim | default(0.0) }}
|
||||
- Name Embedding Similarity (0-1): {{ name_embed_sim | default(0.0) }}
|
||||
- Name Contains Relationship: {{ name_contains | default(false) }}
|
||||
- Context Co-occurrence (same statement refers to both): {{ co_occurrence | default(false) }}
|
||||
- Related Relationship Statements (from entity-entity edges):
|
||||
{% for s in relation_statements %}
|
||||
- {{ s }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
===Guidelines===
|
||||
{% if language == "zh" %}
|
||||
{% if disambiguation_mode %}
|
||||
- 这是"同名但类型不同"的消歧场景。请判断两者是否指向同一真实世界实体。
|
||||
- 综合名称文本/向量相似度、别名、描述、摘要与上下文关系(同源与关系陈述)进行判断。
|
||||
@@ -68,8 +108,43 @@
|
||||
- 优先保留连接强度更强(strong/both)者;其余相同则保留描述/摘要更丰富者;再相同时保留实体A(canonical_idx=0)。
|
||||
- **注意**:别名(aliases)已在三元组提取阶段获取,合并时会自动整合,无需在此阶段提取。
|
||||
{% endif %}
|
||||
{% else %}
|
||||
{% if disambiguation_mode %}
|
||||
- This is a disambiguation scenario for "same name but different types". Please determine whether they refer to the same real-world entity.
|
||||
- Make judgments based on name text/vector similarity, aliases, descriptions, summaries, and contextual relationships (co-occurrence and relationship statements).
|
||||
- **Alias Handling (High Priority)**:
|
||||
* If the alias lists of both entities have intersections, this is a strong signal of identity
|
||||
* If one entity's name appears in another entity's aliases, it should be considered a high-confidence match
|
||||
* If one entity's alias exactly matches another entity's name, it should be considered a high-confidence match
|
||||
* Alias matching weight should be higher than pure name text similarity
|
||||
- If unable to determine with sufficient confidence, handle conservatively: do not merge, and suggest blocking this pair in other fuzzy/heuristic merges (block_pair=true).
|
||||
- If merging is needed (should_merge=true), select the "canonical entity" (canonical_idx) and **must** provide a suggested unified type (suggested_type).
|
||||
- **Type Unification Principles (Important)**:
|
||||
* Prioritize more specific and accurate types (e.g., HistoricalPeriod over Organization, MilitaryCapability over Concept)
|
||||
* If both types are specific but different, choose the type that best matches the entity's core semantics
|
||||
* Generic types (Concept, Phenomenon, Condition, State, Attribute, Event) have lower priority than domain-specific types
|
||||
* Suggested type must be consistent with context and entity description
|
||||
- Canonical entity priority: higher connection strength (strong/both); if equal, retain the one with richer description/summary; if still equal, retain Entity A (canonical_idx=0).
|
||||
- **Note**: Aliases are already obtained during triplet extraction and will be automatically integrated during merging; no need to extract at this stage.
|
||||
{% else %}
|
||||
- If entity types are the same or either is UNKNOWN/empty, can proceed as candidates; if types clearly conflict (e.g., person vs. item), unless aliases and descriptions are highly consistent, determine as different entities.
|
||||
- **Alias Matching Priority (Highest Priority)**:
|
||||
* If Entity A's name exactly matches any of Entity B's aliases, it should be considered a high-confidence match
|
||||
* If Entity B's name exactly matches any of Entity A's aliases, it should be considered a high-confidence match
|
||||
* If any alias of Entity A exactly matches any alias of Entity B, it should be considered a high-confidence match
|
||||
* When aliases match exactly, merging should be considered even if name text similarity is low
|
||||
* Alias matching confidence should be higher than pure name similarity matching
|
||||
- Make judgments based on name text/vector similarity, aliases, descriptions, summaries, and contextual relationships.
|
||||
- When context co-occurs or there are clear relationship statements supporting identity (e.g., the same object is repeatedly mentioned or aliases correspond), the judgment threshold can be moderately lowered.
|
||||
- Conservative decision: when unable to determine with sufficient confidence, do not merge (same_entity=false).
|
||||
- If merging is needed, select the "canonical entity to retain" (canonical_idx) as the more appropriate one:
|
||||
- Prioritize retaining the one with stronger connection strength (strong/both); if equal, retain the one with richer description/summary; if still equal, retain Entity A (canonical_idx=0).
|
||||
- **Note**: Aliases are already obtained during triplet extraction and will be automatically integrated during merging; no need to extract at this stage.
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
**Output format**
|
||||
{% if language == "zh" %}
|
||||
{% if disambiguation_mode %}
|
||||
返回JSON格式,必须包含以下字段:
|
||||
{
|
||||
@@ -103,6 +178,41 @@
|
||||
- confidence: 决策的置信度,范围0.0-1.0
|
||||
- reason: 决策理由的简短说明
|
||||
{% endif %}
|
||||
{% else %}
|
||||
{% if disambiguation_mode %}
|
||||
Return JSON format with the following required fields:
|
||||
{
|
||||
"should_merge": boolean,
|
||||
"canonical_idx": 0 or 1,
|
||||
"confidence": float (0.0-1.0),
|
||||
"block_pair": boolean,
|
||||
"suggested_type": "string or null",
|
||||
"reason": "string"
|
||||
}
|
||||
|
||||
**Field Descriptions**:
|
||||
- should_merge: Whether these two entities should be merged (true/false)
|
||||
- canonical_idx: Index of the canonical entity, 0 for Entity A, 1 for Entity B
|
||||
- confidence: Confidence level of the decision, range 0.0-1.0
|
||||
- block_pair: Whether to block this pair in other fuzzy/heuristic merges (true/false)
|
||||
- suggested_type: Suggested unified type (string or null)
|
||||
- reason: Brief explanation of the decision
|
||||
{% else %}
|
||||
Return JSON format with the following required fields:
|
||||
{
|
||||
"same_entity": boolean,
|
||||
"canonical_idx": 0 or 1,
|
||||
"confidence": float (0.0-1.0),
|
||||
"reason": "string"
|
||||
}
|
||||
|
||||
**Field Descriptions**:
|
||||
- same_entity: Whether the two entities refer to the same real-world entity (true/false)
|
||||
- canonical_idx: Index of the canonical entity, 0 for Entity A, 1 for Entity B
|
||||
- confidence: Confidence level of the decision, range 0.0-1.0
|
||||
- reason: Brief explanation of the decision
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
@@ -110,5 +220,9 @@
|
||||
3. Do not include line breaks within JSON string values
|
||||
4. Test your JSON output mentally to ensure it can be parsed correctly
|
||||
|
||||
{% if language == "zh" %}
|
||||
输出语言应始终与输入语言相同。
|
||||
{% else %}
|
||||
The output language should always be the same as the input language.
|
||||
{% endif %}
|
||||
{{ json_schema }}
|
||||
|
||||
@@ -17,9 +17,18 @@
|
||||
#}
|
||||
|
||||
{% set scene_instructions = {
|
||||
'education': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
|
||||
'online_service': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
|
||||
'outbound': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。'
|
||||
'education': {
|
||||
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
|
||||
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
|
||||
},
|
||||
'online_service': {
|
||||
'zh': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
|
||||
'en': 'Online Service Scenario: Customer inquiries, troubleshooting, service tickets, after-sales support, orders/refunds, ticket escalation, etc.'
|
||||
},
|
||||
'outbound': {
|
||||
'zh': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。',
|
||||
'en': 'Outbound Scenario: Outbound calls, invitations, survey questionnaires, lead follow-up, call scripts, follow-up records, etc.'
|
||||
}
|
||||
} %}
|
||||
|
||||
{% set scene_key = pruning_scene %}
|
||||
@@ -27,8 +36,9 @@
|
||||
{% set scene_key = 'education' %}
|
||||
{% endif %}
|
||||
|
||||
{% set instruction = scene_instructions[scene_key] %}
|
||||
{% set instruction = scene_instructions[scene_key][language] if language in ['zh', 'en'] else scene_instructions[scene_key]['zh'] %}
|
||||
|
||||
{% if language == "zh" %}
|
||||
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
|
||||
场景说明:{{ instruction }}
|
||||
|
||||
@@ -46,4 +56,24 @@
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...]
|
||||
}
|
||||
}
|
||||
{% else %}
|
||||
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
|
||||
Scenario Description: {{ instruction }}
|
||||
|
||||
Full Dialogue:
|
||||
"""
|
||||
{{ dialog_text }}
|
||||
"""
|
||||
|
||||
Output strict JSON only (fixed keys, order doesn't matter):
|
||||
{
|
||||
"is_related": <true or false>,
|
||||
"times": [<string>...],
|
||||
"ids": [<string>...],
|
||||
"amounts": [<string>...],
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...]
|
||||
}
|
||||
{% endif %}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{% if language == "zh" %}
|
||||
你是一个专业的情绪分析专家。请分析以下陈述句的情绪信息。
|
||||
|
||||
陈述句:{{ statement }}
|
||||
@@ -55,3 +56,62 @@
|
||||
- 主体分类要准确,优先识别用户本人(self)
|
||||
|
||||
请以 JSON 格式返回结果。
|
||||
{% else %}
|
||||
You are a professional emotion analysis expert. Please analyze the emotional information in the following statement.
|
||||
|
||||
Statement: {{ statement }}
|
||||
|
||||
Please extract the following information:
|
||||
|
||||
1. emotion_type (Emotion Type):
|
||||
- joy: happiness, delight, pleasure, satisfaction, cheerfulness
|
||||
- sadness: sorrow, grief, disappointment, depression, regret
|
||||
- anger: rage, irritation, dissatisfaction, annoyance, frustration
|
||||
- fear: anxiety, worry, concern, nervousness, apprehension
|
||||
- surprise: astonishment, amazement, shock, wonder
|
||||
- neutral: neutral, objective statement, no obvious emotion
|
||||
|
||||
2. emotion_intensity (Emotion Intensity):
|
||||
- 0.0-0.3: weak emotion
|
||||
- 0.3-0.7: moderate emotion
|
||||
- 0.7-1.0: strong emotion
|
||||
|
||||
{% if extract_keywords %}
|
||||
3. emotion_keywords (Emotion Keywords):
|
||||
- Words directly expressing emotions in the original sentence
|
||||
- Extract up to 3 keywords
|
||||
- Return empty list if no obvious emotion words
|
||||
{% else %}
|
||||
3. emotion_keywords (Emotion Keywords):
|
||||
- Return empty list
|
||||
{% endif %}
|
||||
|
||||
{% if enable_subject %}
|
||||
4. emotion_subject (Emotion Subject):
|
||||
- self: user's own emotions (includes "I", "we", "us" and other first-person pronouns)
|
||||
- other: others' emotions (includes names, "he/she" and other third-person pronouns)
|
||||
- object: evaluation of things (for products, places, events, etc.)
|
||||
|
||||
Note:
|
||||
- If multiple subjects are present, prioritize identifying the user (self)
|
||||
- If the subject cannot be clearly determined, default to self
|
||||
|
||||
5. emotion_target (Emotion Target):
|
||||
- If there is a clear emotion target, extract its name
|
||||
- If there is no clear target, return null
|
||||
{% else %}
|
||||
4. emotion_subject (Emotion Subject):
|
||||
- Default to self
|
||||
|
||||
5. emotion_target (Emotion Target):
|
||||
- Return null
|
||||
{% endif %}
|
||||
|
||||
Notes:
|
||||
- If the statement is an objective factual statement with no obvious emotion, mark as neutral
|
||||
- Emotion intensity should match the context, do not over-interpret
|
||||
- Emotion keywords should be accurate, do not add words not in the original sentence
|
||||
- Subject classification should be accurate, prioritize identifying the user (self)
|
||||
|
||||
Please return the result in JSON format.
|
||||
{% endif %}
|
||||
|
||||
@@ -24,6 +24,23 @@ This scenario belongs to the **{{ domain }}** domain. Consider domain-specific c
|
||||
{% endif %}
|
||||
{%- endif %}
|
||||
|
||||
===Output Language===
|
||||
{% if language == "en" -%}
|
||||
**IMPORTANT: All output content MUST be in English.**
|
||||
- Class names (name field): English in PascalCase format
|
||||
- Chinese name (name_chinese field): Provide Chinese translation
|
||||
- Descriptions: MUST be in English
|
||||
- Examples: MUST be in English
|
||||
- Domain: MUST be in English
|
||||
{%- else -%}
|
||||
**IMPORTANT: Output content language requirements:**
|
||||
- Class names (name field): English in PascalCase format
|
||||
- Chinese name (name_chinese field): Chinese translation
|
||||
- Descriptions: MUST be in Chinese (中文)
|
||||
- Examples: MUST be in Chinese (中文)
|
||||
- Domain: Can be in Chinese or English
|
||||
{%- endif %}
|
||||
|
||||
===Extraction Rules===
|
||||
|
||||
{% if language == "zh" %}
|
||||
@@ -99,16 +116,31 @@ This scenario belongs to the **{{ domain }}** domain. Consider domain-specific c
|
||||
- Aim for a balanced set covering the main concepts in the scenario
|
||||
- Quality over quantity: prefer well-defined classes over exhaustive lists
|
||||
|
||||
|
||||
**5. Clear Descriptions:**
|
||||
{% if language == "en" -%}
|
||||
- Provide concise, informative descriptions in English (max 500 characters)
|
||||
- Describe what the class represents, not specific instances
|
||||
- Use clear, natural English language that explains the class's role in the domain
|
||||
{%- else -%}
|
||||
- Provide concise, informative descriptions in English (max 500 characters)
|
||||
- Describe what the class represents, not specific instances
|
||||
- Use clear, natural English language
|
||||
{%- endif %}
|
||||
|
||||
**6. Concrete Examples:**
|
||||
{% if language == "en" -%}
|
||||
- Provide 2-5 concrete instance examples in English for each class
|
||||
- Examples should be specific, realistic instances of the class
|
||||
- Examples help clarify the class's scope and meaning
|
||||
- Use natural English language for examples
|
||||
- Example format: ["Example1", "Example2", "Example3"]
|
||||
{%- else -%}
|
||||
- Provide 2-5 concrete instance examples in English for each class
|
||||
- Examples should be specific, realistic instances of the class
|
||||
- Examples help clarify the class's scope and meaning
|
||||
- Example format: ["Example1", "Example2", "Example3"]
|
||||
{%- endif %}
|
||||
|
||||
**7. Class Hierarchy:**
|
||||
- Identify parent-child relationships where applicable
|
||||
@@ -234,6 +266,64 @@ This scenario belongs to the **{{ domain }}** domain. Consider domain-specific c
|
||||
}
|
||||
|
||||
{% else %}
|
||||
|
||||
{% if language == "en" -%}
|
||||
**Example 1 (Healthcare Domain):**
|
||||
Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments."
|
||||
|
||||
Output:
|
||||
{
|
||||
"classes": [
|
||||
{
|
||||
"name": "Patient",
|
||||
"name_chinese": "患者",
|
||||
"description": "A person who receives medical care or treatment at a healthcare facility",
|
||||
"examples": ["Outpatient", "Inpatient", "Emergency patient", "Chronic disease patient"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Person",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "MedicalProcedure",
|
||||
"name_chinese": "医疗程序",
|
||||
"description": "A systematic operation or process performed for medical diagnosis or treatment",
|
||||
"examples": ["Surgery", "Blood test", "X-ray examination", "Vaccination"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Process",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Diagnosis",
|
||||
"name_chinese": "诊断",
|
||||
"description": "The identification of a disease or condition based on symptoms and examination results",
|
||||
"examples": ["Diabetes diagnosis", "Cancer diagnosis", "Flu diagnosis"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Concept",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Doctor",
|
||||
"name_chinese": "医生",
|
||||
"description": "A licensed medical professional who diagnoses and treats patients",
|
||||
"examples": ["General practitioner", "Surgeon", "Cardiologist"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Role",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Treatment",
|
||||
"name_chinese": "治疗",
|
||||
"description": "Medical care or therapy provided to cure or manage a disease condition",
|
||||
"examples": ["Medication therapy", "Physical therapy", "Chemotherapy", "Surgical treatment"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Process",
|
||||
"domain": "Healthcare"
|
||||
}
|
||||
],
|
||||
"domain": "Healthcare",
|
||||
"namespace": "http://example.org/healthcare#"
|
||||
}
|
||||
{%- else -%}
|
||||
**Example 1 (Healthcare Domain):**
|
||||
Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments."
|
||||
|
||||
@@ -334,6 +424,7 @@ Output:
|
||||
"domain": "Education"
|
||||
}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
===Output Format===
|
||||
|
||||
|
||||
@@ -5,8 +5,13 @@
|
||||
|
||||
===Tasks===
|
||||
|
||||
{% if language == "zh" %}
|
||||
你的任务是根据详细的提取指南,从提供的对话片段中识别和提取陈述句。
|
||||
每个陈述句必须按照下面提到的标准进行标记。
|
||||
{% else %}
|
||||
Your task is to identify and extract declarative statements from the provided conversational chunk based on the detailed extraction guidelines.
|
||||
Each statement must be labeled as per the criteria mentioned below.
|
||||
{% endif %}
|
||||
|
||||
===Inputs===
|
||||
{% if inputs %}
|
||||
@@ -17,6 +22,32 @@ Each statement must be labeled as per the criteria mentioned below.
|
||||
|
||||
|
||||
===Extraction Instructions===
|
||||
{% if language == "zh" %}
|
||||
{% if granularity %}
|
||||
{% if granularity == 3 %}
|
||||
原子化和清晰:构建陈述句以清楚地显示单一的主谓宾关系。最好有多个较小的陈述句,而不是一个复杂的陈述句。
|
||||
上下文独立:陈述句必须在不需要阅读整个对话的情况下可以理解。
|
||||
{% elif granularity == 2 %}
|
||||
在句子级别提取陈述句。每个陈述句应对应一个单一、完整的思想(通常是来源中的一个完整句子),但要重新表述以获得最大的清晰度,删除对话填充词(例如,"嗯"、"像"、感叹词)。
|
||||
{% elif granularity == 1 %}
|
||||
仅提取精华句子,并将片段总结为多个独立的陈述句,每个陈述句关注事实陈述、用户偏好、关系和显著的时间上下文。
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
上下文解析要求:
|
||||
- 将指示代词("那个"、"这个"、"那些"、"这些")解析为其具体指代对象
|
||||
- 如果陈述句包含无法从对话上下文中解析的模糊引用,则:
|
||||
a) 扩展陈述句以包含对话早期的缺失上下文
|
||||
b) 标记陈述句为需要额外上下文
|
||||
c) 如果陈述句在没有上下文的情况下变得无意义,则跳过提取
|
||||
|
||||
对话上下文和共指消解:
|
||||
- 将每个陈述句归属于说出它的参与者。
|
||||
- 如果参与者列表为说话者提供了名称(例如,"李雪(用户)"),请在提取的陈述句中使用具体名称("李雪"),而不是通用角色("用户")。
|
||||
- 将所有代词解析为对话上下文中的具体人物或实体。
|
||||
- 识别并将抽象引用解析为其具体名称(如果提到)。
|
||||
- 将缩写和首字母缩略词扩展为其完整形式。
|
||||
{% else %}
|
||||
{% if granularity %}
|
||||
{% if granularity == 3 %}
|
||||
Atomic & Clear: Structure statements to clearly show a single subject-predicate-object relationship. It is better to have multiple smaller statements than one complex one.
|
||||
@@ -29,7 +60,7 @@ Extract only essence sentences and summarize the chunk into multiple, standalone
|
||||
{% endif %}
|
||||
|
||||
Context Resolution Requirements:
|
||||
- Resolve demonstrative pronouns ("that," "this," "those","这个", "那个") to their specific referents
|
||||
- Resolve demonstrative pronouns ("that," "this," "those") to their specific referents
|
||||
- If a statement contains vague references that cannot be resolved from the conversation context, either:
|
||||
a) Expand the statement to include the missing context from earlier in the conversation
|
||||
b) Mark the statement as requiring additional context
|
||||
@@ -41,16 +72,36 @@ Conversational Context & Co-reference Resolution:
|
||||
- Resolve all pronouns to the specific person or entity from the conversation's context.
|
||||
- Identify and resolve abstract references to their specific names if mentioned.
|
||||
- Expand abbreviations and acronyms to their full form.
|
||||
{% endif %}
|
||||
|
||||
{% if include_dialogue_context %}
|
||||
{% if language == "zh" %}
|
||||
===完整对话上下文===
|
||||
以下是完整的对话上下文,以帮助您理解引用、代词和对话流程:
|
||||
{% else %}
|
||||
===Full Dialogue Context===
|
||||
The following is the complete dialogue context to help you understand references, pronouns, and conversational flow:
|
||||
{% endif %}
|
||||
|
||||
{{ dialogue_context }}
|
||||
|
||||
{% if language == "zh" %}
|
||||
===对话上下文结束===
|
||||
{% else %}
|
||||
===End of Dialogue Context===
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% if language == "zh" %}
|
||||
过滤和格式化:
|
||||
|
||||
- 仅提取陈述句。
|
||||
不要提取问题、命令、问候语或对话填充词。
|
||||
时间精度:
|
||||
|
||||
包括任何明确的日期、时间或定量限定符。
|
||||
如果一个句子既描述了事件的开始(静态)又描述了其持续性质(动态),则将两者提取为单独的陈述句。
|
||||
{% else %}
|
||||
Filtering and Formatting:
|
||||
|
||||
- Extract only declarative statements.
|
||||
@@ -59,18 +110,114 @@ Temporal Precision:
|
||||
|
||||
Include any explicit dates, times, or quantitative qualifiers.
|
||||
If a sentence describes both the start of an event (static) and its ongoing nature (dynamic), extract both as separate statements.
|
||||
{% endif %}
|
||||
|
||||
{%- if definitions %}
|
||||
{%- for section_key, section_dict in definitions.items() %}
|
||||
==== {{ tidy(section_key) | upper }} DEFINITIONS & GUIDANCE ====
|
||||
==== {{ tidy(section_key) | upper }} {% if language == "zh" %}定义和指导{% else %}DEFINITIONS & GUIDANCE{% endif %} ====
|
||||
{%- for category, details in section_dict.items() %}
|
||||
{{ loop.index }}. {{ category }}
|
||||
- Definition: {{ details.get("definition", "") }}
|
||||
- {% if language == "zh" %}定义{% else %}Definition{% endif %}: {{ details.get("definition", "") }}
|
||||
{% endfor -%}
|
||||
{% endfor -%}
|
||||
{% endif -%}
|
||||
|
||||
===Examples===
|
||||
{% if language == "zh" %}
|
||||
示例 1: 英文对话
|
||||
示例片段: """
|
||||
日期: 2024年3月15日
|
||||
参与者:
|
||||
- Sarah Chen (用户)
|
||||
- 助手 (AI)
|
||||
|
||||
用户: "我最近一直在尝试水彩画,画了一些花朵。"
|
||||
AI: "水彩画很有趣!水彩颜料通常由颜料与阿拉伯树胶等粘合剂混合而成。你觉得怎么样?"
|
||||
用户: "我认为色彩组合可以改进,但我真的很喜欢玫瑰和百合。"
|
||||
"""
|
||||
|
||||
示例输出: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "Sarah Chen 最近一直在尝试水彩画。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 画了一些花朵。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "水彩颜料通常由颜料与阿拉伯树胶等粘合剂混合而成。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "ATEMPORAL",
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 认为她的水彩画中的色彩组合可以改进。",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 真的很喜欢玫瑰和百合。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
示例 2: 中文对话示例
|
||||
示例片段: """
|
||||
日期: 2024年3月15日
|
||||
参与者:
|
||||
- 张曼婷 (用户)
|
||||
- 小助手 (AI助手)
|
||||
|
||||
用户: "我最近在尝试水彩画,画了一些花朵。"
|
||||
AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合剂混合而成。你觉得怎么样?"
|
||||
用户: "我觉得色彩搭配还有提升的空间,不过我很喜欢玫瑰和百合这两种花。"
|
||||
"""
|
||||
|
||||
示例输出: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "张曼婷最近在尝试水彩画。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷画了一些花朵。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "水彩颜料通常由颜料和阿拉伯树胶等粘合剂混合而成。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "ATEMPORAL",
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷很喜欢玫瑰和百合。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
}
|
||||
]
|
||||
}
|
||||
{% else %}
|
||||
Example 1: English Conversation
|
||||
Example Chunk: """
|
||||
Date: March 15, 2024
|
||||
@@ -164,8 +311,33 @@ Example Output: {
|
||||
}
|
||||
]
|
||||
}
|
||||
{% endif %}
|
||||
===End of Examples===
|
||||
|
||||
{% if language == "zh" %}
|
||||
===反思过程===
|
||||
|
||||
提取陈述句后,执行以下自我审查步骤:
|
||||
|
||||
**步骤 1: 归属检查**
|
||||
- 确认每个陈述句都正确归属于正确的说话者
|
||||
- 验证说话者名称在整个过程中使用一致
|
||||
- 检查 AI 助手陈述句是否正确归属
|
||||
|
||||
**步骤 2: 完整性审查**
|
||||
- 确保没有遗漏重要的陈述句
|
||||
- 检查时间信息是否保留
|
||||
|
||||
**步骤 3: 分类验证**
|
||||
- 审查 statement_type 分类(FACT/OPINION/PREDICTION/SUGGESTION)
|
||||
- 验证 temporal_type 分配(STATIC/DYNAMIC/ATEMPORAL)
|
||||
- 确保分类与提供的定义一致
|
||||
|
||||
**步骤 4: 最终质量检查**
|
||||
- 删除任何问题、命令或对话填充词
|
||||
- 验证 JSON 格式合规性
|
||||
- 确认输出语言与输入语言匹配
|
||||
{% else %}
|
||||
===Reflection Process===
|
||||
|
||||
After extracting statements, perform the following self-review steps:
|
||||
@@ -188,6 +360,7 @@ After extracting statements, perform the following self-review steps:
|
||||
- Remove any questions, commands, or conversational filler
|
||||
- Verify JSON format compliance
|
||||
- Confirm output language matches input language
|
||||
{% endif %}
|
||||
|
||||
**Output format**
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
@@ -198,10 +371,21 @@ After extracting statements, perform the following self-review steps:
|
||||
5. Example of proper escaping: "statement": "John said: \"I really like this book.\""
|
||||
|
||||
**LANGUAGE REQUIREMENT:**
|
||||
{% if language == "zh" %}
|
||||
- 输出语言应始终与输入语言匹配
|
||||
- 如果输入是中文,则用中文提取陈述句
|
||||
- 如果输入是英文,则用英文提取陈述句
|
||||
- 保留原始语言,不要翻译
|
||||
{% else %}
|
||||
- The output language should ALWAYS match the input language
|
||||
- If input is in English, extract statements in English
|
||||
- If input is in Chinese, extract statements in Chinese
|
||||
- Preserve the original language and do not translate
|
||||
{% endif %}
|
||||
|
||||
{% if language == "zh" %}
|
||||
仅返回与以下架构匹配的 JSON 对象数组中提取的标记陈述句列表:
|
||||
{% else %}
|
||||
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
|
||||
{{ json_schema }}
|
||||
{% endif %}
|
||||
{{ json_schema }}
|
||||
|
||||
@@ -14,68 +14,113 @@
|
||||
#}
|
||||
# Task
|
||||
|
||||
{% if language == "zh" %}
|
||||
从提供的陈述句中提取时间信息(日期和时间范围)。确定所描述的关系或事件何时生效以及何时结束(如果适用)。
|
||||
{% else %}
|
||||
Extract temporal information (dates and time ranges) from the provided statement. Determine when the relationship or event described became valid and when it ended (if applicable).
|
||||
{% endif %}
|
||||
|
||||
# Input Data
|
||||
# {% if language == "zh" %}输入数据{% else %}Input Data{% endif %}
|
||||
{% if inputs %}
|
||||
{% for key, val in inputs.items() %}
|
||||
- {{ key }}: {{val}}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
# Temporal Fields
|
||||
# {% if language == "zh" %}时间字段{% else %}Temporal Fields{% endif %}
|
||||
|
||||
{% if language == "zh" %}
|
||||
- **valid_at**: 关系/事件开始或成为真实的时间(ISO 8601 格式)
|
||||
- **invalid_at**: 关系/事件结束或停止为真的时间(ISO 8601 格式,如果正在进行则为 null)
|
||||
{% else %}
|
||||
- **valid_at**: When the relationship/event started or became true (ISO 8601 format)
|
||||
- **invalid_at**: When the relationship/event ended or stopped being true (ISO 8601 format, or null if ongoing)
|
||||
{% endif %}
|
||||
|
||||
# Extraction Rules
|
||||
# {% if language == "zh" %}提取规则{% else %}Extraction Rules{% endif %}
|
||||
|
||||
## Core Principles
|
||||
## {% if language == "zh" %}核心原则{% else %}Core Principles{% endif %}
|
||||
{% if language == "zh" %}
|
||||
1. **仅使用明确陈述的时间信息** - 不要从外部知识推断日期
|
||||
2. **使用参考/发布日期作为"现在"** 解释相对时间时
|
||||
3. **仅在日期与关系的有效性相关时设置日期** - 忽略偶然的时间提及
|
||||
4. **对于时间点事件**,仅设置 `valid_at`
|
||||
{% else %}
|
||||
1. **Only use explicitly stated temporal information** - do not infer dates from external knowledge
|
||||
2. **Use the reference/publication date as "now"** when interpreting relative times
|
||||
3. **Set dates only if they relate to the validity of the relationship** - ignore incidental time mentions
|
||||
4. **For point-in-time events**, set only `valid_at`
|
||||
{% endif %}
|
||||
|
||||
## Date Format Requirements
|
||||
## {% if language == "zh" %}日期格式要求{% else %}Date Format Requirements{% endif %}
|
||||
{% if language == "zh" %}
|
||||
- 使用 ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ`
|
||||
- 如果未指定时间,使用 `00:00:00`(午夜)
|
||||
- 如果仅提及年份,根据情况使用 `YYYY-01-01`(开始)或 `YYYY-12-31`(结束)
|
||||
- 如果仅提及月份,使用月份的第一天或最后一天
|
||||
- 始终包含时区(如果未指定,使用 `Z` 表示 UTC)
|
||||
- 根据参考日期将相对时间("两周前"、"去年")转换为绝对日期
|
||||
{% else %}
|
||||
- Use ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ`
|
||||
- If no time specified, use `00:00:00` (midnight)
|
||||
- If only year mentioned, use `YYYY-01-01` (start) or `YYYY-12-31` (end) as appropriate
|
||||
- If only month mentioned, use first or last day of month
|
||||
- Always include timezone (use `Z` for UTC if unspecified)
|
||||
- Convert relative times ("two weeks ago", "last year") to absolute dates based on reference date
|
||||
{% endif %}
|
||||
|
||||
## Statement Type Rules
|
||||
## {% if language == "zh" %}陈述句类型规则{% else %}Statement Type Rules{% endif %}
|
||||
|
||||
{{ inputs.get("statement_type") | upper }} Statement Guidance:
|
||||
{{ inputs.get("statement_type") | upper }} {% if language == "zh" %}陈述句指导{% else %}Statement Guidance{% endif %}:
|
||||
{%for key, guide in statement_guide.items() %}
|
||||
- {{ tidy(key) | capitalize }}: {{ guide }}
|
||||
{% endfor %}
|
||||
|
||||
**Special Cases:**
|
||||
**{% if language == "zh" %}特殊情况{% else %}Special Cases{% endif %}:**
|
||||
{% if language == "zh" %}
|
||||
- **意见陈述句**: 仅设置 `valid_at`(意见表达的时间)
|
||||
- **预测陈述句**: 如果明确提及,将 `invalid_at` 设置为预测窗口的结束
|
||||
{% else %}
|
||||
- **Opinion statements**: Set only `valid_at` (when opinion was expressed)
|
||||
- **Prediction statements**: Set `invalid_at` to the end of the prediction window if explicitly mentioned
|
||||
{% endif %}
|
||||
|
||||
## Temporal Type Rules
|
||||
## {% if language == "zh" %}时间类型规则{% else %}Temporal Type Rules{% endif %}
|
||||
|
||||
{{ inputs.get("temporal_type") | upper }} Temporal Type Guidance:
|
||||
{{ inputs.get("temporal_type") | upper }} {% if language == "zh" %}时间类型指导{% else %}Temporal Type Guidance{% endif %}:
|
||||
{% for key, guide in temporal_guide.items() %}
|
||||
- {{ tidy(key) | capitalize }}: {{ guide }}
|
||||
{% endfor %}
|
||||
|
||||
{% if inputs.get('quarter') and inputs.get('publication_date') %}
|
||||
## Quarter Reference
|
||||
## {% if language == "zh" %}季度参考{% else %}Quarter Reference{% endif %}
|
||||
{% if language == "zh" %}
|
||||
假设 {{ inputs.quarter }} 在 {{ inputs.publication_date }} 结束。从此基线计算任何季度引用(Q1、Q2 等)的日期。
|
||||
{% else %}
|
||||
Assume {{ inputs.quarter }} ends on {{ inputs.publication_date }}. Calculate dates for any quarter references (Q1, Q2, etc.) from this baseline.
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
# Output Requirements
|
||||
# {% if language == "zh" %}输出要求{% else %}Output Requirements{% endif %}
|
||||
|
||||
## JSON Formatting (CRITICAL)
|
||||
## {% if language == "zh" %}JSON 格式化(关键){% else %}JSON Formatting (CRITICAL){% endif %}
|
||||
{% if language == "zh" %}
|
||||
1. 使用**仅标准 ASCII 双引号** (") - 永远不要使用中文引号("")或其他 Unicode 变体
|
||||
2. 使用反斜杠转义内部引号: `\"`
|
||||
3. JSON 字符串值中不要有换行符
|
||||
4. 正确关闭并用逗号分隔所有字段
|
||||
{% else %}
|
||||
1. Use **only standard ASCII double quotes** (") - never use Chinese quotes ("") or other Unicode variants
|
||||
2. Escape internal quotes with backslash: `\"`
|
||||
3. No line breaks within JSON string values
|
||||
4. Properly close and comma-separate all fields
|
||||
{% endif %}
|
||||
|
||||
## Language
|
||||
## {% if language == "zh" %}语言{% else %}Language{% endif %}
|
||||
{% if language == "zh" %}
|
||||
输出语言必须与输入语言匹配。
|
||||
{% else %}
|
||||
Output language must match input language.
|
||||
{% endif %}
|
||||
|
||||
{{ json_schema }}
|
||||
|
||||
@@ -15,6 +15,37 @@ Extract entities and knowledge triplets from the given statement.
|
||||
**Chunk Content:** "{{ chunk_content }}"
|
||||
**Statement:** "{{ statement }}"
|
||||
|
||||
{% if ontology_types %}
|
||||
===Ontology Type Guidance===
|
||||
|
||||
**CRITICAL: Use predefined ontology types for entity classification with the following priority:**
|
||||
|
||||
**Type Priority (from highest to lowest):**
|
||||
1. **[场景类型] Scene Types** - Domain-specific types, use these first if applicable
|
||||
2. **[通用类型] General Types** - Common types from standard ontologies (DBpedia)
|
||||
3. **[通用父类] Parent Types** - Provide type hierarchy context
|
||||
|
||||
**Type Matching Rules:**
|
||||
- Entity type MUST exactly match one of the predefined type names
|
||||
- Do NOT modify, translate, or use variations of type names
|
||||
- Prefer scene types over general types when both could apply
|
||||
- If uncertain between types, check the type description for guidance
|
||||
|
||||
**Predefined Ontology Types:**
|
||||
{{ ontology_types }}
|
||||
|
||||
{% if type_hierarchy_hints %}
|
||||
**Type Hierarchy Reference:**
|
||||
The following shows type inheritance relationships (Child → Parent → Grandparent):
|
||||
{% for hint in type_hierarchy_hints %}
|
||||
- {{ hint }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
**Available Type Names (use EXACTLY as shown):**
|
||||
{{ ontology_type_names | join(', ') }}
|
||||
|
||||
{% endif %}
|
||||
===Guidelines===
|
||||
|
||||
**Entity Extraction:**
|
||||
|
||||
@@ -1,2 +1,7 @@
|
||||
{% if language == "zh" %}
|
||||
你是一个从对话消息中提取实体节点的 AI 助手。
|
||||
你的主要任务是提取和分类说话者以及对话中提到的其他重要实体。
|
||||
{% else %}
|
||||
You are an AI assistant that extracts entity nodes from conversational messages.
|
||||
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation.
|
||||
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation.
|
||||
{% endif %}
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
{% if language == "zh" %}
|
||||
给定一个对话上下文和一个当前消息。
|
||||
你的任务是提取在当前消息中**明确或隐含**提到的用户名称和年龄。
|
||||
代词引用(如 he/she/they 或 this/that/those)应消歧为引用实体的名称。
|
||||
|
||||
{{ message }}
|
||||
{% else %}
|
||||
You are given a conversation context and a CURRENT MESSAGE.
|
||||
Your task is to extract user name and age mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
|
||||
Pronoun references such as he/she/they or this/that/those should be disambiguated to the names of the reference entities.
|
||||
|
||||
{{ message }}
|
||||
{{ message }}
|
||||
{% endif %}
|
||||
|
||||
@@ -11,7 +11,7 @@ import logging
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
from app.core.memory.models.ontology_models import OntologyClass
|
||||
from app.core.memory.models.ontology_scenario_models import OntologyClass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,7 +20,7 @@ from owlready2 import (
|
||||
OwlReadyInconsistentOntologyError,
|
||||
)
|
||||
|
||||
from app.core.memory.models.ontology_models import OntologyClass
|
||||
from app.core.memory.models.ontology_scenario_models import OntologyClass
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -583,3 +583,156 @@ class OWLValidator:
|
||||
is_compatible = len(warnings) == 0
|
||||
|
||||
return is_compatible, warnings
|
||||
|
||||
def parse_owl_content(
|
||||
self,
|
||||
owl_content: str,
|
||||
format: str = "rdfxml"
|
||||
) -> List[dict]:
|
||||
"""从 OWL 内容解析出本体类型
|
||||
|
||||
支持解析 RDF/XML、Turtle 和 JSON 格式的 OWL 文件,
|
||||
提取其中定义的 owl:Class 及其 rdfs:label 和 rdfs:comment。
|
||||
|
||||
Args:
|
||||
owl_content: OWL 文件内容字符串
|
||||
format: 文件格式,支持 "rdfxml"、"turtle"、"json"
|
||||
|
||||
Returns:
|
||||
解析出的类型列表,每个元素包含:
|
||||
- name: 类型名称(英文标识符)
|
||||
- name_chinese: 中文名称(如果有)
|
||||
- description: 类型描述
|
||||
- parent_class: 父类名称
|
||||
|
||||
Raises:
|
||||
ValueError: 如果格式不支持或解析失败
|
||||
|
||||
Examples:
|
||||
>>> validator = OWLValidator()
|
||||
>>> classes = validator.parse_owl_content(owl_xml, format="rdfxml")
|
||||
>>> for cls in classes:
|
||||
... print(cls["name"], cls["description"])
|
||||
"""
|
||||
valid_formats = ["rdfxml", "turtle", "json"]
|
||||
if format not in valid_formats:
|
||||
raise ValueError(
|
||||
f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}"
|
||||
)
|
||||
|
||||
# JSON 格式单独处理
|
||||
if format == "json":
|
||||
return self._parse_json_owl(owl_content)
|
||||
|
||||
# 使用 rdflib 解析 RDF/XML 或 Turtle
|
||||
try:
|
||||
from rdflib import Graph, RDF, RDFS, OWL, Namespace
|
||||
|
||||
g = Graph()
|
||||
rdf_format = "xml" if format == "rdfxml" else "turtle"
|
||||
g.parse(data=owl_content, format=rdf_format)
|
||||
|
||||
classes = []
|
||||
|
||||
# 查找所有 owl:Class
|
||||
for cls_uri in g.subjects(RDF.type, OWL.Class):
|
||||
cls_str = str(cls_uri)
|
||||
|
||||
# 跳过空节点和 OWL 内置类
|
||||
if cls_str.startswith("http://www.w3.org/") or "/.well-known/" in cls_str:
|
||||
continue
|
||||
|
||||
# 提取类名(从 URI 中获取本地名称)
|
||||
if '#' in cls_str:
|
||||
name = cls_str.split('#')[-1]
|
||||
else:
|
||||
name = cls_str.split('/')[-1]
|
||||
|
||||
# 跳过空名称
|
||||
if not name or name == "Thing":
|
||||
continue
|
||||
|
||||
# 获取 rdfs:label(可能有多个,包括中英文)
|
||||
labels = list(g.objects(cls_uri, RDFS.label))
|
||||
name_chinese = None
|
||||
label_str = name # 默认使用 URI 中的名称
|
||||
|
||||
for label in labels:
|
||||
label_text = str(label)
|
||||
# 检查是否包含中文
|
||||
if any('\u4e00' <= char <= '\u9fff' for char in label_text):
|
||||
name_chinese = label_text
|
||||
else:
|
||||
label_str = label_text
|
||||
|
||||
# 获取 rdfs:comment(描述)
|
||||
comments = list(g.objects(cls_uri, RDFS.comment))
|
||||
description = str(comments[0]) if comments else None
|
||||
|
||||
# 获取父类(rdfs:subClassOf)
|
||||
parent_class = None
|
||||
for parent_uri in g.objects(cls_uri, RDFS.subClassOf):
|
||||
parent_str = str(parent_uri)
|
||||
# 跳过 owl:Thing
|
||||
if parent_str == str(OWL.Thing) or parent_str.endswith("#Thing"):
|
||||
continue
|
||||
# 提取父类名称
|
||||
if '#' in parent_str:
|
||||
parent_class = parent_str.split('#')[-1]
|
||||
else:
|
||||
parent_class = parent_str.split('/')[-1]
|
||||
break # 只取第一个非 Thing 的父类
|
||||
|
||||
classes.append({
|
||||
"name": name,
|
||||
"name_chinese": name_chinese,
|
||||
"description": description,
|
||||
"parent_class": parent_class
|
||||
})
|
||||
|
||||
logger.info(f"Parsed {len(classes)} classes from OWL content (format: {format})")
|
||||
return classes
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to parse OWL(文档格式不正确) content: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
def _parse_json_owl(self, json_content: str) -> List[dict]:
|
||||
"""解析 JSON 格式的 OWL 内容
|
||||
|
||||
JSON 格式是简化的本体表示,由 export_to_owl 的 json 格式导出。
|
||||
|
||||
Args:
|
||||
json_content: JSON 格式的 OWL 内容
|
||||
|
||||
Returns:
|
||||
解析出的类型列表
|
||||
"""
|
||||
import json
|
||||
|
||||
try:
|
||||
data = json.loads(json_content)
|
||||
|
||||
# 检查是否是我们导出的 JSON 格式
|
||||
if "ontology" in data and "classes" in data["ontology"]:
|
||||
raw_classes = data["ontology"]["classes"]
|
||||
elif "classes" in data:
|
||||
raw_classes = data["classes"]
|
||||
else:
|
||||
raise ValueError("Invalid JSON format: missing 'classes' field")
|
||||
|
||||
classes = []
|
||||
for cls in raw_classes:
|
||||
classes.append({
|
||||
"name": cls.get("name", ""),
|
||||
"name_chinese": cls.get("name_chinese"),
|
||||
"description": cls.get("description"),
|
||||
"parent_class": cls.get("parent_class")
|
||||
})
|
||||
|
||||
logger.info(f"Parsed {len(classes)} classes from JSON content")
|
||||
return classes
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON content: {str(e)}") from e
|
||||
|
||||
@@ -18,6 +18,8 @@ from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.tool import ToolNode
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from app.core.workflow.nodes.code import CodeNode
|
||||
|
||||
__all__ = [
|
||||
"BaseNode",
|
||||
@@ -35,5 +37,7 @@ __all__ = [
|
||||
"JinjaRenderNode",
|
||||
"ParameterExtractorNode",
|
||||
"QuestionClassifierNode",
|
||||
"ToolNode"
|
||||
"ToolNode",
|
||||
"CodeNode",
|
||||
"VariableAggregatorNode"
|
||||
]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||
from app.core.workflow.nodes.code.node import CodeNode
|
||||
|
||||
__all__ = ["CodeNode"]
|
||||
__all__ = ["CodeNode", "CodeNodeConfig"]
|
||||
|
||||
@@ -216,7 +216,7 @@ class LLMNode(BaseNode):
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||
|
||||
# 返回 AIMessage(包含响应元数据)
|
||||
return response if isinstance(response, AIMessage) else AIMessage(content=content)
|
||||
return AIMessage(content=content, response_metadata=response.response_metadata)
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)"""
|
||||
|
||||
@@ -193,7 +193,8 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
model_resp = await llm.ainvoke(messages)
|
||||
self.response_metadata = model_resp.response_metadata
|
||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||
model_message = self.process_model_output(model_resp.content)
|
||||
result = json_repair.repair_json(model_message, return_objects=True)
|
||||
logger.info(f"node: {self.node_id} get params:{result}")
|
||||
|
||||
return result
|
||||
|
||||
@@ -131,7 +131,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
]
|
||||
|
||||
response = await llm.ainvoke(messages)
|
||||
result = response.content.strip()
|
||||
result = self.process_model_output(response.content)
|
||||
self.response_metadata = response.response_metadata
|
||||
|
||||
if result in category_names:
|
||||
|
||||
328
api/app/query_ontology_matched_entities.py
Normal file
328
api/app/query_ontology_matched_entities.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
api\scripts\query_ontology_matched_entities.py
|
||||
|
||||
根据 end_user_id 查询 Neo4j 中的 ExtractedEntity 节点,
|
||||
并筛选出 entity_type 与以下类型匹配的实体:
|
||||
1. 场景本体类型(ontology_class 表)
|
||||
2. 通用本体类型(General_purpose_entity.ttl 等文件)
|
||||
|
||||
用法: python scripts/query_ontology_matched_entities.py <end_user_id> [config_id]
|
||||
示例: python scripts/query_ontology_matched_entities.py 075660cf-08e6-40a6-a76e-308b6f52fbf1
|
||||
python scripts/query_ontology_matched_entities.py 075660cf-08e6-40a6-a76e-308b6f52fbf1 fd547bb9-7b9e-47ea-ae53-242d208a31a2
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
from typing import List, Dict, Any, Set, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from app.db import SessionLocal
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.core.memory.ontology_services.ontology_type_loader import (
|
||||
get_general_ontology_registry,
|
||||
is_general_ontology_enabled,
|
||||
)
|
||||
|
||||
|
||||
async def get_entities_by_end_user_id(connector: Neo4jConnector, end_user_id: str) -> List[Dict[str, Any]]:
|
||||
"""从 Neo4j 查询指定 end_user_id 的所有实体"""
|
||||
|
||||
query = """
|
||||
MATCH (n:ExtractedEntity)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
n.id AS id,
|
||||
n.name AS name,
|
||||
n.entity_type AS entity_type,
|
||||
n.description AS description,
|
||||
n.end_user_id AS end_user_id,
|
||||
n.created_at AS created_at
|
||||
ORDER BY n.created_at DESC
|
||||
"""
|
||||
|
||||
results = await connector.execute_query(query, end_user_id=end_user_id)
|
||||
return results
|
||||
|
||||
|
||||
def get_ontology_types_from_scene(db, scene_id: UUID) -> Set[str]:
|
||||
"""获取场景下所有本体类型名称"""
|
||||
class_repo = OntologyClassRepository(db)
|
||||
ontology_classes = class_repo.get_by_scene(scene_id)
|
||||
return {oc.class_name for oc in ontology_classes}
|
||||
|
||||
|
||||
def get_ontology_types_from_config(db, config_id: UUID) -> Optional[Set[str]]:
|
||||
"""从记忆配置获取关联的本体类型"""
|
||||
memory_config = MemoryConfigRepository.get_by_id(db, config_id)
|
||||
if not memory_config or not memory_config.scene_id:
|
||||
return None
|
||||
return get_ontology_types_from_scene(db, memory_config.scene_id)
|
||||
|
||||
|
||||
def get_all_ontology_types(db) -> Dict[str, Set[str]]:
|
||||
"""获取所有工作空间的本体类型"""
|
||||
from app.models.ontology_scene import OntologyScene
|
||||
|
||||
scenes = db.query(OntologyScene).all()
|
||||
all_types = {}
|
||||
|
||||
for scene in scenes:
|
||||
class_repo = OntologyClassRepository(db)
|
||||
ontology_classes = class_repo.get_by_scene(scene.scene_id)
|
||||
for oc in ontology_classes:
|
||||
if oc.class_name not in all_types:
|
||||
all_types[oc.class_name] = set()
|
||||
all_types[oc.class_name].add(scene.scene_name)
|
||||
|
||||
return all_types
|
||||
|
||||
|
||||
def get_general_ontology_types() -> Set[str]:
|
||||
"""获取通用本体类型名称集合"""
|
||||
if not is_general_ontology_enabled():
|
||||
return set()
|
||||
|
||||
try:
|
||||
registry = get_general_ontology_registry()
|
||||
return set(registry.types.keys())
|
||||
except Exception as e:
|
||||
print(f"⚠️ 加载通用本体类型失败: {e}")
|
||||
return set()
|
||||
|
||||
|
||||
async def query_ontology_matched_entities(end_user_id: str, config_id: Optional[str] = None):
|
||||
"""查询与本体类型匹配的实体"""
|
||||
|
||||
print(f"\n{'='*70}")
|
||||
print(f"查询 Neo4j 中与本体类型匹配的实体")
|
||||
print(f"{'='*70}")
|
||||
print(f"end_user_id: {end_user_id}")
|
||||
|
||||
db = SessionLocal()
|
||||
connector = Neo4jConnector()
|
||||
|
||||
try:
|
||||
# 1. 获取场景本体类型集合
|
||||
scene_ontology_types: Set[str] = set()
|
||||
scene_name = "所有场景"
|
||||
|
||||
if config_id:
|
||||
try:
|
||||
config_uuid = UUID(config_id)
|
||||
types = get_ontology_types_from_config(db, config_uuid)
|
||||
if types:
|
||||
scene_ontology_types = types
|
||||
memory_config = MemoryConfigRepository.get_by_id(db, config_uuid)
|
||||
if memory_config and memory_config.scene_id:
|
||||
scene_repo = OntologySceneRepository(db)
|
||||
scene = scene_repo.get_by_id(memory_config.scene_id)
|
||||
if scene:
|
||||
scene_name = scene.scene_name
|
||||
print(f"config_id: {config_id}")
|
||||
print(f"关联场景: {scene_name}")
|
||||
except ValueError:
|
||||
print(f"⚠️ 无效的 config_id 格式: {config_id}")
|
||||
|
||||
# 如果没有指定 config_id 或获取失败,获取所有场景本体类型
|
||||
if not scene_ontology_types:
|
||||
all_types = get_all_ontology_types(db)
|
||||
scene_ontology_types = set(all_types.keys())
|
||||
print(f"使用所有场景本体类型进行匹配")
|
||||
|
||||
# 2. 获取通用本体类型
|
||||
general_ontology_types = get_general_ontology_types()
|
||||
|
||||
print(f"\n📋 场景本体类型 (共 {len(scene_ontology_types)} 个):")
|
||||
print(f" {'-'*50}")
|
||||
for i, type_name in enumerate(sorted(scene_ontology_types)[:20], 1):
|
||||
print(f" {i:2}. {type_name}")
|
||||
if len(scene_ontology_types) > 20:
|
||||
print(f" ... 还有 {len(scene_ontology_types) - 20} 个")
|
||||
|
||||
print(f"\n📋 通用本体类型 (共 {len(general_ontology_types)} 个):")
|
||||
print(f" {'-'*50}")
|
||||
sample_general_types = sorted(general_ontology_types)[:20]
|
||||
for i, type_name in enumerate(sample_general_types, 1):
|
||||
print(f" {i:2}. {type_name}")
|
||||
if len(general_ontology_types) > 20:
|
||||
print(f" ... 还有 {len(general_ontology_types) - 20} 个")
|
||||
|
||||
# 3. 从 Neo4j 查询实体
|
||||
print(f"\n🔍 正在查询 Neo4j...")
|
||||
entities = await get_entities_by_end_user_id(connector, end_user_id)
|
||||
|
||||
if not entities:
|
||||
print(f"\n⚠️ 未找到 end_user_id={end_user_id} 的任何实体")
|
||||
return
|
||||
|
||||
print(f" 找到 {len(entities)} 个实体")
|
||||
|
||||
# 4. 分类实体(场景类型、通用类型、未匹配)
|
||||
scene_matched_entities = []
|
||||
general_matched_entities = []
|
||||
both_matched_entities = [] # 同时匹配场景和通用类型
|
||||
unmatched_entities = []
|
||||
|
||||
scene_type_distribution = defaultdict(list)
|
||||
general_type_distribution = defaultdict(list)
|
||||
|
||||
for entity in entities:
|
||||
entity_type = entity.get('entity_type', '')
|
||||
in_scene = entity_type in scene_ontology_types
|
||||
in_general = entity_type in general_ontology_types
|
||||
|
||||
if in_scene and in_general:
|
||||
both_matched_entities.append(entity)
|
||||
scene_type_distribution[entity_type].append(entity)
|
||||
general_type_distribution[entity_type].append(entity)
|
||||
elif in_scene:
|
||||
scene_matched_entities.append(entity)
|
||||
scene_type_distribution[entity_type].append(entity)
|
||||
elif in_general:
|
||||
general_matched_entities.append(entity)
|
||||
general_type_distribution[entity_type].append(entity)
|
||||
else:
|
||||
unmatched_entities.append(entity)
|
||||
|
||||
# 5. 输出匹配场景类型的实体
|
||||
total_scene_matched = len(scene_matched_entities) + len(both_matched_entities)
|
||||
print(f"\n{'='*70}")
|
||||
print(f"✅ 匹配场景本体类型的实体 (共 {total_scene_matched} 个)")
|
||||
print(f"{'='*70}")
|
||||
|
||||
if scene_type_distribution:
|
||||
for type_name in sorted(scene_type_distribution.keys()):
|
||||
entities_of_type = scene_type_distribution[type_name]
|
||||
print(f"\n📌 类型: {type_name} ({len(entities_of_type)} 个)")
|
||||
print(f" {'-'*50}")
|
||||
for entity in entities_of_type[:3]:
|
||||
name = entity.get('name', 'N/A')
|
||||
desc = entity.get('description', '')
|
||||
desc_preview = (desc[:50] + "...") if desc and len(desc) > 50 else (desc or "无描述")
|
||||
print(f" • {name}")
|
||||
print(f" 描述: {desc_preview}")
|
||||
if len(entities_of_type) > 3:
|
||||
print(f" ... 还有 {len(entities_of_type) - 3} 个")
|
||||
else:
|
||||
print(f"\n (无匹配场景类型的实体)")
|
||||
|
||||
# 6. 输出匹配通用类型的实体
|
||||
total_general_matched = len(general_matched_entities) + len(both_matched_entities)
|
||||
print(f"\n{'='*70}")
|
||||
print(f"✅ 匹配通用本体类型的实体 (共 {total_general_matched} 个)")
|
||||
print(f"{'='*70}")
|
||||
|
||||
if general_type_distribution:
|
||||
for type_name in sorted(general_type_distribution.keys()):
|
||||
entities_of_type = general_type_distribution[type_name]
|
||||
print(f"\n📌 类型: {type_name} ({len(entities_of_type)} 个)")
|
||||
print(f" {'-'*50}")
|
||||
for entity in entities_of_type[:3]:
|
||||
name = entity.get('name', 'N/A')
|
||||
desc = entity.get('description', '')
|
||||
desc_preview = (desc[:50] + "...") if desc and len(desc) > 50 else (desc or "无描述")
|
||||
print(f" • {name}")
|
||||
print(f" 描述: {desc_preview}")
|
||||
if len(entities_of_type) > 3:
|
||||
print(f" ... 还有 {len(entities_of_type) - 3} 个")
|
||||
else:
|
||||
print(f"\n (无匹配通用类型的实体)")
|
||||
|
||||
# 7. 输出未匹配的实体
|
||||
print(f"\n{'='*70}")
|
||||
print(f"❌ 未匹配任何本体类型的实体 (共 {len(unmatched_entities)} 个)")
|
||||
print(f"{'='*70}")
|
||||
|
||||
if unmatched_entities:
|
||||
unmatched_by_type = defaultdict(list)
|
||||
for entity in unmatched_entities:
|
||||
entity_type = entity.get('entity_type', 'Unknown')
|
||||
unmatched_by_type[entity_type].append(entity)
|
||||
|
||||
for type_name in sorted(unmatched_by_type.keys()):
|
||||
entities_of_type = unmatched_by_type[type_name]
|
||||
print(f"\n📌 类型: {type_name} ({len(entities_of_type)} 个)")
|
||||
print(f" {'-'*50}")
|
||||
for entity in entities_of_type[:3]:
|
||||
name = entity.get('name', 'N/A')
|
||||
print(f" • {name}")
|
||||
if len(entities_of_type) > 3:
|
||||
print(f" ... 还有 {len(entities_of_type) - 3} 个")
|
||||
else:
|
||||
print(f"\n (所有实体都匹配本体类型)")
|
||||
|
||||
# 8. 统计摘要
|
||||
total_entities = len(entities)
|
||||
any_matched = total_entities - len(unmatched_entities)
|
||||
|
||||
print(f"\n{'='*70}")
|
||||
print(f"📊 统计摘要")
|
||||
print(f"{'='*70}")
|
||||
print(f"\n 基础统计:")
|
||||
print(f" {'-'*50}")
|
||||
print(f" 总实体数: {total_entities}")
|
||||
print(f" 场景本体类型数: {len(scene_ontology_types)}")
|
||||
print(f" 通用本体类型数: {len(general_ontology_types)}")
|
||||
|
||||
print(f"\n 匹配率统计:")
|
||||
print(f" {'-'*50}")
|
||||
scene_rate = total_scene_matched / total_entities * 100 if total_entities > 0 else 0
|
||||
general_rate = total_general_matched / total_entities * 100 if total_entities > 0 else 0
|
||||
any_rate = any_matched / total_entities * 100 if total_entities > 0 else 0
|
||||
unmatched_rate = len(unmatched_entities) / total_entities * 100 if total_entities > 0 else 0
|
||||
|
||||
print(f" 匹配场景类型: {total_scene_matched} 个 ({scene_rate:.1f}%)")
|
||||
print(f" 匹配通用类型: {total_general_matched} 个 ({general_rate:.1f}%)")
|
||||
print(f" 同时匹配两者: {len(both_matched_entities)} 个 ({len(both_matched_entities)/total_entities*100:.1f}%)")
|
||||
print(f" 仅匹配场景类型: {len(scene_matched_entities)} 个 ({len(scene_matched_entities)/total_entities*100:.1f}%)")
|
||||
print(f" 仅匹配通用类型: {len(general_matched_entities)} 个 ({len(general_matched_entities)/total_entities*100:.1f}%)")
|
||||
print(f" 匹配任一类型: {any_matched} 个 ({any_rate:.1f}%)")
|
||||
print(f" 未匹配任何类型: {len(unmatched_entities)} 个 ({unmatched_rate:.1f}%)")
|
||||
|
||||
# 9. 类型分布详情
|
||||
if scene_type_distribution:
|
||||
print(f"\n 场景类型分布 (Top 10):")
|
||||
print(f" {'-'*50}")
|
||||
sorted_scene_types = sorted(scene_type_distribution.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
for type_name, entities_list in sorted_scene_types[:10]:
|
||||
print(f" - {type_name}: {len(entities_list)} 个")
|
||||
|
||||
if general_type_distribution:
|
||||
print(f"\n 通用类型分布 (Top 10):")
|
||||
print(f" {'-'*50}")
|
||||
sorted_general_types = sorted(general_type_distribution.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
for type_name, entities_list in sorted_general_types[:10]:
|
||||
print(f" - {type_name}: {len(entities_list)} 个")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 查询出错: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
db.close()
|
||||
await connector.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("用法: python scripts/query_ontology_matched_entities.py <end_user_id> [config_id]")
|
||||
print("示例: python scripts/query_ontology_matched_entities.py 075660cf-08e6-40a6-a76e-308b6f52fbf1")
|
||||
print(" python scripts/query_ontology_matched_entities.py 075660cf-08e6-40a6-a76e-308b6f52fbf1 fd547bb9-7b9e-47ea-ae53-242d208a31a2")
|
||||
sys.exit(1)
|
||||
|
||||
end_user_id = sys.argv[1]
|
||||
config_id = sys.argv[2] if len(sys.argv) > 2 else None
|
||||
|
||||
asyncio.run(query_ontology_matched_entities(end_user_id, config_id))
|
||||
@@ -415,6 +415,9 @@ class MemoryConfig:
|
||||
pruning_scene: Optional[str] = "education"
|
||||
pruning_threshold: float = 0.5
|
||||
|
||||
# Ontology scene association
|
||||
scene_id: Optional[UUID] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
if not self.config_name or not self.config_name.strip():
|
||||
|
||||
@@ -330,6 +330,7 @@ class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用
|
||||
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
|
||||
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)")
|
||||
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
|
||||
custom_text: Optional[str] = Field(None, description="自定义输入文本,当配置关联本体场景时使用此字段进行试运行")
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||
|
||||
from app.core.memory.models.ontology_models import OntologyClass
|
||||
from app.core.memory.models.ontology_scenario_models import OntologyClass
|
||||
|
||||
|
||||
class ExtractionRequest(BaseModel):
|
||||
@@ -74,47 +74,51 @@ class ExtractionResponse(BaseModel):
|
||||
extracted_count: int = Field(..., description="提取的类数量")
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
"""OWL文件导出请求模型
|
||||
class ExportBySceneRequest(BaseModel):
|
||||
"""按场景导出OWL文件请求模型
|
||||
|
||||
用于POST /api/ontology/export端点的请求体。
|
||||
根据scene_id从数据库查询该场景下的所有本体类型并导出为OWL文件。
|
||||
|
||||
Attributes:
|
||||
classes: 要导出的本体类列表
|
||||
format: 导出格式,可选值: rdfxml, turtle, ntriples, json
|
||||
include_metadata: 是否包含完整的OWL元数据(命名空间等),默认True
|
||||
scene_id: 本体场景ID,必填,用于查询该场景下的所有类型
|
||||
format: 导出格式,可选值:rdfxml(默认)、turtle
|
||||
|
||||
Examples:
|
||||
>>> request = ExportRequest(
|
||||
... classes=[...],
|
||||
... format="rdfxml",
|
||||
... include_metadata=True
|
||||
>>> request = ExportBySceneRequest(
|
||||
... scene_id=UUID("550e8400-e29b-41d4-a716-446655440000"),
|
||||
... format="rdfxml"
|
||||
... )
|
||||
"""
|
||||
classes: List[OntologyClass] = Field(..., description="要导出的本体类列表", min_length=1)
|
||||
format: str = Field("rdfxml", description="导出格式: rdfxml, turtle, ntriples, json")
|
||||
include_metadata: bool = Field(True, description="是否包含完整的OWL元数据")
|
||||
scene_id: UUID = Field(..., description="本体场景ID")
|
||||
format: str = Field("rdfxml", description="导出格式,可选值:rdfxml(默认)、turtle")
|
||||
|
||||
|
||||
class ExportResponse(BaseModel):
|
||||
"""OWL文件导出响应模型
|
||||
class ExportBySceneResponse(BaseModel):
|
||||
"""按场景导出OWL文件响应模型
|
||||
|
||||
用于POST /api/ontology/export端点的响应体。
|
||||
|
||||
Attributes:
|
||||
owl_content: OWL文件内容
|
||||
format: 导出格式
|
||||
filename: 导出文件名(含扩展名)
|
||||
scene_id: 场景ID
|
||||
scene_name: 场景名称
|
||||
classes_count: 导出的类数量
|
||||
|
||||
Examples:
|
||||
>>> response = ExportResponse(
|
||||
>>> response = ExportBySceneResponse(
|
||||
... owl_content="<?xml version='1.0'?>...",
|
||||
... format="rdfxml",
|
||||
... filename="medical_ontology.owl",
|
||||
... scene_id=UUID("550e8400-e29b-41d4-a716-446655440000"),
|
||||
... scene_name="医疗场景",
|
||||
... classes_count=7
|
||||
... )
|
||||
"""
|
||||
owl_content: str = Field(..., description="OWL文件内容")
|
||||
format: str = Field(..., description="导出格式")
|
||||
filename: str = Field(..., description="导出文件名(含扩展名)")
|
||||
scene_id: UUID = Field(..., description="场景ID")
|
||||
scene_name: str = Field(..., description="场景名称")
|
||||
classes_count: int = Field(..., description="导出的类数量")
|
||||
|
||||
|
||||
@@ -459,3 +463,56 @@ class ClassListResponse(BaseModel):
|
||||
scene_name: str = Field(..., description="场景名称")
|
||||
scene_description: Optional[str] = Field(None, description="场景描述")
|
||||
items: List[ClassResponse] = Field(..., description="类型列表")
|
||||
|
||||
|
||||
# ==================== OWL 导入相关 Schema ====================
|
||||
|
||||
class ImportOwlRequest(BaseModel):
|
||||
"""OWL 文件导入请求
|
||||
|
||||
用于 POST /api/ontology/import 端点的请求体。
|
||||
解析 OWL 文件并将类型直接导入到指定场景。
|
||||
|
||||
Attributes:
|
||||
scene_id: 目标场景ID,必填
|
||||
owl_content: OWL 文件内容(字符串形式)
|
||||
format: 文件格式,可选值:rdfxml(默认)、turtle
|
||||
|
||||
Examples:
|
||||
>>> request = ImportOwlRequest(
|
||||
... scene_id=UUID("550e8400-e29b-41d4-a716-446655440000"),
|
||||
... owl_content="<?xml version='1.0'?>...",
|
||||
... format="rdfxml"
|
||||
... )
|
||||
"""
|
||||
scene_id: UUID = Field(..., description="目标场景ID")
|
||||
owl_content: str = Field(..., min_length=1, description="OWL 文件内容")
|
||||
format: str = Field("rdfxml", description="文件格式,可选值:rdfxml(默认)、turtle")
|
||||
|
||||
|
||||
class ImportOwlResponse(BaseModel):
|
||||
"""OWL 文件导入响应
|
||||
|
||||
用于返回导入结果。
|
||||
|
||||
Attributes:
|
||||
scene_id: 场景ID
|
||||
scene_name: 场景名称
|
||||
imported_count: 成功导入的类型数量
|
||||
skipped_count: 跳过的数量(重复)
|
||||
items: 导入的类型列表
|
||||
|
||||
Examples:
|
||||
>>> response = ImportOwlResponse(
|
||||
... scene_id=UUID("..."),
|
||||
... scene_name="智能制造场景",
|
||||
... imported_count=4,
|
||||
... skipped_count=0,
|
||||
... items=[...]
|
||||
... )
|
||||
"""
|
||||
scene_id: UUID = Field(..., description="场景ID")
|
||||
scene_name: str = Field(..., description="场景名称")
|
||||
imported_count: int = Field(..., description="成功导入的类型数量")
|
||||
skipped_count: int = Field(0, description="跳过的数量(重复)")
|
||||
items: List[ClassResponse] = Field(..., description="导入的类型列表")
|
||||
|
||||
@@ -303,6 +303,8 @@ class MemoryConfigService:
|
||||
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
# Ontology scene association
|
||||
scene_id=memory_config.scene_id,
|
||||
)
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
@@ -476,6 +478,43 @@ class MemoryConfigService:
|
||||
"pruning_threshold": memory_config.pruning_threshold,
|
||||
}
|
||||
|
||||
def get_ontology_types(self, memory_config: MemoryConfig):
|
||||
"""Fetch ontology types for the memory configuration's scene.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing scene_id
|
||||
|
||||
Returns:
|
||||
OntologyTypeList if scene_id is valid and has types, None otherwise
|
||||
"""
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
if not memory_config.scene_id:
|
||||
logger.debug("No scene_id configured, skipping ontology type fetch")
|
||||
return None
|
||||
|
||||
try:
|
||||
ontology_repo = OntologyClassRepository(self.db)
|
||||
ontology_classes = ontology_repo.get_by_scene(memory_config.scene_id)
|
||||
|
||||
if not ontology_classes:
|
||||
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
|
||||
return None
|
||||
|
||||
ontology_types = OntologyTypeList.from_db_models(ontology_classes)
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||
)
|
||||
return ontology_types
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
def get_workspace_default_config(
|
||||
self,
|
||||
workspace_id: UUID
|
||||
|
||||
@@ -280,12 +280,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
if not cid:
|
||||
raise ValueError("未提供 payload.config_id,禁止启动试运行")
|
||||
|
||||
# 验证 dialogue_text 必须提供
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
logger.info(f"[PILOT_RUN_STREAM] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
|
||||
if not dialogue_text:
|
||||
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
||||
|
||||
# Load configuration from database only using centralized manager
|
||||
try:
|
||||
config_service = MemoryConfigService(self.db)
|
||||
@@ -297,6 +291,30 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
except ConfigurationError as e:
|
||||
raise RuntimeError(f"Configuration loading failed: {e}")
|
||||
|
||||
# 根据是否关联本体场景选择使用的文本
|
||||
# 如果配置关联了本体场景(scene_id 不为空),使用 custom_text(如果提供)
|
||||
# 否则使用 dialogue_text
|
||||
if memory_config.scene_id:
|
||||
# 关联了本体场景,优先使用 custom_text
|
||||
if hasattr(payload, 'custom_text') and payload.custom_text:
|
||||
dialogue_text = payload.custom_text.strip()
|
||||
logger.info(f"[PILOT_RUN_STREAM] Using custom_text for scene_id={memory_config.scene_id}, length: {len(dialogue_text)}")
|
||||
else:
|
||||
# 如果没有提供 custom_text,回退到 dialogue_text
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
logger.info(f"[PILOT_RUN_STREAM] No custom_text provided, using dialogue_text for scene_id={memory_config.scene_id}")
|
||||
else:
|
||||
# 没有关联本体场景,使用 dialogue_text
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
logger.info(f"[PILOT_RUN_STREAM] No scene_id, using dialogue_text, length: {len(dialogue_text)}")
|
||||
|
||||
# 验证最终使用的文本不为空
|
||||
if not dialogue_text:
|
||||
raise ValueError("试运行模式必须提供有效的文本内容(dialogue_text 或 custom_text)")
|
||||
|
||||
logger.info(f"[PILOT_RUN_STREAM] Final text preview: {dialogue_text[:100]}")
|
||||
|
||||
|
||||
# 步骤 2: 创建进度回调函数捕获管线进度
|
||||
# 使用队列在回调和生成器之间传递进度事件
|
||||
progress_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
@@ -14,7 +14,7 @@ from typing import Any, Dict, List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.models.ontology_models import (
|
||||
from app.core.memory.models.ontology_scenario_models import (
|
||||
OntologyClass,
|
||||
OntologyExtractionResponse,
|
||||
)
|
||||
@@ -49,6 +49,10 @@ class OntologyService:
|
||||
DEFAULT_LLM_TIMEOUT = 30.0
|
||||
DEFAULT_ENABLE_OWL_VALIDATION = True
|
||||
|
||||
# 从环境变量获取默认语言
|
||||
from app.core.config import settings
|
||||
DEFAULT_LANGUAGE = settings.DEFAULT_LANGUAGE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
|
||||
@@ -142,6 +142,20 @@ async def run_pilot_extraction(
|
||||
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
|
||||
)
|
||||
|
||||
# 加载本体类型(如果配置了 scene_id),支持通用类型回退
|
||||
ontology_types = None
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_with_fallback
|
||||
|
||||
ontology_types = load_ontology_types_with_fallback(
|
||||
scene_id=memory_config.scene_id,
|
||||
workspace_id=memory_config.workspace_id,
|
||||
db=db,
|
||||
enable_general_fallback=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ontology types: {e}", exc_info=True)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
@@ -150,6 +164,7 @@ async def run_pilot_extraction(
|
||||
progress_callback=progress_callback,
|
||||
embedding_id=str(memory_config.embedding_model_id),
|
||||
language=language,
|
||||
ontology_types=ontology_types,
|
||||
)
|
||||
|
||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||
|
||||
108
api/app/tasks.py
108
api/app/tasks.py
@@ -1697,114 +1697,6 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
# Long-term Memory Storage Tasks (Batched Write Strategies)
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.long_term_storage.window", bind=True)
|
||||
def long_term_storage_window_task(
|
||||
self,
|
||||
end_user_id: str,
|
||||
langchain_messages: List[Dict[str, Any]],
|
||||
config_id: str,
|
||||
scope: int = 6
|
||||
) -> Dict[str, Any]:
|
||||
"""Celery task for window-based long-term memory storage.
|
||||
|
||||
Accumulates messages in Redis buffer until window size (scope) is reached,
|
||||
then writes batched messages to Neo4j.
|
||||
|
||||
Args:
|
||||
end_user_id: End user identifier
|
||||
langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
|
||||
config_id: Memory configuration ID
|
||||
scope: Window size (number of messages before triggering write)
|
||||
|
||||
Returns:
|
||||
Dict containing task status and metadata
|
||||
"""
|
||||
from app.core.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info(f"[LONG_TERM_WINDOW] Starting task - end_user_id={end_user_id}, scope={scope}")
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import window_dialogue
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
# Save to Redis buffer first
|
||||
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
||||
|
||||
# Get workspace_id from end_user for fallback
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
|
||||
workspace_id = None
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if end_user:
|
||||
app = db.query(App).filter(App.id == end_user.app_id).first()
|
||||
if app:
|
||||
workspace_id = app.workspace_id
|
||||
|
||||
# Load memory config with workspace fallback
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id,
|
||||
service_name="LongTermStorageTask"
|
||||
)
|
||||
|
||||
# Execute window-based dialogue storage
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
|
||||
return {"status": "SUCCESS", "strategy": "window", "scope": scope}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
logger.info(f"[LONG_TERM_WINDOW] Task completed - elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return {
|
||||
**result,
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"[LONG_TERM_WINDOW] Task failed - error={str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"strategy": "window",
|
||||
"error": str(e),
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
|
||||
# @celery_app.task(name="app.core.memory.agent.long_term_storage.time", bind=True)
|
||||
# def long_term_storage_time_task(
|
||||
# self,
|
||||
|
||||
@@ -1,4 +1,32 @@
|
||||
{
|
||||
"v0.2.3": {
|
||||
"introduction": {
|
||||
"codeName": "归墟",
|
||||
"releaseDate": "2026-2-6",
|
||||
"upgradePosition": "🐻 稳定性与细节打磨版本,万流归墟,静水流深",
|
||||
"coreUpgrades": [
|
||||
"1. 智能与记忆 🧠<br>* 提示词工程模块:新增专用提示词工程能力<br>* 长短期记忆整合:增强短期与长期记忆生命周期管理<br>* 双语记忆支持:解决情景记忆、显性记忆的双语问题",
|
||||
"2. 系统架构 ⚙️<br>* 反思任务调度器:新增 worker-periodic 容器<br>* 模型配置降级:记忆管理正确降级使用空间模型",
|
||||
"3. 问题修复 🔧<br>* 工作流分享:修复多轮对话产生多个conversation<br>* 流式输出:修复chat结尾缺少end标记<br>* 实体详情:移除未知类型记忆<br>* 提示词模板路径:修复jinja2路径解析错误<br>* 知识库字段:strategy更名为retrieve_type<br>* 空间头像:优化频繁调用模型接口<br>* 记忆仪表盘:修复end_users接口无返回",
|
||||
"<br>",
|
||||
"v0.2.4 将继续完善工作流代码执行功能,并推出本体工程+记忆配置入口。",
|
||||
"记忆熊,记得更牢,用得更好。🐻✨"
|
||||
]
|
||||
},
|
||||
"introduction_en": {
|
||||
"codeName": "Settle",
|
||||
"releaseDate": "2026-2-6",
|
||||
"upgradePosition": "🐻 Stability and refinement release — still waters run deep",
|
||||
"coreUpgrades": [
|
||||
"1. Intelligence & Memory 🧠<br>* Prompt Engineering Module: New dedicated prompt engineering capabilities<br>* Long-term & Short-term Memory Integration: Enhanced memory lifecycle management<br>* Bilingual Memory Support: Resolved dual-language issues in episodic and explicit memory",
|
||||
"2. System Architecture ⚙️<br>* Reflection Task Worker: Added worker-periodic container for scheduled tasks<br>* Model Configuration Fallback: Memory management properly falls back to workspace model",
|
||||
"3. Bug Fixes 🔧<br>* Workflow Sharing: Fixed multiple conversations created during multi-turn dialogues<br>* Streaming Output: Resolved missing end marker in chat streaming<br>* Entity Details: Removed unknown type memories from All view<br>* Prompt Template Paths: Fixed jinja2 path resolution errors<br>* Knowledge Base Schema: Renamed strategy to retrieve_type<br>* Workspace Avatar: Optimized frequent model API calls<br>* Memory Dashboard: Fixed end_users endpoint empty responses",
|
||||
"<br>",
|
||||
"v0.2.4 will continue with workflow code execution enhancements and the ontology engineering + memory configuration portal.",
|
||||
"MemoryBear — remember better, work smarter. 🐻✨"
|
||||
]
|
||||
}
|
||||
},
|
||||
"v0.2.2": {
|
||||
"introduction": {
|
||||
"codeName": "淬锋(Temper)",
|
||||
|
||||
@@ -129,3 +129,9 @@ KB_image2text_id=
|
||||
config_id=
|
||||
reranker_id=
|
||||
|
||||
# 本体类型融合配置 (记得写入env_example)
|
||||
GENERAL_ONTOLOGY_FILES=General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES=true # 总开关,控制是否启用通用本体类型融合功能(false = 不使用任何本体类型指导)
|
||||
MAX_ONTOLOGY_TYPES_IN_PROMPT=100 # 限制传给 LLM 的类型数量,防止 Prompt 过长
|
||||
CORE_GENERAL_TYPES=Person,Organization,Place,Event,Work,Concept # 定义核心类型列表,这些类型会优先包含在合并结果中
|
||||
ONTOLOGY_EXPERIMENT_MODE=true # 是否允许通过 API 动态切换本体配置
|
||||
@@ -141,6 +141,7 @@ dependencies = [
|
||||
"flower>=2.0.1",
|
||||
"aiofiles>=23.0.0",
|
||||
"owlready2>=0.46",
|
||||
"rdflib>=7.0.0",
|
||||
"lxml>=4.9.0",
|
||||
"httpx>=0.28.0",
|
||||
]
|
||||
|
||||
@@ -39,6 +39,7 @@ python-multipart>=0.0.20
|
||||
pyyaml==6.0.3
|
||||
redis==6.4.0
|
||||
rsa==4.9.1
|
||||
rdflib>=6.0.0
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
sqlalchemy==2.0.44
|
||||
|
||||
4
api/tests/workflow/__init__.py
Normal file
4
api/tests/workflow/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/5 15:36
|
||||
4
api/tests/workflow/executor/__init__.py
Normal file
4
api/tests/workflow/executor/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/6 14:45
|
||||
622
api/tests/workflow/executor/test_vairable_pool.py
Normal file
622
api/tests/workflow/executor/test_vairable_pool.py
Normal file
@@ -0,0 +1,622 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/6
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool, VariableSelector
|
||||
|
||||
|
||||
# ==================== VariableSelector 测试 ====================
|
||||
def test_variable_selector_from_string():
|
||||
"""测试从字符串创建变量选择器"""
|
||||
selector = VariableSelector.from_string("sys.message")
|
||||
|
||||
assert selector.namespace == "sys"
|
||||
assert selector.key == "message"
|
||||
assert selector.path == ["sys", "message"]
|
||||
|
||||
|
||||
def test_variable_selector_from_list():
|
||||
"""测试从列表创建变量选择器"""
|
||||
selector = VariableSelector(["conv", "username"])
|
||||
|
||||
assert selector.namespace == "conv"
|
||||
assert selector.key == "username"
|
||||
assert str(selector) == "conv.username"
|
||||
|
||||
|
||||
def test_variable_selector_empty_path():
|
||||
"""测试空路径抛出异常"""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
VariableSelector([])
|
||||
|
||||
assert "变量路径不能为空" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_variable_selector_single_element():
|
||||
"""测试单元素路径"""
|
||||
selector = VariableSelector(["sys"])
|
||||
|
||||
assert selector.namespace == "sys"
|
||||
assert selector.key is None
|
||||
|
||||
|
||||
# ==================== VariablePool 基础测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_new_variable():
|
||||
"""测试创建新变量"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("conv", "username", "Alice", VariableType.STRING, mut=True)
|
||||
|
||||
assert pool.has("conv.username")
|
||||
assert pool.get_value("conv.username") == "Alice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_new_multiple_variables():
|
||||
"""测试创建多个变量"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("conv", "name", "Bob", VariableType.STRING, mut=True)
|
||||
await pool.new("conv", "age", 25, VariableType.NUMBER, mut=True)
|
||||
await pool.new("conv", "active", True, VariableType.BOOLEAN, mut=True)
|
||||
|
||||
assert pool.get_value("conv.name") == "Bob"
|
||||
assert pool.get_value("conv.age") == 25
|
||||
assert pool.get_value("conv.active") is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_different_namespaces():
|
||||
"""测试不同命名空间的变量"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False)
|
||||
await pool.new("conv", "message", "World", VariableType.STRING, mut=True)
|
||||
await pool.new("node1", "output", "Result", VariableType.STRING, mut=False)
|
||||
|
||||
assert pool.get_value("sys.message") == "Hello"
|
||||
assert pool.get_value("conv.message") == "World"
|
||||
assert pool.get_value("node1.output") == "Result"
|
||||
|
||||
|
||||
# ==================== get_value 测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_value_with_template():
|
||||
"""测试使用模板语法获取值"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("conv", "test", "value", VariableType.STRING, mut=True)
|
||||
|
||||
# 支持模板语法
|
||||
assert pool.get_value("{{ conv.test }}") == "value"
|
||||
assert pool.get_value("{{conv.test}}") == "value"
|
||||
assert pool.get_value("{{ conv.test}}") == "value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_value_not_exist_strict():
|
||||
"""测试获取不存在的变量(严格模式)"""
|
||||
pool = VariablePool()
|
||||
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
pool.get_value("conv.nonexistent")
|
||||
|
||||
assert "not exist" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_value_not_exist_with_default():
|
||||
"""测试获取不存在的变量(使用默认值)"""
|
||||
pool = VariablePool()
|
||||
|
||||
result = pool.get_value("conv.nonexistent", default="default_value", strict=False)
|
||||
|
||||
assert result == "default_value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_value_different_types():
|
||||
"""测试获取不同类型的变量值"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("conv", "str", "text", VariableType.STRING, mut=True)
|
||||
await pool.new("conv", "num", 42, VariableType.NUMBER, mut=True)
|
||||
await pool.new("conv", "bool", False, VariableType.BOOLEAN, mut=True)
|
||||
await pool.new("conv", "arr", [1, 2, 3], VariableType.ARRAY_NUMBER, mut=True)
|
||||
await pool.new("conv", "obj", {"key": "value"}, VariableType.OBJECT, mut=True)
|
||||
|
||||
assert pool.get_value("conv.str") == "text"
|
||||
assert pool.get_value("conv.num") == 42
|
||||
assert pool.get_value("conv.bool") is False
|
||||
assert pool.get_value("conv.arr") == [1, 2, 3]
|
||||
assert pool.get_value("conv.obj") == {"key": "value"}
|
||||
|
||||
|
||||
# ==================== set 测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_mutable_variable():
|
||||
"""测试设置可变变量"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("conv", "counter", 0, VariableType.NUMBER, mut=True)
|
||||
await pool.set("conv.counter", 10)
|
||||
|
||||
assert pool.get_value("conv.counter") == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_immutable_variable():
|
||||
"""测试设置不可变变量(应该失败)"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("sys", "message", "original", VariableType.STRING, mut=False)
|
||||
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
await pool.set("sys.message", "modified")
|
||||
|
||||
assert "cannot be modified" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_nonexistent_variable():
|
||||
"""测试设置不存在的变量"""
|
||||
pool = VariablePool()
|
||||
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
await pool.set("conv.nonexistent", "value")
|
||||
|
||||
assert "is not defined" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_multiple_times():
|
||||
"""测试多次设置变量"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("conv", "value", "first", VariableType.STRING, mut=True)
|
||||
await pool.set("conv.value", "second")
|
||||
await pool.set("conv.value", "third")
|
||||
|
||||
assert pool.get_value("conv.value") == "third"
|
||||
|
||||
|
||||
# ==================== has 测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_existing_variable():
|
||||
"""测试检查存在的变量"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("conv", "test", "value", VariableType.STRING, mut=True)
|
||||
|
||||
assert pool.has("conv.test") is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_nonexistent_variable():
|
||||
"""测试检查不存在的变量"""
|
||||
pool = VariablePool()
|
||||
|
||||
assert pool.has("conv.nonexistent") is False
|
||||
|
||||
|
||||
# ==================== get_literal 测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_literal():
|
||||
"""测试获取变量的字面量表示"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("conv", "num", 42, VariableType.NUMBER, mut=True)
|
||||
|
||||
literal = pool.get_literal("conv.num")
|
||||
|
||||
assert isinstance(literal, str)
|
||||
|
||||
|
||||
# ==================== 命名空间操作测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_system_vars():
|
||||
"""测试获取所有系统变量"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False)
|
||||
await pool.new("sys", "user_id", "user123", VariableType.STRING, mut=False)
|
||||
await pool.new("conv", "other", "value", VariableType.STRING, mut=True)
|
||||
|
||||
sys_vars = pool.get_all_system_vars()
|
||||
|
||||
assert "message" in sys_vars
|
||||
assert "user_id" in sys_vars
|
||||
assert "other" not in sys_vars
|
||||
assert sys_vars["message"] == "Hello"
|
||||
assert sys_vars["user_id"] == "user123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_conversation_vars():
|
||||
"""测试获取所有会话变量"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("conv", "username", "Alice", VariableType.STRING, mut=True)
|
||||
await pool.new("conv", "score", 100, VariableType.NUMBER, mut=True)
|
||||
await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False)
|
||||
|
||||
conv_vars = pool.get_all_conversation_vars()
|
||||
|
||||
assert "username" in conv_vars
|
||||
assert "score" in conv_vars
|
||||
assert "message" not in conv_vars
|
||||
assert conv_vars["username"] == "Alice"
|
||||
assert conv_vars["score"] == 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_node_outputs():
|
||||
"""测试获取所有节点输出"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("node1", "output", "result1", VariableType.STRING, mut=False)
|
||||
await pool.new("node2", "output", "result2", VariableType.STRING, mut=False)
|
||||
await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False)
|
||||
await pool.new("conv", "var", "value", VariableType.STRING, mut=True)
|
||||
|
||||
node_outputs = pool.get_all_node_outputs()
|
||||
|
||||
assert "node1" in node_outputs
|
||||
assert "node2" in node_outputs
|
||||
assert "sys" not in node_outputs
|
||||
assert "conv" not in node_outputs
|
||||
assert node_outputs["node1"]["output"] == "result1"
|
||||
assert node_outputs["node2"]["output"] == "result2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_node_output():
|
||||
"""测试获取指定节点的输出"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("node1", "output", "result", VariableType.STRING, mut=False)
|
||||
await pool.new("node1", "status", "success", VariableType.STRING, mut=False)
|
||||
|
||||
node_output = pool.get_node_output("node1")
|
||||
|
||||
assert node_output["output"] == "result"
|
||||
assert node_output["status"] == "success"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_node_output_not_exist_strict():
|
||||
"""测试获取不存在的节点输出(严格模式)"""
|
||||
pool = VariablePool()
|
||||
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
pool.get_node_output("nonexistent_node")
|
||||
|
||||
assert "output not exist" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_node_output_not_exist_with_default():
|
||||
"""测试获取不存在的节点输出(使用默认值)"""
|
||||
pool = VariablePool()
|
||||
|
||||
result = pool.get_node_output("nonexistent_node", defalut=None, strict=False)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ==================== 复杂场景测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_new_existing_mutable():
|
||||
"""测试创建已存在的可变变量(应该更新值)"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("conv", "counter", 0, VariableType.NUMBER, mut=True)
|
||||
await pool.new("conv", "counter", 10, VariableType.NUMBER, mut=True)
|
||||
|
||||
assert pool.get_value("conv.counter") == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_new_existing_immutable():
|
||||
"""测试创建已存在的不可变变量(应该为新值)"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("sys", "message", "original", VariableType.STRING, mut=False)
|
||||
await pool.new("sys", "message", "modified", VariableType.STRING, mut=False)
|
||||
|
||||
# 不可变变量被更新
|
||||
assert pool.get_value("sys.message") == "modified"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_zero_and_false_values():
|
||||
"""测试零值和 False 值"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("conv", "zero", 0, VariableType.NUMBER, mut=True)
|
||||
await pool.new("conv", "false", False, VariableType.BOOLEAN, mut=True)
|
||||
await pool.new("conv", "empty_str", "", VariableType.STRING, mut=True)
|
||||
await pool.new("conv", "empty_arr", [], VariableType.ARRAY_NUMBER, mut=True)
|
||||
await pool.new("conv", "empty_obj", {}, VariableType.OBJECT, mut=True)
|
||||
|
||||
assert pool.get_value("conv.zero") == 0
|
||||
assert pool.get_value("conv.false") is False
|
||||
assert pool.get_value("conv.empty_str") == ""
|
||||
assert pool.get_value("conv.empty_arr") == []
|
||||
assert pool.get_value("conv.empty_obj") == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_nested_objects():
|
||||
"""测试嵌套对象"""
|
||||
pool = VariablePool()
|
||||
|
||||
nested_obj = {
|
||||
"user": {
|
||||
"name": "Alice",
|
||||
"age": 25,
|
||||
"address": {
|
||||
"city": "Beijing"
|
||||
}
|
||||
},
|
||||
"items": [1, 2, 3]
|
||||
}
|
||||
|
||||
await pool.new("conv", "data", nested_obj, VariableType.OBJECT, mut=True)
|
||||
|
||||
result = pool.get_value("conv.data")
|
||||
assert result["user"]["name"] == "Alice"
|
||||
assert result["user"]["address"]["city"] == "Beijing"
|
||||
assert result["items"] == [1, 2, 3]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_array_of_objects():
|
||||
"""测试对象数组"""
|
||||
pool = VariablePool()
|
||||
|
||||
users = [
|
||||
{"name": "Alice", "age": 25},
|
||||
{"name": "Bob", "age": 30}
|
||||
]
|
||||
|
||||
await pool.new("conv", "users", users, VariableType.ARRAY_OBJECT, mut=True)
|
||||
|
||||
result = pool.get_value("conv.users")
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "Alice"
|
||||
assert result[1]["age"] == 30
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_to_dict():
|
||||
"""测试导出为字典"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False)
|
||||
await pool.new("conv", "username", "Alice", VariableType.STRING, mut=True)
|
||||
await pool.new("node1", "output", "result", VariableType.STRING, mut=False)
|
||||
|
||||
result = pool.to_dict()
|
||||
|
||||
assert "system" in result
|
||||
assert "conversation" in result
|
||||
assert "nodes" in result
|
||||
assert result["system"]["message"] == "Hello"
|
||||
assert result["conversation"]["username"] == "Alice"
|
||||
assert result["nodes"]["node1"]["output"] == "result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_copy():
|
||||
"""测试复制变量池"""
|
||||
pool1 = VariablePool()
|
||||
|
||||
await pool1.new("conv", "test", "value", VariableType.STRING, mut=True)
|
||||
|
||||
pool2 = VariablePool()
|
||||
pool2.copy(pool1)
|
||||
|
||||
assert pool2.get_value("conv.test") == "value"
|
||||
|
||||
# 修改 pool2 不应影响 pool1
|
||||
await pool2.set("conv.test", "modified")
|
||||
assert pool2.get_value("conv.test") == "modified"
|
||||
assert pool1.get_value("conv.test") == "value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_repr():
|
||||
"""测试字符串表示"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("sys", "message", "Hello", VariableType.STRING, mut=False)
|
||||
await pool.new("conv", "username", "Alice", VariableType.STRING, mut=True)
|
||||
await pool.new("node1", "output", "result", VariableType.STRING, mut=False)
|
||||
|
||||
repr_str = repr(pool)
|
||||
|
||||
assert "VariablePool" in repr_str
|
||||
assert "system_vars=1" in repr_str
|
||||
assert "conversation_vars=1" in repr_str
|
||||
assert "runtime_vars=1" in repr_str
|
||||
|
||||
|
||||
# ==================== 并发测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_concurrent_set():
|
||||
"""测试并发设置变量"""
|
||||
import asyncio
|
||||
|
||||
pool = VariablePool()
|
||||
await pool.new("conv", "counter", 0, VariableType.NUMBER, mut=True)
|
||||
|
||||
async def increment():
|
||||
for _ in range(100):
|
||||
current = pool.get_value("conv.counter")
|
||||
await pool.set("conv.counter", current + 1)
|
||||
|
||||
# 并发执行多个增量操作
|
||||
await asyncio.gather(increment(), increment())
|
||||
|
||||
# 由于有锁保护,最终值应该是 200
|
||||
assert pool.get_value("conv.counter") == 200
|
||||
|
||||
|
||||
# ==================== 边界情况测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_empty():
|
||||
"""测试空变量池"""
|
||||
pool = VariablePool()
|
||||
|
||||
assert pool.get_all_system_vars() == {}
|
||||
assert pool.get_all_conversation_vars() == {}
|
||||
assert pool.get_all_node_outputs() == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_selector_invalid():
|
||||
"""测试无效的变量选择器"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("conv", "test", "value", VariableType.STRING, mut=True)
|
||||
|
||||
# 选择器格式错误
|
||||
with pytest.raises(ValueError):
|
||||
pool.get_value("conv.test.extra")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_special_characters():
|
||||
"""测试包含特殊字符的变量名"""
|
||||
pool = VariablePool()
|
||||
|
||||
# 变量名可以包含下划线、数字等
|
||||
await pool.new("conv", "user_name_123", "Alice", VariableType.STRING, mut=True)
|
||||
await pool.new("node_1", "output_data", "result", VariableType.STRING, mut=False)
|
||||
|
||||
assert pool.get_value("conv.user_name_123") == "Alice"
|
||||
assert pool.get_value("node_1.output_data") == "result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_large_data():
|
||||
"""测试大数据量"""
|
||||
pool = VariablePool()
|
||||
|
||||
# 创建大量变量
|
||||
for i in range(100):
|
||||
await pool.new("conv", f"var_{i}", i, VariableType.NUMBER, mut=True)
|
||||
|
||||
# 验证所有变量都存在
|
||||
for i in range(100):
|
||||
assert pool.get_value(f"conv.var_{i}") == i
|
||||
|
||||
conv_vars = pool.get_all_conversation_vars()
|
||||
assert len(conv_vars) == 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_different_types_same_name():
|
||||
"""测试不同命名空间中相同名称的变量"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("sys", "value", "system", VariableType.STRING, mut=False)
|
||||
await pool.new("conv", "value", "conversation", VariableType.STRING, mut=True)
|
||||
await pool.new("node1", "value", "node", VariableType.STRING, mut=False)
|
||||
|
||||
assert pool.get_value("sys.value") == "system"
|
||||
assert pool.get_value("conv.value") == "conversation"
|
||||
assert pool.get_value("node1.value") == "node"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_update_type():
|
||||
"""测试更新变量类型"""
|
||||
pool = VariablePool()
|
||||
|
||||
# 创建字符串变量
|
||||
await pool.new("conv", "data", "text", VariableType.STRING, mut=True)
|
||||
assert pool.get_value("conv.data") == "text"
|
||||
|
||||
# 更新为数字类型变量类型不可变
|
||||
with pytest.raises(TypeError):
|
||||
await pool.new("conv", "data", 123, VariableType.NUMBER, mut=True)
|
||||
assert pool.get_value("conv.data") == "text"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_array_types():
|
||||
"""测试不同类型的数组"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("conv", "arr_str", ["a", "b", "c"], VariableType.ARRAY_STRING, mut=True)
|
||||
await pool.new("conv", "arr_num", [1, 2, 3], VariableType.ARRAY_NUMBER, mut=True)
|
||||
await pool.new("conv", "arr_bool", [True, False], VariableType.ARRAY_BOOLEAN, mut=True)
|
||||
await pool.new("conv", "arr_obj", [{"id": 1}, {"id": 2}], VariableType.ARRAY_OBJECT, mut=True)
|
||||
|
||||
assert pool.get_value("conv.arr_str") == ["a", "b", "c"]
|
||||
assert pool.get_value("conv.arr_num") == [1, 2, 3]
|
||||
assert pool.get_value("conv.arr_bool") == [True, False]
|
||||
assert pool.get_value("conv.arr_obj") == [{"id": 1}, {"id": 2}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_namespace_isolation():
|
||||
"""测试命名空间隔离"""
|
||||
pool = VariablePool()
|
||||
|
||||
# 在不同命名空间创建变量
|
||||
await pool.new("sys", "var1", "sys_value", VariableType.STRING, mut=False)
|
||||
await pool.new("conv", "var2", "conv_value", VariableType.STRING, mut=True)
|
||||
await pool.new("node1", "var3", "node_value", VariableType.STRING, mut=False)
|
||||
|
||||
# 获取各命名空间的变量
|
||||
sys_vars = pool.get_all_system_vars()
|
||||
conv_vars = pool.get_all_conversation_vars()
|
||||
node_outputs = pool.get_all_node_outputs()
|
||||
|
||||
# 验证隔离性
|
||||
assert "var1" in sys_vars and "var2" not in sys_vars and "var3" not in sys_vars
|
||||
assert "var2" in conv_vars and "var1" not in conv_vars and "var3" not in conv_vars
|
||||
assert "node1" in node_outputs and "var3" in node_outputs["node1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_mutability_rules():
|
||||
"""测试可变性规则"""
|
||||
pool = VariablePool()
|
||||
|
||||
# 系统变量应该是不可变的
|
||||
await pool.new("sys", "immutable", "value", VariableType.STRING, mut=False)
|
||||
with pytest.raises(KeyError):
|
||||
await pool.set("sys.immutable", "new_value")
|
||||
|
||||
# 会话变量应该是可变的
|
||||
await pool.new("conv", "mutable", "value", VariableType.STRING, mut=True)
|
||||
await pool.set("conv.mutable", "new_value")
|
||||
assert pool.get_value("conv.mutable") == "new_value"
|
||||
|
||||
# 节点输出应该是不可变的
|
||||
await pool.new("node1", "output", "value", VariableType.STRING, mut=False)
|
||||
with pytest.raises(KeyError):
|
||||
await pool.set("node1.output", "new_value")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_pool_template_variations():
|
||||
"""测试模板语法的各种变体"""
|
||||
pool = VariablePool()
|
||||
|
||||
await pool.new("conv", "test", "value", VariableType.STRING, mut=True)
|
||||
|
||||
# 各种模板格式都应该工作
|
||||
assert pool.get_value("{{conv.test}}") == "value"
|
||||
assert pool.get_value("{{ conv.test }}") == "value"
|
||||
assert pool.get_value("{{ conv.test }}") == "value"
|
||||
assert pool.get_value("{{ conv.test}}") == "value"
|
||||
assert pool.get_value("{{conv.test }}") == "value"
|
||||
4
api/tests/workflow/nodes/__init__.py
Normal file
4
api/tests/workflow/nodes/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/6 14:43
|
||||
77
api/tests/workflow/nodes/base.py
Normal file
77
api/tests/workflow/nodes/base.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/5 18:19
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
TEST_WORKSPACE_ID = "test_workspace_id"
|
||||
TEST_USER_ID = "test_user_id"
|
||||
TEST_EXECUTION_ID = "test_execution_id"
|
||||
TEST_CONVERSATION_ID = "test_conversation_id"
|
||||
TEST_MODEL_ID = "" or os.getenv("TEST_MODEL_ID")
|
||||
TEST_FILE = {
|
||||
"type": "image",
|
||||
"url": "https://inews.gtimg.com/om_bt/Ojy0PdDIWWXRTAMh2QjsiumDZh-D1x7qCkDSmoaaX6INAAA/641",
|
||||
"__file": True
|
||||
}
|
||||
INPUT_DATA = {
|
||||
"message": "",
|
||||
"variables": [],
|
||||
"conversation_id": TEST_CONVERSATION_ID,
|
||||
"files": [TEST_FILE]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def global_precheck():
|
||||
assert bool(TEST_MODEL_ID) is True, 'PLASE SET TEST_MODEL_ID FIRST'
|
||||
|
||||
|
||||
def simple_state():
|
||||
return {
|
||||
"messages": [{"role": "user", "content": "123456"}],
|
||||
"node_outputs": {},
|
||||
"execution_id": TEST_EXECUTION_ID,
|
||||
"workspace_id": TEST_WORKSPACE_ID,
|
||||
"user_id": TEST_USER_ID,
|
||||
"error": None,
|
||||
"error_node": None,
|
||||
"cycle_nodes": [], # loop, iteration node id
|
||||
"looping": 0, # loop runing flag, only use in loop node,not use in main loop
|
||||
"activate": {}
|
||||
}
|
||||
|
||||
|
||||
async def simple_vairable_pool(message):
|
||||
# Initialize system variables (sys namespace)
|
||||
variable_pool = VariablePool()
|
||||
user_message = message
|
||||
user_files = INPUT_DATA.get("files") or []
|
||||
|
||||
# Initialize system variables (sys namespace)
|
||||
input_variables = INPUT_DATA.get("variables") or {}
|
||||
sys_vars = {
|
||||
"message": (user_message, VariableType.STRING),
|
||||
"conversation_id": (INPUT_DATA.get("conversation_id"), VariableType.STRING),
|
||||
"execution_id": (TEST_EXECUTION_ID, VariableType.STRING),
|
||||
"workspace_id": (TEST_WORKSPACE_ID, VariableType.STRING),
|
||||
"user_id": (TEST_USER_ID, VariableType.STRING),
|
||||
"input_variables": (input_variables, VariableType.OBJECT),
|
||||
"files": (user_files, VariableType.ARRAY_FILE)
|
||||
}
|
||||
for key, var_def in sys_vars.items():
|
||||
value = var_def[0]
|
||||
var_type = var_def[1]
|
||||
await variable_pool.new(
|
||||
namespace='sys',
|
||||
key=key,
|
||||
value=value,
|
||||
var_type=VariableType(var_type),
|
||||
mut=False
|
||||
)
|
||||
return variable_pool
|
||||
834
api/tests/workflow/nodes/test_assigner_node.py
Normal file
834
api/tests/workflow/nodes/test_assigner_node.py
Normal file
@@ -0,0 +1,834 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/5 18:54
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.nodes import AssignerNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from tests.workflow.nodes.base import simple_state, simple_vairable_pool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_number_add():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "add",
|
||||
"value": 3
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_number_subtract():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "subtract",
|
||||
"value": 3
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == -2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_number_multiply():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", 2, VariableType.NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "multiply",
|
||||
"value": 3
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_number_divide():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", 6, VariableType.NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "divide",
|
||||
"value": 2
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_number_assign():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "test1", 4, VariableType.NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "assign",
|
||||
"value": "{{conv.test1}}"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_number_cover():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "cover",
|
||||
"value": 4
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_number_clear():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "clear",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_number_append():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
|
||||
with pytest.raises(AttributeError) as exc_info:
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "append",
|
||||
"value": 3
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert "'NumberOperator' object has no attribute 'append'" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_number_remove_last():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
|
||||
with pytest.raises(AttributeError) as exc_info:
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "remove_last"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert "'NumberOperator' object has no attribute 'remove_last'" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_number_remove_first():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", 1, VariableType.NUMBER, mut=True)
|
||||
with pytest.raises(AttributeError) as exc_info:
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "remove_first"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert "'NumberOperator' object has no attribute 'remove_first'" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_array_append():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "append",
|
||||
"value": 3
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == [1, 2, 3]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_array_remove_last():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "remove_last"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == [1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_array_remove_first():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "remove_first"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == [2]
|
||||
|
||||
|
||||
# String tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_string_assign():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "assign",
|
||||
"value": "world"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == "world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_string_cover():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "cover",
|
||||
"value": "world"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == "world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_string_clear():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "clear"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_string_invalid_operation():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", "hello", VariableType.STRING, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "add",
|
||||
"value": "world"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
with pytest.raises(AttributeError) as exc_info:
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert "'StringOperator' object has no attribute 'add'" in str(exc_info.value)
|
||||
|
||||
|
||||
# Boolean tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_boolean_assign():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", True, VariableType.BOOLEAN, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "assign",
|
||||
"value": False
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_boolean_cover():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", False, VariableType.BOOLEAN, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "cover",
|
||||
"value": True
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_boolean_clear():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", True, VariableType.BOOLEAN, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "clear"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") is False
|
||||
|
||||
|
||||
# Object tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_object_assign():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", {"key": "value"}, VariableType.OBJECT, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "assign",
|
||||
"value": {"new_key": "new_value"}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == {"new_key": "new_value"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_object_cover():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", {"key": "value"}, VariableType.OBJECT, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "cover",
|
||||
"value": {"new_key": "new_value"}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == {"new_key": "new_value"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_object_clear():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", {"key": "value"}, VariableType.OBJECT, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "clear"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == {}
|
||||
|
||||
|
||||
# Array string tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_array_string_append():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", ["a", "b"], VariableType.ARRAY_STRING, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "append",
|
||||
"value": "c"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == ["a", "b", "c"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_array_string_clear():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", ["a", "b"], VariableType.ARRAY_STRING, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "clear"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_array_object_append():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", [{"id": 1}], VariableType.ARRAY_OBJECT, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "append",
|
||||
"value": {"id": 2}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == [{"id": 1}, {"id": 2}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_array_assign():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "assign",
|
||||
"value": [3, 4, 5]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == [3, 4, 5]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_array_cover():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", [1, 2], VariableType.ARRAY_NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "cover",
|
||||
"value": [3, 4, 5]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == [3, 4, 5]
|
||||
|
||||
|
||||
# Multiple assignments test
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_multiple_assignments():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test1", 10, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "test2", "hello", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "test3", [1, 2], VariableType.ARRAY_NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test1}}",
|
||||
"operation": "add",
|
||||
"value": 5
|
||||
},
|
||||
{
|
||||
"variable_selector": "{{conv.test2}}",
|
||||
"operation": "assign",
|
||||
"value": "world"
|
||||
},
|
||||
{
|
||||
"variable_selector": "{{conv.test3}}",
|
||||
"operation": "append",
|
||||
"value": 3
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test1") == 15
|
||||
assert variable_pool.get_value("conv.test2") == "world"
|
||||
assert variable_pool.get_value("conv.test3") == [1, 2, 3]
|
||||
|
||||
|
||||
# Variable reference test
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_variable_reference():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "source", 100, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "target", 0, VariableType.NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.target}}",
|
||||
"operation": "assign",
|
||||
"value": "{{conv.source}}"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.target") == 100
|
||||
|
||||
|
||||
# Edge cases
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_divide_by_zero():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", 10, VariableType.NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "divide",
|
||||
"value": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_invalid_namespace():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("sys", "test", 10, VariableType.NUMBER, mut=False)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{sys.test}}",
|
||||
"operation": "add",
|
||||
"value": 5
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert "Only conversation or cycle variables can be assigned" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_empty_array_operations():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", [], VariableType.ARRAY_NUMBER, mut=True)
|
||||
|
||||
# Test append on empty array
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "append",
|
||||
"value": 1
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == [1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_remove_from_single_element_array():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", [1], VariableType.ARRAY_NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "remove_last"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigner_float_operations():
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "test", 10.5, VariableType.NUMBER, mut=True)
|
||||
config = {
|
||||
"id": "assigner_test",
|
||||
"type": "assigner",
|
||||
"name": "赋值测试节点",
|
||||
"config": {
|
||||
"assignments": [
|
||||
{
|
||||
"variable_selector": "{{conv.test}}",
|
||||
"operation": "multiply",
|
||||
"value": 2.0
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
await AssignerNode(config, {}).execute(state, variable_pool)
|
||||
assert variable_pool.get_value("conv.test") == 21.0
|
||||
23
api/tests/workflow/nodes/test_breaker_node.py
Normal file
23
api/tests/workflow/nodes/test_breaker_node.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/5 19:15
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.nodes.breaker import BreakNode
|
||||
from tests.workflow.nodes.base import simple_state, simple_vairable_pool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_breaker():
|
||||
node_config = {
|
||||
"id": "breaker_test",
|
||||
"type": "breaker",
|
||||
"name": "breaker",
|
||||
"config": {
|
||||
}
|
||||
}
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await BreakNode(node_config, {}).execute(state, variable_pool)
|
||||
assert state["looping"] == 2
|
||||
279
api/tests/workflow/nodes/test_code.py
Normal file
279
api/tests/workflow/nodes/test_code.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/6 09:59
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.nodes.code import CodeNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from tests.workflow.nodes.base import simple_state, simple_vairable_pool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_python_complex_output():
|
||||
node_config = {
|
||||
"id": "code_test",
|
||||
"type": "code",
|
||||
"name": "代码执行",
|
||||
"config": {
|
||||
"code": "ZGVmJTIwbWFpbih4JTJDJTIweSklM0ElMEElMjAlMjAlMjAlMjByZXR1cm4lMjAlN0IlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJudW1iZXIlMjIlM0ElMjB4JTIwJTJCJTIweSUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMnN0cmluZyUyMiUzQSUyMHN0cih4JTIwJTJCJTIweSklMkMlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJib29sZWFuJTIyJTNBJTIwYm9vbCh4JTIwJTJCJTIweSklMkMlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJkaWN0JTIyJTNBJTIwJTdCJTIyc3VtJTIyJTNBJTIweCUyMCUyQiUyMHklN0QlMkMlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJhcnJheV9zdHJpbmclMjIlM0ElMjAlNUJzdHIoeCUyMCUyQiUyMHkpJTVEJTJDJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIyYXJyYXlfbnVtYmVyJTIyJTNBJTIwJTVCeCUyMCUyQiUyMHklNUQlMkMlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjJhcnJheV9vYmplY3QlMjIlM0ElMjAlNUIlN0IlMjJzdW0lMjIlM0ElMjB4JTIwJTJCJTIweSU3RCU1RCUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMmFycmF5X2Jvb2xlYW4lMjIlM0ElMjAlNUJib29sKHglMjAlMkIlMjB5KSU1RCUwQSUyMCUyMCUyMCUyMCU3RA==",
|
||||
"language": "python3",
|
||||
"input_variables": [
|
||||
{
|
||||
"name": "x",
|
||||
"variable": "{{conv.x}}"
|
||||
},
|
||||
{
|
||||
"name": "y",
|
||||
"variable": "{{conv.y}}"
|
||||
}
|
||||
],
|
||||
"output_variables": [
|
||||
{
|
||||
"name": "number",
|
||||
"type": VariableType.NUMBER
|
||||
},
|
||||
{
|
||||
"name": "string",
|
||||
"type": VariableType.STRING
|
||||
},
|
||||
{
|
||||
"name": "boolean",
|
||||
"type": VariableType.BOOLEAN
|
||||
},
|
||||
{
|
||||
"name": "dict",
|
||||
"type": VariableType.OBJECT
|
||||
},
|
||||
{
|
||||
"name": "array_string",
|
||||
"type": VariableType.ARRAY_STRING
|
||||
},
|
||||
{
|
||||
"name": "array_number",
|
||||
"type": VariableType.ARRAY_NUMBER
|
||||
},
|
||||
{
|
||||
"name": "array_object",
|
||||
"type": VariableType.ARRAY_OBJECT
|
||||
},
|
||||
{
|
||||
"name": "array_boolean",
|
||||
"type": VariableType.ARRAY_BOOLEAN
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True)
|
||||
result = await CodeNode(node_config, {}).execute(state, variable_pool)
|
||||
assert result == {'number': 3, 'string': '3', 'boolean': True, 'dict': {'sum': 3}, 'array_string': ['3'],
|
||||
'array_number': [3], 'array_object': [{'sum': 3}], 'array_boolean': [True]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_javascript_complex_output():
|
||||
node_config = {
|
||||
"id": "code_test",
|
||||
"type": "code",
|
||||
"name": "代码执行",
|
||||
"config": {
|
||||
"code": "ZnVuY3Rpb24gbWFpbih7eCwgeX0pIHsKICBjb25zdCBzdW0gPSB4ICsgeTsKCiAgcmV0dXJuIHsKICAgIG51bWJlcjogc3VtLAogICAgc3RyaW5nOiBTdHJpbmcoc3VtKSwKICAgIGJvb2xlYW46IEJvb2xlYW4oc3VtKSwKICAgIGRpY3Q6IHsgc3VtIH0sCiAgICBhcnJheV9zdHJpbmc6IFtTdHJpbmcoc3VtKV0sCiAgICBhcnJheV9udW1iZXI6IFtzdW1dLAogICAgYXJyYXlfb2JqZWN0OiBbeyBzdW0gfV0sCiAgICBhcnJheV9ib29sZWFuOiBbQm9vbGVhbihzdW0pXSwKICB9Owp9",
|
||||
"language": "javascript",
|
||||
"input_variables": [
|
||||
{
|
||||
"name": "x",
|
||||
"variable": "{{conv.x}}"
|
||||
},
|
||||
{
|
||||
"name": "y",
|
||||
"variable": "{{conv.y}}"
|
||||
}
|
||||
],
|
||||
"output_variables": [
|
||||
{
|
||||
"name": "number",
|
||||
"type": VariableType.NUMBER
|
||||
},
|
||||
{
|
||||
"name": "string",
|
||||
"type": VariableType.STRING
|
||||
},
|
||||
{
|
||||
"name": "boolean",
|
||||
"type": VariableType.BOOLEAN
|
||||
},
|
||||
{
|
||||
"name": "dict",
|
||||
"type": VariableType.OBJECT
|
||||
},
|
||||
{
|
||||
"name": "array_string",
|
||||
"type": VariableType.ARRAY_STRING
|
||||
},
|
||||
{
|
||||
"name": "array_number",
|
||||
"type": VariableType.ARRAY_NUMBER
|
||||
},
|
||||
{
|
||||
"name": "array_object",
|
||||
"type": VariableType.ARRAY_OBJECT
|
||||
},
|
||||
{
|
||||
"name": "array_boolean",
|
||||
"type": VariableType.ARRAY_BOOLEAN
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True)
|
||||
result = await CodeNode(node_config, {}).execute(state, variable_pool)
|
||||
assert result == {'number': 3, 'string': '3', 'boolean': True, 'dict': {'sum': 3}, 'array_string': ['3'],
|
||||
'array_number': [3], 'array_object': [{'sum': 3}], 'array_boolean': [True]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_python_operation_permissions():
|
||||
node_config = {
|
||||
"id": "code_test",
|
||||
"type": "code",
|
||||
"name": "代码执行",
|
||||
"config": {
|
||||
"code": "ZGVmJTIwbWFpbih4JTJDJTIweSklM0ElMEElMjAlMjAlMjAlMjBpbXBvcnQlMjBvcyUwQSUyMCUyMCUyMCUyMG9zLmdldGN3ZCgpJTBBJTIwJTIwJTIwJTIwcmV0dXJuJTIwJTdCJTIycmVzdWx0JTIyJTNBJTIweCUyMCUyQiUyMHklN0QlMEE=",
|
||||
"language": "python3",
|
||||
"input_variables": [
|
||||
{
|
||||
"name": "x",
|
||||
"variable": "{{conv.x}}"
|
||||
},
|
||||
{
|
||||
"name": "y",
|
||||
"variable": "{{conv.y}}"
|
||||
}
|
||||
],
|
||||
"output_variables": [
|
||||
{
|
||||
"name": "result",
|
||||
"type": "number"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True)
|
||||
with pytest.raises(RuntimeError, match="Operation not permitted"):
|
||||
await CodeNode(node_config, {}).execute(state, variable_pool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_javascript_operation_permissions():
|
||||
node_config = {
|
||||
"id": "code_test",
|
||||
"type": "code",
|
||||
"name": "代码执行",
|
||||
"config": {
|
||||
"code": "Y29uc29sZS5sb2cocHJvY2Vzcy5nZXRldWlkKCkpOw==",
|
||||
"language": "javascript",
|
||||
"input_variables": [
|
||||
{
|
||||
"name": "x",
|
||||
"variable": "{{conv.x}}"
|
||||
},
|
||||
{
|
||||
"name": "y",
|
||||
"variable": "{{conv.y}}"
|
||||
}
|
||||
],
|
||||
"output_variables": [
|
||||
{
|
||||
"name": "result",
|
||||
"type": "number"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True)
|
||||
with pytest.raises(RuntimeError, match="Operation not permitted"):
|
||||
await CodeNode(node_config, {}).execute(state, variable_pool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_python_run_error():
|
||||
node_config = {
|
||||
"id": "code_test",
|
||||
"type": "code",
|
||||
"name": "代码执行",
|
||||
"config": {
|
||||
"code": "ZGVmJTIwbWFpbih4JTJDJTIweSUzQSUwQSUyMCUyMCUyMCUyMHJldHVybiUyMCU3QiUyMnJlc3VsdCUyMiUzQSUyMHglMjAlMkIlMjB5JTdEJTBB",
|
||||
"language": "python3",
|
||||
"input_variables": [
|
||||
{
|
||||
"name": "x",
|
||||
"variable": "{{conv.x}}"
|
||||
},
|
||||
{
|
||||
"name": "y",
|
||||
"variable": "{{conv.y}}"
|
||||
}
|
||||
],
|
||||
"output_variables": [
|
||||
{
|
||||
"name": "result",
|
||||
"type": "number"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True)
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await CodeNode(node_config, {}).execute(state, variable_pool)
|
||||
assert "'(' was never closed" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_code_javascript_run_error():
|
||||
node_config = {
|
||||
"id": "code_test",
|
||||
"type": "code",
|
||||
"name": "代码执行",
|
||||
"config": {
|
||||
"code": "Y29uc29sZS5sb2co",
|
||||
"language": "javascript",
|
||||
"input_variables": [
|
||||
{
|
||||
"name": "x",
|
||||
"variable": "{{conv.x}}"
|
||||
},
|
||||
{
|
||||
"name": "y",
|
||||
"variable": "{{conv.y}}"
|
||||
}
|
||||
],
|
||||
"output_variables": [
|
||||
{
|
||||
"name": "result",
|
||||
"type": "number"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "y", 2, VariableType.NUMBER, mut=True)
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await CodeNode(node_config, {}).execute(state, variable_pool)
|
||||
assert "SyntaxError" in str(exc_info.value)
|
||||
42
api/tests/workflow/nodes/test_end_node.py
Normal file
42
api/tests/workflow/nodes/test_end_node.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/6 12:22
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.nodes import EndNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from tests.workflow.nodes.base import simple_state, simple_vairable_pool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_output():
|
||||
node_config = {
|
||||
"id": "end_test",
|
||||
"type": "end",
|
||||
"name": "end",
|
||||
"config": {
|
||||
"output": "{{conv.x}}{{sys.message}}"
|
||||
}
|
||||
}
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "x", 1, VariableType.NUMBER, mut=True)
|
||||
result = await EndNode(node_config, {}).execute(state, variable_pool)
|
||||
assert result == "1test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_output_miss():
|
||||
node_config = {
|
||||
"id": "end_test",
|
||||
"type": "end",
|
||||
"name": "end",
|
||||
"config": {
|
||||
"output": "{{conv.x}}{{sys.message}}"
|
||||
}
|
||||
}
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
result = await EndNode(node_config, {}).execute(state, variable_pool)
|
||||
assert result == "test"
|
||||
1127
api/tests/workflow/nodes/test_ifelse_node.py
Normal file
1127
api/tests/workflow/nodes/test_ifelse_node.py
Normal file
File diff suppressed because it is too large
Load Diff
889
api/tests/workflow/nodes/test_jinja_render_node.py
Normal file
889
api/tests/workflow/nodes/test_jinja_render_node.py
Normal file
@@ -0,0 +1,889 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/6
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.nodes import JinjaRenderNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from tests.workflow.nodes.base import simple_state, simple_vairable_pool
|
||||
|
||||
|
||||
# 基础模板渲染配置
|
||||
SIMPLE_TEMPLATE_CONFIG = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "Hello, {{ name }}!",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "name",
|
||||
"value": "conv.username"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 多变量模板配置
|
||||
MULTI_VARIABLE_CONFIG = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ greeting }}, {{ name }}! You are {{ age }} years old.",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "greeting",
|
||||
"value": "conv.greeting"
|
||||
},
|
||||
{
|
||||
"name": "name",
|
||||
"value": "conv.name"
|
||||
},
|
||||
{
|
||||
"name": "age",
|
||||
"value": "conv.age"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 条件渲染配置
|
||||
CONDITIONAL_TEMPLATE_CONFIG = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{% if is_admin %}Admin{% else %}User{% endif %}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "is_admin",
|
||||
"value": "conv.is_admin"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 循环渲染配置
|
||||
LOOP_TEMPLATE_CONFIG = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "items",
|
||||
"value": "conv.items"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 过滤器配置
|
||||
FILTER_TEMPLATE_CONFIG = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ text | upper }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "text",
|
||||
"value": "conv.text"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 对象属性访问配置
|
||||
OBJECT_TEMPLATE_CONFIG = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "Name: {{ user.name }}, Age: {{ user.age }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "user",
|
||||
"value": "conv.user"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 数学运算配置
|
||||
MATH_TEMPLATE_CONFIG = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ a }} + {{ b }} = {{ a + b }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "a",
|
||||
"value": "conv.a"
|
||||
},
|
||||
{
|
||||
"name": "b",
|
||||
"value": "conv.b"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 默认值配置
|
||||
DEFAULT_VALUE_CONFIG = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ name | default('Guest') }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "name",
|
||||
"value": "conv.name"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ==================== 基础模板渲染测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_simple_template():
|
||||
"""测试简单模板渲染"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "username", "Alice", VariableType.STRING, mut=True)
|
||||
|
||||
result = await JinjaRenderNode(SIMPLE_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
|
||||
assert result == "Hello, Alice!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_multi_variable():
|
||||
"""测试多变量模板渲染"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "greeting", "Hi", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "name", "Bob", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "age", 25, VariableType.NUMBER, mut=True)
|
||||
|
||||
result = await JinjaRenderNode(MULTI_VARIABLE_CONFIG, {}).execute(state, variable_pool)
|
||||
assert result == "Hi, Bob! You are 25 years old."
|
||||
|
||||
|
||||
# ==================== 条件渲染测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_conditional_true():
|
||||
"""测试条件渲染为真"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "is_admin", True, VariableType.BOOLEAN, mut=True)
|
||||
|
||||
result = await JinjaRenderNode(CONDITIONAL_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
|
||||
assert result == "Admin"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_conditional_false():
|
||||
"""测试条件渲染为假"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "is_admin", False, VariableType.BOOLEAN, mut=True)
|
||||
|
||||
result = await JinjaRenderNode(CONDITIONAL_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
|
||||
assert result == "User"
|
||||
|
||||
|
||||
# ==================== 循环渲染测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_loop_array():
|
||||
"""测试数组循环渲染"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "items", ["apple", "banana", "cherry"], VariableType.ARRAY_STRING, mut=True)
|
||||
|
||||
result = await JinjaRenderNode(LOOP_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
|
||||
assert result == "apple, banana, cherry"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_loop_empty_array():
|
||||
"""测试空数组循环渲染"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "items", [], VariableType.ARRAY_STRING, mut=True)
|
||||
|
||||
result = await JinjaRenderNode(LOOP_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
|
||||
assert result == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_loop_single_item():
|
||||
"""测试单元素数组循环渲染"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "items", ["apple"], VariableType.ARRAY_STRING, mut=True)
|
||||
|
||||
result = await JinjaRenderNode(LOOP_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
|
||||
assert result == "apple"
|
||||
|
||||
|
||||
# ==================== 过滤器测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_filter_upper():
|
||||
"""测试大写过滤器"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "text", "hello world", VariableType.STRING, mut=True)
|
||||
|
||||
result = await JinjaRenderNode(FILTER_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
|
||||
assert result == "HELLO WORLD"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_filter_lower():
|
||||
"""测试小写过滤器"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "text", "HELLO WORLD", VariableType.STRING, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ text | lower }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "text",
|
||||
"value": "conv.text"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "hello world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_filter_title():
|
||||
"""测试标题化过滤器"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "text", "hello world", VariableType.STRING, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ text | title }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "text",
|
||||
"value": "conv.text"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "Hello World"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_filter_length():
|
||||
"""测试长度过滤器"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "items", [1, 2, 3, 4, 5], VariableType.ARRAY_NUMBER, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "Length: {{ items | length }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "items",
|
||||
"value": "conv.items"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "Length: 5"
|
||||
|
||||
|
||||
# ==================== 对象属性访问测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_object_access():
|
||||
"""测试对象属性访问"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "user", {"name": "Alice", "age": 30}, VariableType.OBJECT, mut=True)
|
||||
|
||||
result = await JinjaRenderNode(OBJECT_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
|
||||
assert result == "Name: Alice, Age: 30"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_nested_object():
|
||||
"""测试嵌套对象访问"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "data", {
|
||||
"user": {
|
||||
"name": "Bob",
|
||||
"address": {
|
||||
"city": "Beijing"
|
||||
}
|
||||
}
|
||||
}, VariableType.OBJECT, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ data.user.name }} lives in {{ data.user.address.city }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "data",
|
||||
"value": "conv.data"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "Bob lives in Beijing"
|
||||
|
||||
|
||||
# ==================== 数学运算测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_math_addition():
|
||||
"""测试加法运算"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "a", 10, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "b", 20, VariableType.NUMBER, mut=True)
|
||||
|
||||
result = await JinjaRenderNode(MATH_TEMPLATE_CONFIG, {}).execute(state, variable_pool)
|
||||
assert result == "10 + 20 = 30"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_math_subtraction():
|
||||
"""测试减法运算"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "a", 30, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "b", 10, VariableType.NUMBER, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ a }} - {{ b }} = {{ a - b }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "a",
|
||||
"value": "conv.a"
|
||||
},
|
||||
{
|
||||
"name": "b",
|
||||
"value": "conv.b"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "30 - 10 = 20"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_math_multiplication():
|
||||
"""测试乘法运算"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "a", 5, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "b", 6, VariableType.NUMBER, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ a }} * {{ b }} = {{ a * b }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "a",
|
||||
"value": "conv.a"
|
||||
},
|
||||
{
|
||||
"name": "b",
|
||||
"value": "conv.b"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "5 * 6 = 30"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_math_division():
|
||||
"""测试除法运算"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "a", 20, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "b", 4, VariableType.NUMBER, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ a }} / {{ b }} = {{ a / b }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "a",
|
||||
"value": "conv.a"
|
||||
},
|
||||
{
|
||||
"name": "b",
|
||||
"value": "conv.b"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "20 / 4 = 5.0"
|
||||
|
||||
|
||||
# ==================== 默认值测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_default_value_missing():
|
||||
"""测试变量缺失时使用默认值"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
# 不创建 name 变量
|
||||
|
||||
result = await JinjaRenderNode(DEFAULT_VALUE_CONFIG, {}).execute(state, variable_pool)
|
||||
assert result == "Guest"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_default_value_present():
|
||||
"""测试变量存在时不使用默认值"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "name", "Alice", VariableType.STRING, mut=True)
|
||||
|
||||
result = await JinjaRenderNode(DEFAULT_VALUE_CONFIG, {}).execute(state, variable_pool)
|
||||
assert result == "Alice"
|
||||
|
||||
|
||||
# ==================== 字符串拼接测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_string_concatenation():
|
||||
"""测试字符串拼接"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "first", "Hello", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "second", "World", VariableType.STRING, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ first ~ ' ' ~ second }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "first",
|
||||
"value": "conv.first"
|
||||
},
|
||||
{
|
||||
"name": "second",
|
||||
"value": "conv.second"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "Hello World"
|
||||
|
||||
|
||||
# ==================== 比较运算测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_comparison():
|
||||
"""测试比较运算"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "score", 85, VariableType.NUMBER, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{% if score >= 90 %}A{% elif score >= 80 %}B{% elif score >= 70 %}C{% else %}D{% endif %}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "score",
|
||||
"value": "conv.score"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "B"
|
||||
|
||||
|
||||
# ==================== 数组操作测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_array_index():
|
||||
"""测试数组索引访问"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "items", ["first", "second", "third"], VariableType.ARRAY_STRING, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "First: {{ items[0] }}, Last: {{ items[-1] }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "items",
|
||||
"value": "conv.items"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "First: first, Last: third"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_array_slice():
|
||||
"""测试数组切片"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "numbers", [1, 2, 3, 4, 5], VariableType.ARRAY_NUMBER, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{% for n in numbers[1:4] %}{{ n }}{% endfor %}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "numbers",
|
||||
"value": "conv.numbers"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "234"
|
||||
|
||||
|
||||
# ==================== 复杂模板测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_complex_template():
|
||||
"""测试复杂模板"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "users", [
|
||||
{"name": "Alice", "age": 25},
|
||||
{"name": "Bob", "age": 30},
|
||||
{"name": "Charlie", "age": 35}
|
||||
], VariableType.ARRAY_OBJECT, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{% for user in users %}{{ user.name }} ({{ user.age }}){% if not loop.last %}, {% endif %}{% endfor %}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "users",
|
||||
"value": "conv.users"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "Alice (25), Bob (30), Charlie (35)"
|
||||
|
||||
|
||||
# ==================== 空值处理测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_empty_string():
|
||||
"""测试空字符串"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "text", "", VariableType.STRING, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{% if text %}{{ text }}{% else %}Empty{% endif %}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "text",
|
||||
"value": "conv.text"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "Empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_zero_value():
|
||||
"""测试零值"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "count", 0, VariableType.NUMBER, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "Count: {{ count }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "count",
|
||||
"value": "conv.count"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "Count: 0"
|
||||
|
||||
|
||||
# ==================== 特殊字符测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_special_characters():
|
||||
"""测试特殊字符"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "text", "Hello \"World\"", VariableType.STRING, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ text }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "text",
|
||||
"value": "conv.text"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "Hello \"World\""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_newline():
|
||||
"""测试换行符"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "line1", "First line", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "line2", "Second line", VariableType.STRING, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ line1 }}\n{{ line2 }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "line1",
|
||||
"value": "conv.line1"
|
||||
},
|
||||
{
|
||||
"name": "line2",
|
||||
"value": "conv.line2"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "First line\nSecond line"
|
||||
|
||||
|
||||
# ==================== 错误处理测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_invalid_template():
|
||||
"""测试无效模板语法"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "name", "Alice", VariableType.STRING, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ name", # 缺少闭合括号
|
||||
"mapping": [
|
||||
{
|
||||
"name": "name",
|
||||
"value": "conv.name"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert "render failed" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_undefined_variable_strict_false():
|
||||
"""测试未定义变量(非严格模式)"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
# 不创建任何变量
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "Hello, {{ undefined_var }}!",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "undefined_var",
|
||||
"value": "conv.undefined"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
# 非严格模式下,未定义变量会被渲染为空字符串
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "Hello, !"
|
||||
|
||||
|
||||
# ==================== 布尔值测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_boolean_true():
|
||||
"""测试布尔值 True"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "flag", True, VariableType.BOOLEAN, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "Flag is {{ flag }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "flag",
|
||||
"value": "conv.flag"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "Flag is True"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_boolean_false():
|
||||
"""测试布尔值 False"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "flag", False, VariableType.BOOLEAN, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "Flag is {{ flag }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "flag",
|
||||
"value": "conv.flag"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "Flag is False"
|
||||
|
||||
|
||||
# ==================== 浮点数测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_float_number():
|
||||
"""测试浮点数"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "price", 19.99, VariableType.NUMBER, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "Price: ${{ price }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "price",
|
||||
"value": "conv.price"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "Price: $19.99"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jinja_float_formatting():
|
||||
"""测试浮点数格式化"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "value", 3.14159, VariableType.NUMBER, mut=True)
|
||||
|
||||
config = {
|
||||
"id": "jinja_test",
|
||||
"type": "jinja-render",
|
||||
"name": "Jinja渲染测试节点",
|
||||
"config": {
|
||||
"template": "{{ '%.2f' | format(value) }}",
|
||||
"mapping": [
|
||||
{
|
||||
"name": "value",
|
||||
"value": "conv.value"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = await JinjaRenderNode(config, {}).execute(state, variable_pool)
|
||||
assert result == "3.14"
|
||||
145
api/tests/workflow/nodes/test_llm_node.py
Normal file
145
api/tests/workflow/nodes/test_llm_node.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/5 15:39
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.nodes import LLMNode
|
||||
from tests.workflow.nodes.base import TEST_MODEL_ID, simple_state, simple_vairable_pool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_memory_no_stream():
|
||||
node_config = {
|
||||
"id": "llm_test",
|
||||
"type": "llm",
|
||||
"name": "LLM 问答",
|
||||
"config": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业、友好且乐于助人的 AI 助手。"
|
||||
"你的职责:- "
|
||||
"准确理解用户的问题并提供有价值的回答"
|
||||
"- 保持回答的专业性和准确性"
|
||||
"- 如果不确定答案,诚实地告知用户"
|
||||
"- 使用清晰、易懂的语言进行交流"
|
||||
"回答风格:"
|
||||
"- 简洁明了,直击要点"
|
||||
"- 必要时提供详细解释和示例"
|
||||
"- 使用友好、礼貌的语气"
|
||||
"- 适当使用格式化(如列表、段落)提高可读性"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "{{ sys.message }}"
|
||||
}
|
||||
],
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000,
|
||||
"memory": {
|
||||
"enable": True,
|
||||
"enable_window": True,
|
||||
"window_size": 5
|
||||
},
|
||||
"vision": False,
|
||||
"vision_input": "{{sys.files}}"
|
||||
}
|
||||
}
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("输出上一句话")
|
||||
result = await LLMNode(node_config, {}).execute(state, variable_pool)
|
||||
assert '123456' in result.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_memory_stream():
|
||||
node_config = {
|
||||
"id": "llm_test",
|
||||
"type": "llm",
|
||||
"name": "LLM 问答",
|
||||
"config": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业、友好且乐于助人的 AI 助手。"
|
||||
"你的职责:- "
|
||||
"准确理解用户的问题并提供有价值的回答"
|
||||
"- 保持回答的专业性和准确性"
|
||||
"- 如果不确定答案,诚实地告知用户"
|
||||
"- 使用清晰、易懂的语言进行交流"
|
||||
"回答风格:"
|
||||
"- 简洁明了,直击要点"
|
||||
"- 必要时提供详细解释和示例"
|
||||
"- 使用友好、礼貌的语气"
|
||||
"- 适当使用格式化(如列表、段落)提高可读性"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "{{ sys.message }}"
|
||||
}
|
||||
],
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000,
|
||||
"memory": {
|
||||
"enable": True,
|
||||
"enable_window": True,
|
||||
"window_size": 5
|
||||
},
|
||||
"vision": False,
|
||||
"vision_input": "{{sys.files}}"
|
||||
}
|
||||
}
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("输出上一句话")
|
||||
async for event in LLMNode(node_config, {}).execute_stream(state, variable_pool):
|
||||
if event.get("__final__"):
|
||||
assert '123456' in event.get("result").content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_vision():
|
||||
node_config = {
|
||||
"id": "llm_test",
|
||||
"type": "llm",
|
||||
"name": "LLM 问答",
|
||||
"config": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业、友好且乐于助人的 AI 助手。"
|
||||
"你的职责:- "
|
||||
"准确理解用户的问题并提供有价值的回答"
|
||||
"- 保持回答的专业性和准确性"
|
||||
"- 如果不确定答案,诚实地告知用户"
|
||||
"- 使用清晰、易懂的语言进行交流"
|
||||
"回答风格:"
|
||||
"- 简洁明了,直击要点"
|
||||
"- 必要时提供详细解释和示例"
|
||||
"- 使用友好、礼貌的语气"
|
||||
"- 适当使用格式化(如列表、段落)提高可读性"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "{{ sys.message }}"
|
||||
}
|
||||
],
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000,
|
||||
"memory": {
|
||||
"enable": True,
|
||||
"enable_window": True,
|
||||
"window_size": 5
|
||||
},
|
||||
"vision": True,
|
||||
"vision_input": "{{sys.files}}"
|
||||
}
|
||||
}
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("图片里面有什么")
|
||||
async for event in LLMNode(node_config, {}).execute_stream(state, variable_pool):
|
||||
if event.get("__final__"):
|
||||
assert '花' in event.get("result").content
|
||||
504
api/tests/workflow/nodes/test_parameter_extractor_node.py
Normal file
504
api/tests/workflow/nodes/test_parameter_extractor_node.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/6 14:10
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.nodes import ParameterExtractorNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from tests.workflow.nodes.base import TEST_MODEL_ID, simple_state, simple_vairable_pool
|
||||
|
||||
|
||||
# 基础参数提取配置 - 单个字符串参数
|
||||
SINGLE_STRING_PARAM_CONFIG = {
|
||||
"id": "param_extractor_test",
|
||||
"type": "parameter-extractor",
|
||||
"name": "参数提取测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"text": "我的名字是张三,今年25岁",
|
||||
"params": [
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"desc": "用户的姓名",
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
"prompt": ""
|
||||
}
|
||||
}
|
||||
|
||||
# 多参数提取配置
|
||||
MULTI_PARAMS_CONFIG = {
|
||||
"id": "param_extractor_test",
|
||||
"type": "parameter-extractor",
|
||||
"name": "参数提取测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"text": "我的名字是李四,今年30岁,住在北京",
|
||||
"params": [
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"desc": "用户的姓名",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "age",
|
||||
"type": "number",
|
||||
"desc": "用户的年龄",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "city",
|
||||
"type": "string",
|
||||
"desc": "用户所在的城市",
|
||||
"required": False
|
||||
}
|
||||
],
|
||||
"prompt": ""
|
||||
}
|
||||
}
|
||||
|
||||
# 数字参数提取配置
|
||||
NUMBER_PARAM_CONFIG = {
|
||||
"id": "param_extractor_test",
|
||||
"type": "parameter-extractor",
|
||||
"name": "参数提取测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"text": "这个产品的价格是99.99元,库存有100件",
|
||||
"params": [
|
||||
{
|
||||
"name": "price",
|
||||
"type": "number",
|
||||
"desc": "产品价格",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "stock",
|
||||
"type": "number",
|
||||
"desc": "库存数量",
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
"prompt": ""
|
||||
}
|
||||
}
|
||||
|
||||
# 布尔参数提取配置
|
||||
BOOLEAN_PARAM_CONFIG = {
|
||||
"id": "param_extractor_test",
|
||||
"type": "parameter-extractor",
|
||||
"name": "参数提取测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"text": "这个用户已经完成了实名认证,但还没有绑定手机号",
|
||||
"params": [
|
||||
{
|
||||
"name": "verified",
|
||||
"type": "boolean",
|
||||
"desc": "是否完成实名认证",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "phone_bound",
|
||||
"type": "boolean",
|
||||
"desc": "是否绑定手机号",
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
"prompt": ""
|
||||
}
|
||||
}
|
||||
|
||||
# 数组参数提取配置
|
||||
ARRAY_STRING_PARAM_CONFIG = {
|
||||
"id": "param_extractor_test",
|
||||
"type": "parameter-extractor",
|
||||
"name": "参数提取测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"text": "我喜欢的水果有苹果、香蕉、橙子",
|
||||
"params": [
|
||||
{
|
||||
"name": "fruits",
|
||||
"type": "array[string]",
|
||||
"desc": "喜欢的水果列表",
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
"prompt": ""
|
||||
}
|
||||
}
|
||||
|
||||
# 数字数组参数提取配置
|
||||
ARRAY_NUMBER_PARAM_CONFIG = {
|
||||
"id": "param_extractor_test",
|
||||
"type": "parameter-extractor",
|
||||
"name": "参数提取测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"text": "这个月的销售额分别是:第一周10000,第二周12000,第三周15000,第四周18000",
|
||||
"params": [
|
||||
{
|
||||
"name": "weekly_sales",
|
||||
"type": "array[number]",
|
||||
"desc": "每周的销售额",
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
"prompt": ""
|
||||
}
|
||||
}
|
||||
|
||||
# 带自定义提示的配置
|
||||
CUSTOM_PROMPT_CONFIG = {
|
||||
"id": "param_extractor_test",
|
||||
"type": "parameter-extractor",
|
||||
"name": "参数提取测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"text": "订单号:ORD123456,金额:299元",
|
||||
"params": [
|
||||
{
|
||||
"name": "order_id",
|
||||
"type": "string",
|
||||
"desc": "订单编号",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "amount",
|
||||
"type": "number",
|
||||
"desc": "订单金额",
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
"prompt": "请仔细提取订单信息,确保订单号和金额准确无误"
|
||||
}
|
||||
}
|
||||
|
||||
# 使用变量的配置
|
||||
VARIABLE_INPUT_CONFIG = {
|
||||
"id": "param_extractor_test",
|
||||
"type": "parameter-extractor",
|
||||
"name": "参数提取测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"text": "{{ conv.user_input }}",
|
||||
"params": [
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"desc": "用户姓名",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "age",
|
||||
"type": "number",
|
||||
"desc": "用户年龄",
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
"prompt": ""
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ==================== 基础参数提取测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_single_string_param():
|
||||
"""测试提取单个字符串参数"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
result = await ParameterExtractorNode(SINGLE_STRING_PARAM_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "name" in result
|
||||
assert isinstance(result["name"], str)
|
||||
assert "张三" in result["name"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_multi_params():
|
||||
"""测试提取多个参数"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
result = await ParameterExtractorNode(MULTI_PARAMS_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "name" in result
|
||||
assert "age" in result
|
||||
assert "city" in result
|
||||
assert isinstance(result["name"], str)
|
||||
assert isinstance(result["age"], (int, float))
|
||||
assert "李四" in result["name"]
|
||||
assert result["age"] == 30
|
||||
assert "北京" in result["city"]
|
||||
|
||||
|
||||
# ==================== 数字参数提取测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_number_params():
|
||||
"""测试提取数字参数"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
result = await ParameterExtractorNode(NUMBER_PARAM_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "price" in result
|
||||
assert "stock" in result
|
||||
assert isinstance(result["price"], (int, float))
|
||||
assert isinstance(result["stock"], (int, float))
|
||||
assert abs(result["price"] - 99.99) < 0.1
|
||||
assert result["stock"] == 100
|
||||
|
||||
|
||||
# ==================== 布尔参数提取测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_boolean_params():
|
||||
"""测试提取布尔参数"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
result = await ParameterExtractorNode(BOOLEAN_PARAM_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "verified" in result
|
||||
assert "phone_bound" in result
|
||||
assert isinstance(result["verified"], bool)
|
||||
assert isinstance(result["phone_bound"], bool)
|
||||
assert result["verified"] is True
|
||||
assert result["phone_bound"] is False
|
||||
|
||||
|
||||
# ==================== 数组参数提取测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_array_string_param():
|
||||
"""测试提取字符串数组参数"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
result = await ParameterExtractorNode(ARRAY_STRING_PARAM_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "fruits" in result
|
||||
assert isinstance(result["fruits"], list)
|
||||
assert len(result["fruits"]) >= 3
|
||||
assert "苹果" in result["fruits"]
|
||||
assert "香蕉" in result["fruits"]
|
||||
assert "橙子" in result["fruits"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_array_number_param():
|
||||
"""测试提取数字数组参数"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
result = await ParameterExtractorNode(ARRAY_NUMBER_PARAM_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "weekly_sales" in result
|
||||
assert isinstance(result["weekly_sales"], list)
|
||||
assert len(result["weekly_sales"]) == 4
|
||||
assert 10000 in result["weekly_sales"]
|
||||
assert 12000 in result["weekly_sales"]
|
||||
assert 15000 in result["weekly_sales"]
|
||||
assert 18000 in result["weekly_sales"]
|
||||
|
||||
|
||||
# ==================== 自定义提示测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_with_custom_prompt():
|
||||
"""测试使用自定义提示提取参数"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
result = await ParameterExtractorNode(CUSTOM_PROMPT_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "order_id" in result
|
||||
assert "amount" in result
|
||||
assert "ORD123456" in result["order_id"]
|
||||
assert isinstance(result["amount"], (int, float))
|
||||
assert result["amount"] == 299
|
||||
|
||||
|
||||
# ==================== 变量输入测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_with_variable_input():
|
||||
"""测试使用变量作为输入文本"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "user_input", "我叫王五,今年28岁", VariableType.STRING, mut=True)
|
||||
|
||||
result = await ParameterExtractorNode(VARIABLE_INPUT_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "name" in result
|
||||
assert "age" in result
|
||||
assert "王五" in result["name"]
|
||||
assert result["age"] == 28
|
||||
|
||||
|
||||
# ==================== 复杂场景测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_from_complex_text():
|
||||
"""测试从复杂文本中提取参数"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "param_extractor_test",
|
||||
"type": "parameter-extractor",
|
||||
"name": "参数提取测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"text": """
|
||||
客户信息:
|
||||
姓名:赵六
|
||||
年龄:35岁
|
||||
职业:软件工程师
|
||||
城市:上海
|
||||
邮箱:zhaoliu@example.com
|
||||
是否VIP:是
|
||||
""",
|
||||
"params": [
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"desc": "客户姓名",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "age",
|
||||
"type": "number",
|
||||
"desc": "客户年龄",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "occupation",
|
||||
"type": "string",
|
||||
"desc": "客户职业",
|
||||
"required": False
|
||||
},
|
||||
{
|
||||
"name": "city",
|
||||
"type": "string",
|
||||
"desc": "所在城市",
|
||||
"required": False
|
||||
},
|
||||
{
|
||||
"name": "is_vip",
|
||||
"type": "boolean",
|
||||
"desc": "是否为VIP客户",
|
||||
"required": False
|
||||
}
|
||||
],
|
||||
"prompt": ""
|
||||
}
|
||||
}
|
||||
|
||||
result = await ParameterExtractorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "name" in result
|
||||
assert "age" in result
|
||||
assert "赵六" in result["name"]
|
||||
assert result["age"] == 35
|
||||
if "occupation" in result:
|
||||
assert "工程师" in result["occupation"]
|
||||
if "city" in result:
|
||||
assert "上海" in result["city"]
|
||||
if "is_vip" in result:
|
||||
assert result["is_vip"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_optional_params():
|
||||
"""测试提取可选参数"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "param_extractor_test",
|
||||
"type": "parameter-extractor",
|
||||
"name": "参数提取测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"text": "我叫小明",
|
||||
"params": [
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"desc": "用户姓名",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "age",
|
||||
"type": "number",
|
||||
"desc": "用户年龄",
|
||||
"required": False
|
||||
},
|
||||
{
|
||||
"name": "city",
|
||||
"type": "string",
|
||||
"desc": "所在城市",
|
||||
"required": False
|
||||
}
|
||||
],
|
||||
"prompt": ""
|
||||
}
|
||||
}
|
||||
|
||||
result = await ParameterExtractorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "name" in result
|
||||
assert "小明" in result["name"]
|
||||
# 可选参数可能不存在或为 None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_with_sys_message():
|
||||
"""测试使用系统消息变量"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("我叫小红,今年22岁")
|
||||
|
||||
config = {
|
||||
"id": "param_extractor_test",
|
||||
"type": "parameter-extractor",
|
||||
"name": "参数提取测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"text": "{{ sys.message }}",
|
||||
"params": [
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"desc": "用户姓名",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "age",
|
||||
"type": "number",
|
||||
"desc": "用户年龄",
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
"prompt": ""
|
||||
}
|
||||
}
|
||||
|
||||
result = await ParameterExtractorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "name" in result
|
||||
assert "age" in result
|
||||
assert "小红" in result["name"]
|
||||
assert result["age"] == 22
|
||||
647
api/tests/workflow/nodes/test_question_classifier_node.py
Normal file
647
api/tests/workflow/nodes/test_question_classifier_node.py
Normal file
@@ -0,0 +1,647 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/6
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.nodes import QuestionClassifierNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from tests.workflow.nodes.base import TEST_MODEL_ID, simple_state, simple_vairable_pool
|
||||
|
||||
|
||||
# 基础分类配置 - 两个类别
|
||||
BASIC_TWO_CATEGORIES_CONFIG = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "我想买一台笔记本电脑",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "产品咨询"
|
||||
},
|
||||
{
|
||||
"class_name": "售后服务"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
# 多类别配置
|
||||
MULTI_CATEGORIES_CONFIG = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "我的订单什么时候能到?",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "产品咨询"
|
||||
},
|
||||
{
|
||||
"class_name": "订单查询"
|
||||
},
|
||||
{
|
||||
"class_name": "售后服务"
|
||||
},
|
||||
{
|
||||
"class_name": "投诉建议"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
# 带补充提示的配置
|
||||
WITH_SUPPLEMENT_PROMPT_CONFIG = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "这个产品怎么样?",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "产品咨询"
|
||||
},
|
||||
{
|
||||
"class_name": "用户评价"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": "如果用户在询问产品信息或特性,归类为产品咨询;如果是评价或反馈,归类为用户评价"
|
||||
}
|
||||
}
|
||||
|
||||
# 使用变量的配置
|
||||
VARIABLE_INPUT_CONFIG = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "{{ conv.user_question }}",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "技术支持"
|
||||
},
|
||||
{
|
||||
"class_name": "账号问题"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
# 使用系统消息的配置
|
||||
SYS_MESSAGE_CONFIG = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "{{ sys.message }}",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "产品咨询"
|
||||
},
|
||||
{
|
||||
"class_name": "售后服务"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
# 空问题配置
|
||||
EMPTY_QUESTION_CONFIG = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "产品咨询"
|
||||
},
|
||||
{
|
||||
"class_name": "售后服务"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ==================== 基础分类测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_product_inquiry():
|
||||
"""测试产品咨询分类"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
result = await QuestionClassifierNode(BASIC_TWO_CATEGORIES_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "class_name" in result
|
||||
assert "output" in result
|
||||
assert result["class_name"] == "产品咨询"
|
||||
assert result["output"] == "CASE1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_after_sales():
|
||||
"""测试售后服务分类"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "我的产品坏了,怎么维修?",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "产品咨询"
|
||||
},
|
||||
{
|
||||
"class_name": "售后服务"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["class_name"] == "售后服务"
|
||||
assert result["output"] == "CASE2"
|
||||
|
||||
|
||||
# ==================== 多类别分类测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_order_inquiry():
|
||||
"""测试订单查询分类"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
result = await QuestionClassifierNode(MULTI_CATEGORIES_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["class_name"] == "订单查询"
|
||||
assert result["output"] == "CASE2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_complaint():
|
||||
"""测试投诉建议分类"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "你们的服务态度太差了!",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "产品咨询"
|
||||
},
|
||||
{
|
||||
"class_name": "订单查询"
|
||||
},
|
||||
{
|
||||
"class_name": "售后服务"
|
||||
},
|
||||
{
|
||||
"class_name": "投诉建议"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["class_name"] == "投诉建议"
|
||||
assert result["output"] == "CASE4"
|
||||
|
||||
|
||||
# ==================== 补充提示测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_with_supplement_prompt():
|
||||
"""测试使用补充提示进行分类"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
result = await QuestionClassifierNode(WITH_SUPPLEMENT_PROMPT_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "class_name" in result
|
||||
assert "output" in result
|
||||
assert result["class_name"] in ["产品咨询", "用户评价"]
|
||||
assert result["output"] in ["CASE1", "CASE2"]
|
||||
|
||||
|
||||
# ==================== 变量输入测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_with_conv_variable():
|
||||
"""测试使用 conv 变量作为输入"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
await variable_pool.new("conv", "user_question", "我忘记密码了", VariableType.STRING, mut=True)
|
||||
|
||||
result = await QuestionClassifierNode(VARIABLE_INPUT_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["class_name"] == "账号问题"
|
||||
assert result["output"] == "CASE2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_with_sys_message():
|
||||
"""测试使用系统消息变量"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("我想了解一下你们的产品功能")
|
||||
|
||||
result = await QuestionClassifierNode(SYS_MESSAGE_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["class_name"] == "产品咨询"
|
||||
assert result["output"] == "CASE1"
|
||||
|
||||
|
||||
# ==================== 边界情况测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_empty_question():
|
||||
"""测试空问题输入"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
result = await QuestionClassifierNode(EMPTY_QUESTION_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "class_name" in result
|
||||
assert "output" in result
|
||||
# 空问题应该返回默认分类(第一个分类)
|
||||
assert result["class_name"] == "产品咨询"
|
||||
assert result["output"] == "CASE1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_single_category():
|
||||
"""测试只有一个分类的情况"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "任何问题",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "通用咨询"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["class_name"] == "通用咨询"
|
||||
assert result["output"] == "CASE1"
|
||||
|
||||
|
||||
# ==================== 复杂场景测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_ambiguous_question():
|
||||
"""测试模糊问题分类"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "你好",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "产品咨询"
|
||||
},
|
||||
{
|
||||
"class_name": "售后服务"
|
||||
},
|
||||
{
|
||||
"class_name": "闲聊"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["class_name"] in ["产品咨询", "售后服务", "闲聊"]
|
||||
assert result["output"] in ["CASE1", "CASE2", "CASE3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_long_question():
|
||||
"""测试长问题分类"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "我在上个月购买了你们的产品,使用了一段时间后发现有一些问题,想咨询一下售后政策和维修流程,请问应该怎么办?",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "产品咨询"
|
||||
},
|
||||
{
|
||||
"class_name": "售后服务"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["class_name"] == "售后服务"
|
||||
assert result["output"] == "CASE2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_technical_support():
|
||||
"""测试技术支持分类"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "软件安装失败,报错代码0x80070005",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "技术支持"
|
||||
},
|
||||
{
|
||||
"class_name": "账号问题"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["class_name"] == "技术支持"
|
||||
assert result["output"] == "CASE1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_multiple_categories():
|
||||
"""测试多个类别的详细分类"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "我想申请退款",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "产品咨询"
|
||||
},
|
||||
{
|
||||
"class_name": "订单查询"
|
||||
},
|
||||
{
|
||||
"class_name": "退换货"
|
||||
},
|
||||
{
|
||||
"class_name": "售后服务"
|
||||
},
|
||||
{
|
||||
"class_name": "投诉建议"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["class_name"] == "退换货"
|
||||
assert result["output"] == "CASE3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_with_detailed_supplement():
|
||||
"""测试使用详细补充提示"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "这个功能怎么用?",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "产品使用"
|
||||
},
|
||||
{
|
||||
"class_name": "产品介绍"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": "如果用户询问如何使用某个功能,归类为产品使用;如果询问功能是什么或有什么功能,归类为产品介绍"
|
||||
}
|
||||
}
|
||||
|
||||
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["class_name"] == "产品使用"
|
||||
assert result["output"] == "CASE1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_chinese_categories():
|
||||
"""测试中文类别名称"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "我要投诉",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "咨询类"
|
||||
},
|
||||
{
|
||||
"class_name": "投诉类"
|
||||
},
|
||||
{
|
||||
"class_name": "建议类"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["class_name"] == "投诉类"
|
||||
assert result["output"] == "CASE2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_case_mapping():
|
||||
"""测试分类到 CASE 的映射关系"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "测试问题",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "类别A"
|
||||
},
|
||||
{
|
||||
"class_name": "类别B"
|
||||
},
|
||||
{
|
||||
"class_name": "类别C"
|
||||
},
|
||||
{
|
||||
"class_name": "类别D"
|
||||
},
|
||||
{
|
||||
"class_name": "类别E"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "class_name" in result
|
||||
assert "output" in result
|
||||
|
||||
# 验证 CASE 映射关系
|
||||
category_names = ["类别A", "类别B", "类别C", "类别D", "类别E"]
|
||||
if result["class_name"] in category_names:
|
||||
expected_case = f"CASE{category_names.index(result['class_name']) + 1}"
|
||||
assert result["output"] == expected_case
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_with_special_characters():
|
||||
"""测试包含特殊字符的问题"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "classifier_test",
|
||||
"type": "question-classifier",
|
||||
"name": "问题分类测试节点",
|
||||
"config": {
|
||||
"model_id": TEST_MODEL_ID,
|
||||
"input_variable": "产品价格是多少?有优惠吗?",
|
||||
"categories": [
|
||||
{
|
||||
"class_name": "价格咨询"
|
||||
},
|
||||
{
|
||||
"class_name": "促销活动"
|
||||
}
|
||||
],
|
||||
"system_prompt": "你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||
"user_prompt": "问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||
"user_supplement_prompt": None
|
||||
}
|
||||
}
|
||||
|
||||
result = await QuestionClassifierNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["class_name"] in ["价格咨询", "促销活动"]
|
||||
assert result["output"] in ["CASE1", "CASE2"]
|
||||
735
api/tests/workflow/nodes/test_start_node.py
Normal file
735
api/tests/workflow/nodes/test_start_node.py
Normal file
@@ -0,0 +1,735 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/6
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.nodes import StartNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from tests.workflow.nodes.base import (
|
||||
simple_state,
|
||||
simple_vairable_pool,
|
||||
TEST_EXECUTION_ID,
|
||||
TEST_WORKSPACE_ID,
|
||||
TEST_USER_ID,
|
||||
TEST_CONVERSATION_ID,
|
||||
TEST_FILE
|
||||
)
|
||||
|
||||
|
||||
async def create_variable_pool_with_inputs(message: str, input_variables: dict = None):
|
||||
"""创建带有自定义输入变量的变量池"""
|
||||
variable_pool = VariablePool()
|
||||
|
||||
sys_vars = {
|
||||
"message": (message, VariableType.STRING),
|
||||
"conversation_id": (TEST_CONVERSATION_ID, VariableType.STRING),
|
||||
"execution_id": (TEST_EXECUTION_ID, VariableType.STRING),
|
||||
"workspace_id": (TEST_WORKSPACE_ID, VariableType.STRING),
|
||||
"user_id": (TEST_USER_ID, VariableType.STRING),
|
||||
"input_variables": (input_variables or {}, VariableType.OBJECT),
|
||||
"files": ([TEST_FILE], VariableType.ARRAY_FILE)
|
||||
}
|
||||
|
||||
for key, var_def in sys_vars.items():
|
||||
value = var_def[0]
|
||||
var_type = var_def[1]
|
||||
await variable_pool.new(
|
||||
namespace='sys',
|
||||
key=key,
|
||||
value=value,
|
||||
var_type=VariableType(var_type),
|
||||
mut=False # 系统变量不可变
|
||||
)
|
||||
|
||||
return variable_pool
|
||||
|
||||
|
||||
# 基础配置 - 无自定义变量
|
||||
BASIC_CONFIG = {
|
||||
"id": "start_test",
|
||||
"type": "start",
|
||||
"name": "开始节点",
|
||||
"config": {
|
||||
"variables": []
|
||||
}
|
||||
}
|
||||
|
||||
# 带单个自定义变量的配置
|
||||
SINGLE_VARIABLE_CONFIG = {
|
||||
"id": "start_test",
|
||||
"type": "start",
|
||||
"name": "开始节点",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"name": "language",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"default": "zh-CN",
|
||||
"description": "语言设置"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 带多个自定义变量的配置
|
||||
MULTI_VARIABLES_CONFIG = {
|
||||
"id": "start_test",
|
||||
"type": "start",
|
||||
"name": "开始节点",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"name": "language",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"default": "zh-CN",
|
||||
"description": "语言设置"
|
||||
},
|
||||
{
|
||||
"name": "max_length",
|
||||
"type": "number",
|
||||
"required": False,
|
||||
"default": 1000,
|
||||
"description": "最大长度"
|
||||
},
|
||||
{
|
||||
"name": "enable_cache",
|
||||
"type": "boolean",
|
||||
"required": False,
|
||||
"default": True,
|
||||
"description": "是否启用缓存"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 带必需变量的配置
|
||||
REQUIRED_VARIABLE_CONFIG = {
|
||||
"id": "start_test",
|
||||
"type": "start",
|
||||
"name": "开始节点",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"name": "api_key",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "API密钥"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 混合必需和可选变量的配置
|
||||
MIXED_VARIABLES_CONFIG = {
|
||||
"id": "start_test",
|
||||
"type": "start",
|
||||
"name": "开始节点",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"name": "user_id",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "用户ID"
|
||||
},
|
||||
{
|
||||
"name": "timeout",
|
||||
"type": "number",
|
||||
"required": False,
|
||||
"default": 30,
|
||||
"description": "超时时间(秒)"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# 不同类型变量的配置
|
||||
DIFFERENT_TYPES_CONFIG = {
|
||||
"id": "start_test",
|
||||
"type": "start",
|
||||
"name": "开始节点",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"default": "default_name",
|
||||
"description": "名称"
|
||||
},
|
||||
{
|
||||
"name": "count",
|
||||
"type": "number",
|
||||
"required": False,
|
||||
"default": 0,
|
||||
"description": "计数"
|
||||
},
|
||||
{
|
||||
"name": "enabled",
|
||||
"type": "boolean",
|
||||
"required": False,
|
||||
"default": False,
|
||||
"description": "是否启用"
|
||||
},
|
||||
{
|
||||
"name": "tags",
|
||||
"type": "array[string]",
|
||||
"required": False,
|
||||
"default": [],
|
||||
"description": "标签列表"
|
||||
},
|
||||
{
|
||||
"name": "config",
|
||||
"type": "object",
|
||||
"required": False,
|
||||
"default": {},
|
||||
"description": "配置对象"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ==================== 基础功能测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_basic():
|
||||
"""测试基础 Start 节点(无自定义变量)"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test message")
|
||||
|
||||
result = await StartNode(BASIC_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "message" in result
|
||||
assert "execution_id" in result
|
||||
assert "conversation_id" in result
|
||||
assert "workspace_id" in result
|
||||
assert "user_id" in result
|
||||
assert result["message"] == "test message"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_system_variables():
|
||||
"""测试系统变量输出"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("hello world")
|
||||
|
||||
result = await StartNode(BASIC_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["message"] == "hello world"
|
||||
assert result["execution_id"] == state["execution_id"]
|
||||
assert result["workspace_id"] == state["workspace_id"]
|
||||
assert result["user_id"] == state["user_id"]
|
||||
|
||||
|
||||
# ==================== 自定义变量测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_single_variable_with_default():
|
||||
"""测试单个自定义变量使用默认值"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert "language" in result
|
||||
assert result["language"] == "zh-CN"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_single_variable_with_input():
|
||||
"""测试单个自定义变量使用输入值"""
|
||||
state = simple_state()
|
||||
|
||||
# 使用带输入变量的变量池
|
||||
input_vars = {"language": "en-US"}
|
||||
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
|
||||
|
||||
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert "language" in result
|
||||
assert result["language"] == "en-US"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_multi_variables_with_defaults():
|
||||
"""测试多个自定义变量使用默认值"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
result = await StartNode(MULTI_VARIABLES_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert "language" in result
|
||||
assert "max_length" in result
|
||||
assert "enable_cache" in result
|
||||
assert result["language"] == "zh-CN"
|
||||
assert result["max_length"] == 1000
|
||||
assert result["enable_cache"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_multi_variables_with_inputs():
|
||||
"""测试多个自定义变量使用输入值"""
|
||||
state = simple_state()
|
||||
|
||||
# 使用带输入变量的变量池
|
||||
input_vars = {
|
||||
"language": "ja-JP",
|
||||
"max_length": 2000,
|
||||
"enable_cache": False
|
||||
}
|
||||
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
|
||||
|
||||
result = await StartNode(MULTI_VARIABLES_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["language"] == "ja-JP"
|
||||
assert result["max_length"] == 2000
|
||||
assert result["enable_cache"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_partial_inputs():
|
||||
"""测试部分输入变量,其他使用默认值"""
|
||||
state = simple_state()
|
||||
|
||||
# 只设置部分输入变量
|
||||
input_vars = {
|
||||
"language": "fr-FR"
|
||||
}
|
||||
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
|
||||
|
||||
result = await StartNode(MULTI_VARIABLES_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["language"] == "fr-FR" # 使用输入值
|
||||
assert result["max_length"] == 1000 # 使用默认值
|
||||
assert result["enable_cache"] is True # 使用默认值
|
||||
|
||||
|
||||
# ==================== 必需变量测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_required_variable_provided():
|
||||
"""测试提供必需变量"""
|
||||
state = simple_state()
|
||||
|
||||
# 提供必需变量
|
||||
input_vars = {
|
||||
"api_key": "test_api_key_12345"
|
||||
}
|
||||
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
|
||||
|
||||
result = await StartNode(REQUIRED_VARIABLE_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert "api_key" in result
|
||||
assert result["api_key"] == "test_api_key_12345"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_required_variable_missing():
|
||||
"""测试缺少必需变量"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
# 不提供必需变量
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await StartNode(REQUIRED_VARIABLE_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert "缺少必需的输入变量" in str(exc_info.value)
|
||||
assert "api_key" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_mixed_variables():
|
||||
"""测试混合必需和可选变量"""
|
||||
state = simple_state()
|
||||
|
||||
# 只提供必需变量
|
||||
input_vars = {
|
||||
"user_id": "user_123"
|
||||
}
|
||||
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
|
||||
|
||||
result = await StartNode(MIXED_VARIABLES_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["user_id"] == "user_123" # 必需变量
|
||||
assert result["timeout"] == 30 # 可选变量使用默认值
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_mixed_variables_all_provided():
|
||||
"""测试混合变量全部提供"""
|
||||
state = simple_state()
|
||||
|
||||
# 提供所有变量
|
||||
input_vars = {
|
||||
"user_id": "user_456",
|
||||
"timeout": 60
|
||||
}
|
||||
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
|
||||
|
||||
result = await StartNode(MIXED_VARIABLES_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["user_id"] == "user_456"
|
||||
assert result["timeout"] == 60
|
||||
|
||||
|
||||
# ==================== 不同类型变量测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_different_types_defaults():
|
||||
"""测试不同类型变量的默认值"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
result = await StartNode(DIFFERENT_TYPES_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["name"] == "default_name"
|
||||
assert result["count"] == 0
|
||||
assert result["enabled"] is False
|
||||
assert result["tags"] == []
|
||||
assert result["config"] == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_different_types_inputs():
|
||||
"""测试不同类型变量的输入值"""
|
||||
state = simple_state()
|
||||
|
||||
# 提供不同类型的输入值
|
||||
input_vars = {
|
||||
"name": "custom_name",
|
||||
"count": 100,
|
||||
"enabled": True,
|
||||
"tags": ["tag1", "tag2", "tag3"],
|
||||
"config": {"key": "value", "nested": {"data": 123}}
|
||||
}
|
||||
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
|
||||
|
||||
result = await StartNode(DIFFERENT_TYPES_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["name"] == "custom_name"
|
||||
assert result["count"] == 100
|
||||
assert result["enabled"] is True
|
||||
assert result["tags"] == ["tag1", "tag2", "tag3"]
|
||||
assert result["config"] == {"key": "value", "nested": {"data": 123}}
|
||||
|
||||
|
||||
# ==================== 边界情况测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_empty_message():
|
||||
"""测试空消息"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("")
|
||||
|
||||
result = await StartNode(BASIC_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["message"] == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_no_input_variables():
|
||||
"""测试没有输入变量的情况"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
# 不设置 input_variables
|
||||
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
# 应该使用默认值
|
||||
assert result["language"] == "zh-CN"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_empty_input_variables():
|
||||
"""测试空的输入变量字典"""
|
||||
state = simple_state()
|
||||
|
||||
# 设置空的输入变量字典
|
||||
variable_pool = await create_variable_pool_with_inputs("test", {})
|
||||
|
||||
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
# 应该使用默认值
|
||||
assert result["language"] == "zh-CN"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_extra_input_variables():
|
||||
"""测试额外的输入变量(未在配置中定义)"""
|
||||
state = simple_state()
|
||||
|
||||
# 提供额外的未定义变量
|
||||
input_vars = {
|
||||
"language": "de-DE",
|
||||
"extra_var": "should_be_ignored"
|
||||
}
|
||||
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
|
||||
|
||||
result = await StartNode(SINGLE_VARIABLE_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["language"] == "de-DE"
|
||||
assert "extra_var" not in result # 额外变量不应该出现在输出中
|
||||
|
||||
|
||||
# ==================== 数组类型变量测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_array_string_variable():
|
||||
"""测试字符串数组变量"""
|
||||
state = simple_state()
|
||||
|
||||
config = {
|
||||
"id": "start_test",
|
||||
"type": "start",
|
||||
"name": "开始节点",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"name": "categories",
|
||||
"type": "array[string]",
|
||||
"required": False,
|
||||
"default": ["default1", "default2"],
|
||||
"description": "分类列表"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
input_vars = {
|
||||
"categories": ["cat1", "cat2", "cat3"]
|
||||
}
|
||||
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
|
||||
|
||||
result = await StartNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["categories"] == ["cat1", "cat2", "cat3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_array_number_variable():
|
||||
"""测试数字数组变量"""
|
||||
state = simple_state()
|
||||
|
||||
config = {
|
||||
"id": "start_test",
|
||||
"type": "start",
|
||||
"name": "开始节点",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"name": "scores",
|
||||
"type": "array[number]",
|
||||
"required": False,
|
||||
"default": [0, 0, 0],
|
||||
"description": "分数列表"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
input_vars = {
|
||||
"scores": [85, 90, 95]
|
||||
}
|
||||
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
|
||||
|
||||
result = await StartNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["scores"] == [85, 90, 95]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_array_object_variable():
|
||||
"""测试对象数组变量"""
|
||||
state = simple_state()
|
||||
|
||||
config = {
|
||||
"id": "start_test",
|
||||
"type": "start",
|
||||
"name": "开始节点",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"name": "users",
|
||||
"type": "array[object]",
|
||||
"required": False,
|
||||
"default": [],
|
||||
"description": "用户列表"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
input_vars = {
|
||||
"users": [
|
||||
{"name": "Alice", "age": 25},
|
||||
{"name": "Bob", "age": 30}
|
||||
]
|
||||
}
|
||||
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
|
||||
|
||||
result = await StartNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert len(result["users"]) == 2
|
||||
assert result["users"][0]["name"] == "Alice"
|
||||
assert result["users"][1]["age"] == 30
|
||||
|
||||
|
||||
# ==================== 复杂场景测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_complex_object():
|
||||
"""测试复杂对象变量"""
|
||||
state = simple_state()
|
||||
|
||||
config = {
|
||||
"id": "start_test",
|
||||
"type": "start",
|
||||
"name": "开始节点",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"name": "settings",
|
||||
"type": "object",
|
||||
"required": False,
|
||||
"default": {"theme": "light"},
|
||||
"description": "设置对象"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
input_vars = {
|
||||
"settings": {
|
||||
"theme": "dark",
|
||||
"language": "zh-CN",
|
||||
"notifications": {
|
||||
"email": True,
|
||||
"sms": False
|
||||
},
|
||||
"features": ["feature1", "feature2"]
|
||||
}
|
||||
}
|
||||
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
|
||||
|
||||
result = await StartNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["settings"]["theme"] == "dark"
|
||||
assert result["settings"]["language"] == "zh-CN"
|
||||
assert result["settings"]["notifications"]["email"] is True
|
||||
assert result["settings"]["features"] == ["feature1", "feature2"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_zero_and_false_values():
|
||||
"""测试零值和 False 值(确保不被当作空值)"""
|
||||
state = simple_state()
|
||||
|
||||
config = {
|
||||
"id": "start_test",
|
||||
"type": "start",
|
||||
"name": "开始节点",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"name": "count",
|
||||
"type": "number",
|
||||
"required": False,
|
||||
"default": 10,
|
||||
"description": "计数"
|
||||
},
|
||||
{
|
||||
"name": "enabled",
|
||||
"type": "boolean",
|
||||
"required": False,
|
||||
"default": True,
|
||||
"description": "是否启用"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
input_vars = {
|
||||
"count": 0,
|
||||
"enabled": False
|
||||
}
|
||||
variable_pool = await create_variable_pool_with_inputs("test", input_vars)
|
||||
|
||||
result = await StartNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
# 0 和 False 应该被正确识别,而不是使用默认值
|
||||
assert result["count"] == 0
|
||||
assert result["enabled"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_output_types():
|
||||
"""测试输出类型定义"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
node = StartNode(MULTI_VARIABLES_CONFIG, {})
|
||||
await node.execute(state, variable_pool)
|
||||
|
||||
output_types = node._output_types()
|
||||
|
||||
# 验证系统变量类型
|
||||
assert output_types["message"] == VariableType.STRING
|
||||
assert output_types["execution_id"] == VariableType.STRING
|
||||
assert output_types["conversation_id"] == VariableType.STRING
|
||||
assert output_types["workspace_id"] == VariableType.STRING
|
||||
assert output_types["user_id"] == VariableType.STRING
|
||||
|
||||
# 验证自定义变量类型
|
||||
assert output_types["language"] == VariableType.STRING
|
||||
assert output_types["max_length"] == VariableType.NUMBER
|
||||
assert output_types["enable_cache"] == VariableType.BOOLEAN
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_multiple_executions():
|
||||
"""测试多次执行 Start 节点"""
|
||||
state = simple_state()
|
||||
|
||||
node = StartNode(SINGLE_VARIABLE_CONFIG, {})
|
||||
|
||||
# 第一次执行
|
||||
variable_pool1 = await create_variable_pool_with_inputs("first message", {})
|
||||
result1 = await node.execute(state, variable_pool1)
|
||||
assert result1["message"] == "first message"
|
||||
assert result1["language"] == "zh-CN"
|
||||
|
||||
# 第二次执行(使用新的变量池)
|
||||
variable_pool2 = await create_variable_pool_with_inputs("second message", {})
|
||||
result2 = await node.execute(state, variable_pool2)
|
||||
assert result2["message"] == "second message"
|
||||
assert result2["language"] == "zh-CN"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_node_with_description():
|
||||
"""测试带描述的变量"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "start_test",
|
||||
"type": "start",
|
||||
"name": "开始节点",
|
||||
"config": {
|
||||
"variables": [
|
||||
{
|
||||
"name": "api_endpoint",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "API 端点 URL,用于连接外部服务"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 测试缺少必需变量时,错误信息包含描述
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await StartNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert "api_endpoint" in str(exc_info.value)
|
||||
assert "API 端点 URL" in str(exc_info.value)
|
||||
621
api/tests/workflow/nodes/test_variable_aggregator_node.py
Normal file
621
api/tests/workflow/nodes/test_variable_aggregator_node.py
Normal file
@@ -0,0 +1,621 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/6
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.nodes import VariableAggregatorNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from tests.workflow.nodes.base import simple_state, simple_vairable_pool
|
||||
|
||||
|
||||
# 非分组模式配置 - 返回第一个非空变量
|
||||
NON_GROUP_CONFIG = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": False,
|
||||
"group_variables": [
|
||||
"{{conv.var1}}",
|
||||
"{{conv.var2}}",
|
||||
"{{conv.var3}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 非分组模式配置 - 带类型定义
|
||||
NON_GROUP_WITH_TYPE_CONFIG = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": False,
|
||||
"group_variables": [
|
||||
"{{conv.var1}}",
|
||||
"{{conv.var2}}"
|
||||
],
|
||||
"group_type": {
|
||||
"output": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 分组模式配置
|
||||
GROUP_CONFIG = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": True,
|
||||
"group_variables": {
|
||||
"user_message": [
|
||||
"{{conv.msg1}}",
|
||||
"{{conv.msg2}}"
|
||||
],
|
||||
"user_name": [
|
||||
"{{conv.name1}}",
|
||||
"{{conv.name2}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 分组模式配置 - 带类型定义
|
||||
GROUP_WITH_TYPE_CONFIG = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": True,
|
||||
"group_variables": {
|
||||
"count": [
|
||||
"{{conv.count1}}",
|
||||
"{{conv.count2}}"
|
||||
],
|
||||
"enabled": [
|
||||
"{{conv.flag1}}",
|
||||
"{{conv.flag2}}"
|
||||
]
|
||||
},
|
||||
"group_type": {
|
||||
"count": "number",
|
||||
"enabled": "boolean"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ==================== 非分组模式测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_group_first_variable():
|
||||
"""测试非分组模式返回第一个非空变量"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
# 设置变量
|
||||
await variable_pool.new("conv", "var1", "first_value", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "var2", "second_value", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "var3", "third_value", VariableType.STRING, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(NON_GROUP_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result == "first_value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_group_skip_none():
|
||||
"""测试非分组模式跳过 None 值"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
# 第一个变量不存在,第二个存在
|
||||
await variable_pool.new("conv", "var2", "second_value", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "var3", "third_value", VariableType.STRING, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(NON_GROUP_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result == "second_value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_group_all_none():
|
||||
"""测试非分组模式所有变量都不存在"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
# 不创建任何变量
|
||||
result = await VariableAggregatorNode(NON_GROUP_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_group_with_type_all_none():
|
||||
"""测试非分组模式带类型定义,所有变量都不存在"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
# 不创建任何变量
|
||||
result = await VariableAggregatorNode(NON_GROUP_WITH_TYPE_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
# 应该返回类型的默认值
|
||||
assert result == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_group_different_types():
|
||||
"""测试非分组模式不同类型的变量"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": False,
|
||||
"group_variables": [
|
||||
"{{conv.num}}",
|
||||
"{{conv.str}}",
|
||||
"{{conv.bool}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 设置不同类型的变量
|
||||
await variable_pool.new("conv", "num", 123, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "str", "text", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "bool", True, VariableType.BOOLEAN, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert result == 123
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_group_zero_and_false():
|
||||
"""测试非分组模式零值和 False 值(不应被视为 None)"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": False,
|
||||
"group_variables": [
|
||||
"{{conv.zero}}",
|
||||
"{{conv.text}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 设置零值
|
||||
await variable_pool.new("conv", "zero", 0, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "text", "fallback", VariableType.STRING, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
# 0 不应被视为 None,应该返回 0
|
||||
assert result == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_group_false_value():
|
||||
"""测试非分组模式 False 值"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": False,
|
||||
"group_variables": [
|
||||
"{{conv.flag}}",
|
||||
"{{conv.text}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 设置 False 值
|
||||
await variable_pool.new("conv", "flag", False, VariableType.BOOLEAN, mut=True)
|
||||
await variable_pool.new("conv", "text", "fallback", VariableType.STRING, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
# False 不应被视为 None,应该返回 False
|
||||
assert result is False
|
||||
|
||||
|
||||
# ==================== 分组模式测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_mode_all_groups():
|
||||
"""测试分组模式所有分组都有值"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
# 设置变量
|
||||
await variable_pool.new("conv", "msg1", "Hello", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "name1", "Alice", VariableType.STRING, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(GROUP_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["user_message"] == "Hello"
|
||||
assert result["user_name"] == "Alice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_mode_fallback():
|
||||
"""测试分组模式使用备用变量"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
# 第一个变量不存在,使用第二个
|
||||
await variable_pool.new("conv", "msg2", "Fallback message", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "name2", "Bob", VariableType.STRING, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(GROUP_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["user_message"] == "Fallback message"
|
||||
assert result["user_name"] == "Bob"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_mode_partial_none():
|
||||
"""测试分组模式部分分组没有值"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
# 只设置一个分组的变量
|
||||
await variable_pool.new("conv", "msg1", "Hello", VariableType.STRING, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(GROUP_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["user_message"] == "Hello"
|
||||
assert result["user_name"] == "" # 没有值的分组返回空字符串
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_mode_all_none():
|
||||
"""测试分组模式所有分组都没有值"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
# 不创建任何变量
|
||||
result = await VariableAggregatorNode(GROUP_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["user_message"] == ""
|
||||
assert result["user_name"] == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_mode_with_type():
|
||||
"""测试分组模式带类型定义"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
# 设置变量
|
||||
await variable_pool.new("conv", "count1", 100, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "flag1", True, VariableType.BOOLEAN, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(GROUP_WITH_TYPE_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["count"] == 100
|
||||
assert result["enabled"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_mode_with_type_defaults():
|
||||
"""测试分组模式带类型定义,使用默认值"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
# 不创建任何变量
|
||||
result = await VariableAggregatorNode(GROUP_WITH_TYPE_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
# 应该返回类型的默认值
|
||||
assert result["count"] == 0 # number 的默认值
|
||||
assert result["enabled"] is False # boolean 的默认值
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_mode_mixed_values():
|
||||
"""测试分组模式混合有值和无值的情况"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
# 只设置 count2
|
||||
await variable_pool.new("conv", "count2", 200, VariableType.NUMBER, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(GROUP_WITH_TYPE_CONFIG, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["count"] == 200 # 使用第二个变量
|
||||
assert result["enabled"] is False # 没有值,使用默认值
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_mode_multiple_groups():
|
||||
"""测试分组模式多个分组"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": True,
|
||||
"group_variables": {
|
||||
"group1": ["{{conv.g1_v1}}", "{{conv.g1_v2}}"],
|
||||
"group2": ["{{conv.g2_v1}}", "{{conv.g2_v2}}"],
|
||||
"group3": ["{{conv.g3_v1}}", "{{conv.g3_v2}}"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 设置不同分组的变量
|
||||
await variable_pool.new("conv", "g1_v1", "value1", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "g2_v2", "value2", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "g3_v1", "value3", VariableType.STRING, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["group1"] == "value1"
|
||||
assert result["group2"] == "value2"
|
||||
assert result["group3"] == "value3"
|
||||
|
||||
|
||||
# ==================== 复杂场景测试 ====================
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregator_with_array():
|
||||
"""测试聚合数组变量"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": False,
|
||||
"group_variables": [
|
||||
"{{conv.arr1}}",
|
||||
"{{conv.arr2}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 设置数组变量
|
||||
await variable_pool.new("conv", "arr1", [1, 2, 3], VariableType.ARRAY_NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "arr2", [4, 5, 6], VariableType.ARRAY_NUMBER, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregator_with_object():
|
||||
"""测试聚合对象变量"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": False,
|
||||
"group_variables": [
|
||||
"{{conv.obj1}}",
|
||||
"{{conv.obj2}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 设置对象变量
|
||||
await variable_pool.new("conv", "obj1", {"key": "value1"}, VariableType.OBJECT, mut=True)
|
||||
await variable_pool.new("conv", "obj2", {"key": "value2"}, VariableType.OBJECT, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert result == {"key": "value1"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregator_empty_string():
|
||||
"""测试空字符串不被视为 None"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": False,
|
||||
"group_variables": [
|
||||
"{{conv.empty}}",
|
||||
"{{conv.text}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 设置空字符串
|
||||
await variable_pool.new("conv", "empty", "", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "text", "fallback", VariableType.STRING, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
# 空字符串不应被视为 None,应该返回空字符串
|
||||
assert result == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregator_empty_array():
|
||||
"""测试空数组不被视为 None"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": False,
|
||||
"group_variables": [
|
||||
"{{conv.empty_arr}}",
|
||||
"{{conv.arr}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 设置空数组
|
||||
await variable_pool.new("conv", "empty_arr", [], VariableType.ARRAY_NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "arr", [1, 2], VariableType.ARRAY_NUMBER, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
# 空数组不应被视为 None,应该返回空数组
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregator_empty_object():
|
||||
"""测试空对象不被视为 None"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": False,
|
||||
"group_variables": [
|
||||
"{{conv.empty_obj}}",
|
||||
"{{conv.obj}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 设置空对象
|
||||
await variable_pool.new("conv", "empty_obj", {}, VariableType.OBJECT, mut=True)
|
||||
await variable_pool.new("conv", "obj", {"key": "value"}, VariableType.OBJECT, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
# 空对象不应被视为 None,应该返回空对象
|
||||
assert result == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_mode_with_different_types():
|
||||
"""测试分组模式不同类型的变量"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": True,
|
||||
"group_variables": {
|
||||
"text": ["{{conv.str1}}", "{{conv.str2}}"],
|
||||
"number": ["{{conv.num1}}", "{{conv.num2}}"],
|
||||
"array": ["{{conv.arr1}}", "{{conv.arr2}}"],
|
||||
"object": ["{{conv.obj1}}", "{{conv.obj2}}"]
|
||||
},
|
||||
"group_type": {
|
||||
"text": "string",
|
||||
"number": "number",
|
||||
"array": "array[number]",
|
||||
"object": "object"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 设置不同类型的变量
|
||||
await variable_pool.new("conv", "str1", "hello", VariableType.STRING, mut=True)
|
||||
await variable_pool.new("conv", "num1", 42, VariableType.NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "arr1", [1, 2, 3], VariableType.ARRAY_NUMBER, mut=True)
|
||||
await variable_pool.new("conv", "obj1", {"key": "value"}, VariableType.OBJECT, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert result["text"] == "hello"
|
||||
assert result["number"] == 42
|
||||
assert result["array"] == [1, 2, 3]
|
||||
assert result["object"] == {"key": "value"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregator_output_types():
|
||||
"""测试输出类型定义"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
node = VariableAggregatorNode(GROUP_WITH_TYPE_CONFIG, {})
|
||||
|
||||
output_types = node._output_types()
|
||||
|
||||
assert output_types["count"] == VariableType.NUMBER
|
||||
assert output_types["enabled"] == VariableType.BOOLEAN
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_group_single_variable():
|
||||
"""测试非分组模式只有一个变量"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": False,
|
||||
"group_variables": [
|
||||
"{{conv.only_var}}"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
await variable_pool.new("conv", "only_var", "single_value", VariableType.STRING, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert result == "single_value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_mode_single_group():
|
||||
"""测试分组模式只有一个分组"""
|
||||
state = simple_state()
|
||||
variable_pool = await simple_vairable_pool("test")
|
||||
|
||||
config = {
|
||||
"id": "aggregator_test",
|
||||
"type": "var-aggregator",
|
||||
"name": "变量聚合测试节点",
|
||||
"config": {
|
||||
"group": True,
|
||||
"group_variables": {
|
||||
"only_group": ["{{conv.var1}}", "{{conv.var2}}"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await variable_pool.new("conv", "var1", "value", VariableType.STRING, mut=True)
|
||||
|
||||
result = await VariableAggregatorNode(config, {}).execute(state, variable_pool)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["only_group"] == "value"
|
||||
Submodule redbear-mem-benchmark updated: 4b0257bb4e...0c4bcafbc1
@@ -36,7 +36,7 @@ async def run_code(request: RunCodeRequest):
|
||||
elif request.language == "javascript":
|
||||
return await run_nodejs_code(request.code, request.preload, request.options)
|
||||
else:
|
||||
return error_response(-400, "unsupported language")
|
||||
return error_response(400, "unsupported language")
|
||||
|
||||
|
||||
@router.get("/dependencies", response_model=ApiResponse)
|
||||
@@ -45,7 +45,7 @@ async def get_dependencies(language: str):
|
||||
if language == "python3":
|
||||
return await list_python_dependencies()
|
||||
else:
|
||||
return error_response(-400, "unsupported language")
|
||||
return error_response(400, "unsupported language")
|
||||
|
||||
|
||||
@router.post("/dependencies/update", response_model=ApiResponse)
|
||||
@@ -54,4 +54,4 @@ async def update_dependencies(request: UpdateDependencyRequest):
|
||||
if request.language == "python3":
|
||||
return await update_python_dependencies()
|
||||
else:
|
||||
return error_response(-400, "unsupported language")
|
||||
return error_response(400, "unsupported language")
|
||||
|
||||
@@ -75,6 +75,4 @@ def success_response(data: Any) -> ApiResponse:
|
||||
|
||||
def error_response(code: int, message: str) -> ApiResponse:
|
||||
"""Create error response"""
|
||||
if code >= 0:
|
||||
code = -1
|
||||
return ApiResponse(code=code, message=message, data=None)
|
||||
|
||||
@@ -27,11 +27,11 @@ async def run_nodejs_code(code: str, preload: str, options: RunnerOptions):
|
||||
try:
|
||||
runner = NodejsRunner()
|
||||
result = await runner.run(code, options, preload)
|
||||
if result.exit_code == signal.SIGSYS + 0x80:
|
||||
if result.exit_code in [signal.SIGSYS + 0x80, -signal.SIGSYS]:
|
||||
return error_response(31, "sandbox security policy violation")
|
||||
|
||||
if result.exit_code != 0:
|
||||
return error_response(500, result.stderr)
|
||||
return error_response(result.exit_code, result.stderr)
|
||||
|
||||
return success_response(RunCodeResponse(
|
||||
stdout=result.stdout,
|
||||
@@ -39,5 +39,5 @@ async def run_nodejs_code(code: str, preload: str, options: RunnerOptions):
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Python execution failed: {e}", exc_info=True)
|
||||
return error_response(-500, str(e))
|
||||
logger.error(f"JavaScript execution failed: {e}", exc_info=True)
|
||||
return error_response(500, str(e))
|
||||
|
||||
@@ -47,7 +47,7 @@ async def run_python_code(code: str, preload: str, options: RunnerOptions):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Python execution failed: {e}", exc_info=True)
|
||||
return error_response(-500, str(e))
|
||||
return error_response(500, str(e))
|
||||
|
||||
|
||||
async def list_python_dependencies():
|
||||
|
||||
Reference in New Issue
Block a user