Merge branch 'develop' into feature/workflow_import_zy

This commit is contained in:
yingzhao
2026-03-03 10:16:59 +08:00
committed by GitHub
77 changed files with 3159 additions and 847 deletions

View File

@@ -82,7 +82,7 @@ celery_app.conf.update(
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'},
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
},
)
@@ -115,16 +115,11 @@ beat_schedule_config = {
"config_id": None, # 使用默认配置,可以通过环境变量配置
},
},
"write-all-workspaces-memory": {
"task": "app.tasks.write_all_workspaces_memory_task",
"schedule": memory_increment_schedule,
"args": (),
},
}
#如果配置了默认工作空间ID则添加记忆总量统计任务
if settings.DEFAULT_WORKSPACE_ID:
beat_schedule_config["write-total-memory"] = {
"task": "app.controllers.memory_storage_controller.search_all",
"schedule": memory_increment_schedule,
"kwargs": {
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
},
}
celery_app.conf.beat_schedule = beat_schedule_config

View File

@@ -0,0 +1 @@
"""Configuration module for application settings."""

View File

@@ -0,0 +1,239 @@
"""默认本体场景配置
本模块定义系统预设的本体场景和实体类型配置。
这些配置用于在工作空间创建时自动初始化默认场景。
支持中英文双语配置,根据用户语言偏好创建对应语言的场景。
"""
# 在线教育场景配置
ONLINE_EDUCATION_SCENE = {
"name_chinese": "在线教育",
"name_english": "Online Education",
"description_chinese": "适用于在线教育平台的本体建模,包含学生、教师、课程等核心实体类型",
"description_english": "Ontology modeling for online education platforms, including core entity types such as students, teachers, and courses",
"types": [
{
"name_chinese": "学生",
"name_english": "Student",
"description_chinese": "在教育系统中接受教育的个体,包含姓名、学号、年级、班级等属性",
"description_english": "Individuals receiving education in the education system, including attributes such as name, student ID, grade, and class"
},
{
"name_chinese": "教师",
"name_english": "Teacher",
"description_chinese": "在教育系统中提供教学服务的个体,包含姓名、工号、任教学科、职称等属性",
"description_english": "Individuals providing teaching services in the education system, including attributes such as name, employee ID, teaching subject, and title"
},
{
"name_chinese": "课程",
"name_english": "Course",
"description_chinese": "教育系统中的教学内容单元,包含课程名称、课程代码、学分、学时等属性",
"description_english": "Teaching content units in the education system, including attributes such as course name, course code, credits, and class hours"
},
{
"name_chinese": "作业",
"name_english": "Assignment",
"description_chinese": "课程中布置的学习任务,包含作业标题、截止日期、所属课程、提交状态等属性",
"description_english": "Learning tasks assigned in courses, including attributes such as assignment title, deadline, course, and submission status"
},
{
"name_chinese": "成绩",
"name_english": "Grade",
"description_chinese": "学生学习成果的评价结果,包含分数、评级、考试类型、所属课程等属性",
"description_english": "Evaluation results of student learning outcomes, including attributes such as score, rating, exam type, and course"
},
{
"name_chinese": "考试",
"name_english": "Exam",
"description_chinese": "评估学生学习成果的测试活动,包含考试名称、时间、地点、科目等属性",
"description_english": "Test activities to assess student learning outcomes, including attributes such as exam name, time, location, and subject"
},
{
"name_chinese": "教室",
"name_english": "Classroom",
"description_chinese": "进行教学活动的物理或虚拟空间,包含教室编号、容量、设备等属性",
"description_english": "Physical or virtual spaces for teaching activities, including attributes such as classroom number, capacity, and equipment"
},
{
"name_chinese": "学科",
"name_english": "Subject",
"description_chinese": "知识的分类领域,包含学科名称、代码、所属院系等属性",
"description_english": "Classification domains of knowledge, including attributes such as subject name, code, and department"
},
{
"name_chinese": "教材",
"name_english": "Textbook",
"description_chinese": "教学使用的书籍或资料包含书名、作者、出版社、ISBN等属性",
"description_english": "Books or materials used for teaching, including attributes such as title, author, publisher, and ISBN"
},
{
"name_chinese": "班级",
"name_english": "Class",
"description_chinese": "学生的组织单位,包含班级名称、年级、人数、班主任等属性",
"description_english": "Organizational units of students, including attributes such as class name, grade, number of students, and class teacher"
},
{
"name_chinese": "学期",
"name_english": "Semester",
"description_chinese": "教学时间的划分单位,包含学期名称、开始时间、结束时间等属性",
"description_english": "Time division units for teaching, including attributes such as semester name, start time, and end time"
},
{
"name_chinese": "课时",
"name_english": "Class Hour",
"description_chinese": "课程的时间单位,包含上课时间、地点、教师、课程等属性",
"description_english": "Time units of courses, including attributes such as class time, location, teacher, and course"
},
{
"name_chinese": "教学计划",
"name_english": "Teaching Plan",
"description_chinese": "课程的教学安排,包含教学目标、内容安排、进度计划等属性",
"description_english": "Teaching arrangements for courses, including attributes such as teaching objectives, content arrangement, and progress plan"
}
]
}
# 情感陪伴场景配置
EMOTIONAL_COMPANION_SCENE = {
"name_chinese": "情感陪伴",
"name_english": "Emotional Companion",
"description_chinese": "适用于情感陪伴应用的本体建模,包含用户、情绪、活动等核心实体类型",
"description_english": "Ontology modeling for emotional companion applications, including core entity types such as users, emotions, and activities",
"types": [
{
"name_chinese": "用户",
"name_english": "User",
"description_chinese": "使用情感陪伴服务的个体,包含姓名、昵称、性格特征、偏好等属性",
"description_english": "Individuals using emotional companion services, including attributes such as name, nickname, personality traits, and preferences"
},
{
"name_chinese": "情绪",
"name_english": "Emotion",
"description_chinese": "用户的情感状态,包含情绪类型、强度、触发原因、持续时间等属性",
"description_english": "Emotional states of users, including attributes such as emotion type, intensity, trigger cause, and duration"
},
{
"name_chinese": "活动",
"name_english": "Activity",
"description_chinese": "用户参与的各类活动,包含活动名称、类型、参与者、时间地点等属性",
"description_english": "Various activities users participate in, including attributes such as activity name, type, participants, time, and location"
},
{
"name_chinese": "对话",
"name_english": "Conversation",
"description_chinese": "用户之间的交流记录,包含对话主题、参与者、时间、关键内容等属性",
"description_english": "Communication records between users, including attributes such as conversation topic, participants, time, and key content"
},
{
"name_chinese": "兴趣爱好",
"name_english": "Hobby",
"description_chinese": "用户的兴趣和爱好,包含爱好名称、类别、熟练程度、相关活动等属性",
"description_english": "User interests and hobbies, including attributes such as hobby name, category, proficiency level, and related activities"
},
{
"name_chinese": "日常事件",
"name_english": "Daily Event",
"description_chinese": "用户日常生活中的事件,包含事件描述、时间、地点、相关人物等属性",
"description_english": "Events in users' daily lives, including attributes such as event description, time, location, and related people"
},
{
"name_chinese": "关系",
"name_english": "Relationship",
"description_chinese": "用户之间的社会关系,包含关系类型、亲密度、建立时间等属性",
"description_english": "Social relationships between users, including attributes such as relationship type, intimacy, and establishment time"
},
{
"name_chinese": "回忆",
"name_english": "Memory",
"description_chinese": "用户的重要记忆片段,包含回忆内容、时间、地点、相关人物等属性",
"description_english": "Important memory fragments of users, including attributes such as memory content, time, location, and related people"
},
{
"name_chinese": "地点",
"name_english": "Location",
"description_chinese": "用户活动的地理位置,包含地点名称、地址、类型、相关事件等属性",
"description_english": "Geographic locations of user activities, including attributes such as location name, address, type, and related events"
},
{
"name_chinese": "时间节点",
"name_english": "Time Point",
"description_chinese": "重要的时间标记,包含日期、事件、意义等属性",
"description_english": "Important time markers, including attributes such as date, event, and significance"
},
{
"name_chinese": "目标",
"name_english": "Goal",
"description_chinese": "用户设定的目标,包含目标描述、截止时间、完成状态、相关活动等属性",
"description_english": "Goals set by users, including attributes such as goal description, deadline, completion status, and related activities"
},
{
"name_chinese": "成就",
"name_english": "Achievement",
"description_chinese": "用户获得的成就,包含成就名称、获得时间、描述、相关目标等属性",
"description_english": "Achievements obtained by users, including attributes such as achievement name, acquisition time, description, and related goals"
}
]
}
# 导出默认场景列表
DEFAULT_SCENES = [ONLINE_EDUCATION_SCENE, EMOTIONAL_COMPANION_SCENE]
def get_scene_name(scene_config: dict, language: str = "zh") -> str:
"""获取场景名称(根据语言)
Args:
scene_config: 场景配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的场景名称
"""
if language == "en":
return scene_config.get("name_english", scene_config.get("name_chinese"))
return scene_config.get("name_chinese")
def get_scene_description(scene_config: dict, language: str = "zh") -> str:
"""获取场景描述(根据语言)
Args:
scene_config: 场景配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的场景描述
"""
if language == "en":
return scene_config.get("description_english", scene_config.get("description_chinese"))
return scene_config.get("description_chinese")
def get_type_name(type_config: dict, language: str = "zh") -> str:
"""获取类型名称(根据语言)
Args:
type_config: 类型配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的类型名称
"""
if language == "en":
return type_config.get("name_english", type_config.get("name_chinese"))
return type_config.get("name_chinese")
def get_type_description(type_config: dict, language: str = "zh") -> str:
"""获取类型描述(根据语言)
Args:
type_config: 类型配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的类型描述
"""
if language == "en":
return type_config.get("description_english", type_config.get("description_chinese"))
return type_config.get("description_chinese")

View File

@@ -0,0 +1,249 @@
# -*- coding: utf-8 -*-
"""默认本体场景初始化器
本模块提供默认本体场景和类型的自动初始化功能。
在工作空间创建时,自动添加预设的本体场景和实体类型。
Classes:
DefaultOntologyInitializer: 默认本体场景初始化器
"""
import logging
from typing import List, Optional, Tuple
from uuid import UUID
from sqlalchemy.orm import Session
from app.config.default_ontology_config import (
DEFAULT_SCENES,
get_scene_name,
get_scene_description,
get_type_name,
get_type_description,
)
from app.core.logging_config import get_business_logger
from app.repositories.ontology_scene_repository import OntologySceneRepository
from app.repositories.ontology_class_repository import OntologyClassRepository
class DefaultOntologyInitializer:
"""默认本体场景初始化器
负责在工作空间创建时自动初始化默认的本体场景和类型。
遵循最小侵入原则,确保初始化失败不阻止工作空间创建。
Attributes:
db: 数据库会话
scene_repo: 场景Repository
class_repo: 类型Repository
logger: 业务日志记录器
"""
def __init__(self, db: Session):
"""初始化
Args:
db: 数据库会话
"""
self.db = db
self.scene_repo = OntologySceneRepository(db)
self.class_repo = OntologyClassRepository(db)
self.logger = get_business_logger()
def initialize_default_scenes(
self,
workspace_id: UUID,
language: str = "zh"
) -> Tuple[bool, str]:
"""为工作空间初始化默认场景
创建两个默认场景(在线教育、情感陪伴)及其对应的实体类型。
如果创建失败,记录错误日志但不抛出异常。
Args:
workspace_id: 工作空间ID
language: 语言类型 ("zh""en"),默认为 "zh"
Returns:
Tuple[bool, str]: (是否成功, 错误信息)
"""
try:
self.logger.info(
f"开始初始化默认本体场景 - workspace_id={workspace_id}, language={language}"
)
scenes_created = 0
total_types_created = 0
# 遍历默认场景配置
for scene_config in DEFAULT_SCENES:
scene_name = get_scene_name(scene_config, language)
# 创建场景及其类型
scene_id = self._create_scene_with_types(workspace_id, scene_config, language)
if scene_id:
scenes_created += 1
# 统计类型数量
types_count = len(scene_config.get("types", []))
total_types_created += types_count
self.logger.info(
f"场景创建成功 - scene_name={scene_name}, "
f"scene_id={scene_id}, types_count={types_count}, language={language}"
)
else:
self.logger.warning(
f"场景创建失败 - scene_name={scene_name}, "
f"workspace_id={workspace_id}, language={language}"
)
# 记录总体结果
self.logger.info(
f"默认场景初始化完成 - workspace_id={workspace_id}, "
f"language={language}, scenes_created={scenes_created}, "
f"total_types_created={total_types_created}"
)
# 如果至少创建了一个场景,视为成功
if scenes_created > 0:
return True, ""
else:
error_msg = "所有默认场景创建失败"
self.logger.error(
f"默认场景初始化失败 - workspace_id={workspace_id}, "
f"language={language}, error={error_msg}"
)
return False, error_msg
except Exception as e:
error_msg = f"默认场景初始化异常: {str(e)}"
self.logger.error(
f"默认场景初始化异常 - workspace_id={workspace_id}, "
f"language={language}, error={str(e)}",
exc_info=True
)
return False, error_msg
def _create_scene_with_types(
self,
workspace_id: UUID,
scene_config: dict,
language: str = "zh"
) -> Optional[UUID]:
"""创建场景及其类型
Args:
workspace_id: 工作空间ID
scene_config: 场景配置字典
language: 语言类型 ("zh""en")
Returns:
Optional[UUID]: 创建的场景ID失败返回None
"""
try:
scene_name = get_scene_name(scene_config, language)
scene_description = get_scene_description(scene_config, language)
# 检查是否已存在同名场景(支持向后兼容)
existing_scene = self.scene_repo.get_by_name(scene_name, workspace_id)
if existing_scene:
self.logger.info(
f"场景已存在,跳过创建 - scene_name={scene_name}, "
f"workspace_id={workspace_id}, scene_id={existing_scene.scene_id}, "
f"language={language}"
)
return None
# 创建场景记录,设置 is_system_default=true
scene_data = {
"scene_name": scene_name,
"scene_description": scene_description
}
scene = self.scene_repo.create(scene_data, workspace_id)
# 设置系统默认标识
scene.is_system_default = True
self.db.flush()
self.logger.info(
f"场景创建成功 - scene_name={scene_name}, "
f"scene_id={scene.scene_id}, is_system_default=True, language={language}"
)
# 批量创建类型
types_config = scene_config.get("types", [])
types_created = self._batch_create_types(scene.scene_id, types_config, language)
self.logger.info(
f"场景类型创建完成 - scene_id={scene.scene_id}, "
f"types_created={types_created}/{len(types_config)}, language={language}"
)
return scene.scene_id
except Exception as e:
scene_name = get_scene_name(scene_config, language)
self.logger.error(
f"场景创建失败 - scene_name={scene_name}, "
f"workspace_id={workspace_id}, language={language}, error={str(e)}",
exc_info=True
)
return None
def _batch_create_types(
self,
scene_id: UUID,
types_config: List[dict],
language: str = "zh"
) -> int:
"""批量创建实体类型
Args:
scene_id: 场景ID
types_config: 类型配置列表
language: 语言类型 ("zh""en")
Returns:
int: 成功创建的类型数量
"""
created_count = 0
for type_config in types_config:
try:
type_name = get_type_name(type_config, language)
type_description = get_type_description(type_config, language)
# 创建类型数据
class_data = {
"class_name": type_name,
"class_description": type_description
}
# 创建类型
ontology_class = self.class_repo.create(class_data, scene_id)
# 设置系统默认标识
ontology_class.is_system_default = True
self.db.flush()
created_count += 1
self.logger.debug(
f"类型创建成功 - class_name={type_name}, "
f"class_id={ontology_class.class_id}, "
f"scene_id={scene_id}, is_system_default=True, language={language}"
)
except Exception as e:
type_name = get_type_name(type_config, language)
self.logger.warning(
f"单个类型创建失败,继续创建其他类型 - "
f"class_name={type_name}, scene_id={scene_id}, "
f"language={language}, error={str(e)}"
)
# 继续创建其他类型
continue
return created_count

View File

@@ -633,12 +633,11 @@ async def get_knowledge_type_stats_api(
current_user: User = Depends(get_current_user)
):
"""
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
会对缺失类型补 0返回字典形式。
可选按状态过滤。
- 知识库类型根据当前用户的 current_workspace_id 过滤
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
- 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0
- 如果用户没有当前工作空间,对应的统计返回 0
"""
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
try:

View File

@@ -9,6 +9,7 @@ from app.schemas.response_schema import ApiResponse
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
from app.services.memory_agent_service import get_end_users_connected_configs_batch
from app.services.app_statistics_service import AppStatisticsService
from app.core.logging_config import get_api_logger
# 获取API专用日志器
@@ -469,6 +470,8 @@ async def get_chunk_insight(
@router.get("/dashboard_data", response_model=ApiResponse)
async def dashboard_data(
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
start_date: Optional[int] = Query(None, description="开始时间戳(毫秒)"),
end_date: Optional[int] = Query(None, description="结束时间戳(毫秒)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
@@ -503,6 +506,15 @@ async def dashboard_data(
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
# 如果没有提供时间范围默认使用最近30天
if start_date is None or end_date is None:
from datetime import datetime, timedelta
end_dt = datetime.now()
start_dt = end_dt - timedelta(days=30)
end_date = int(end_dt.timestamp() * 1000)
start_date = int(start_dt.timestamp() * 1000)
api_logger.info(f"使用默认时间范围: {start_dt}{end_dt}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
@@ -563,17 +575,22 @@ async def dashboard_data(
except Exception as e:
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
# 3. 获取API调用增量total_api_call,转换为整数
# 3. 获取API调用统计total_api_call
try:
api_increment = memory_dashboard_service.get_workspace_api_increment(
db=db,
# 使用 AppStatisticsService 获取真实的API调用统计
app_stats_service = AppStatisticsService(db)
api_stats = app_stats_service.get_workspace_api_statistics(
workspace_id=workspace_id,
current_user=current_user
start_date=start_date,
end_date=end_date
)
neo4j_data["total_api_call"] = api_increment
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
# 计算总调用次数
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
neo4j_data["total_api_call"] = total_api_calls
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
except Exception as e:
api_logger.warning(f"获取API调用增量失败: {str(e)}")
api_logger.error(f"获取API调用统计失败: {str(e)}")
neo4j_data["total_api_call"] = 0
result["neo4j_data"] = neo4j_data
api_logger.info("成功获取neo4j_data")
@@ -602,10 +619,23 @@ async def dashboard_data(
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
rag_data["total_knowledge"] = total_kb
# total_api_call: 固定值
rag_data["total_api_call"] = 1024
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
try:
app_stats_service = AppStatisticsService(db)
api_stats = app_stats_service.get_workspace_api_statistics(
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
)
# 计算总调用次数
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
rag_data["total_api_call"] = total_api_calls
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
except Exception as e:
api_logger.warning(f"获取RAG模式API调用统计失败使用默认值: {str(e)}")
rag_data["total_api_call"] = 0
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
except Exception as e:
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")

View File

@@ -31,7 +31,7 @@ from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.logging_config import get_api_logger, get_business_logger
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
@@ -61,6 +61,7 @@ from app.repositories.ontology_scene_repository import OntologySceneRepository
api_logger = get_api_logger()
business_logger = get_business_logger()
logger = logging.getLogger(__name__)
router = APIRouter(
@@ -399,6 +400,20 @@ async def update_scene(
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 检查是否为系统默认场景
scene_repo = OntologySceneRepository(db)
scene = scene_repo.get_by_id(scene_uuid)
if scene and scene.is_system_default:
business_logger.warning(
f"尝试修改系统默认场景: user_id={current_user.id}, "
f"scene_id={scene_id}, scene_name={scene.scene_name}"
)
return fail(
BizCode.BAD_REQUEST,
"系统默认场景不可修改",
"该场景为系统预设场景,不允许修改"
)
# 创建OntologyService实例
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
@@ -491,6 +506,20 @@ async def delete_scene(
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 检查是否为系统默认场景
scene_repo = OntologySceneRepository(db)
scene = scene_repo.get_by_id(scene_uuid)
if scene and scene.is_system_default:
business_logger.warning(
f"尝试删除系统默认场景: user_id={current_user.id}, "
f"scene_id={scene_id}, scene_name={scene.scene_name}"
)
return fail(
BizCode.BAD_REQUEST,
"系统默认场景不可删除",
"该场景为系统预设场景,不允许删除"
)
# 创建OntologyService实例
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig

View File

@@ -11,7 +11,7 @@ from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.logging_config import get_api_logger, get_business_logger
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
@@ -30,9 +30,11 @@ from app.schemas.response_schema import ApiResponse
from app.services.ontology_service import OntologyService
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
from app.repositories.ontology_class_repository import OntologyClassRepository
api_logger = get_api_logger()
business_logger = get_business_logger()
def _get_dummy_ontology_service(db: Session) -> OntologyService:
@@ -366,6 +368,20 @@ async def update_class_handler(
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 检查是否为系统默认类型
class_repo = OntologyClassRepository(db)
ontology_class = class_repo.get_by_id(class_uuid)
if ontology_class and ontology_class.is_system_default:
business_logger.warning(
f"尝试修改系统默认类型: user_id={current_user.id}, "
f"class_id={class_id}, class_name={ontology_class.class_name}"
)
return fail(
BizCode.BAD_REQUEST,
"系统默认类型不可修改",
"该类型为系统预设类型,不允许修改"
)
# 创建Service
service = _get_dummy_ontology_service(db)
@@ -429,6 +445,20 @@ async def delete_class_handler(
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 检查是否为系统默认类型
class_repo = OntologyClassRepository(db)
ontology_class = class_repo.get_by_id(class_uuid)
if ontology_class and ontology_class.is_system_default:
business_logger.warning(
f"尝试删除系统默认类型: user_id={current_user.id}, "
f"class_id={class_id}, class_name={ontology_class.class_name}"
)
return fail(
BizCode.BAD_REQUEST,
"系统默认类型不可删除",
"该类型为系统预设类型,不允许删除"
)
# 创建Service
service = _get_dummy_ontology_service(db)

View File

@@ -89,7 +89,6 @@ async def chat(
body = await request.json()
payload = AppChatRequest(**body)
other_id = payload.user_id
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
other_id = payload.user_id
workspace_id = app.workspace_id
@@ -135,7 +134,8 @@ async def chat(
app_id=app.id,
workspace_id=workspace_id,
user_id=end_user_id,
is_draft=False
is_draft=False,
conversation_id=payload.conversation_id
)
if app_type == AppType.AGENT:

View File

@@ -100,7 +100,7 @@ def get_current_user_info(
result_schema.current_workspace_name = current_workspace.name
for ws in result.workspaces:
if ws.workspace_id == current_user.current_workspace_id:
if ws.workspace_id == current_user.current_workspace_id and ws.is_active:
result_schema.role = ws.role
break

View File

@@ -1,7 +1,7 @@
import uuid
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi import APIRouter, Depends, Header, HTTPException, Query, status
from sqlalchemy.orm import Session
from app.core.logging_config import get_api_logger
@@ -95,16 +95,29 @@ def get_workspaces(
@router.post("", response_model=ApiResponse)
def create_workspace(
workspace: WorkspaceCreate,
language_type: str = Header(default="zh", alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
):
"""创建新的工作空间"""
api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}")
from app.core.language_utils import get_language_from_header
# 验证并获取语言参数
language = get_language_from_header(language_type)
api_logger.info(
f"用户 {current_user.username} 请求创建工作空间: {workspace.name}, "
f"language={language}"
)
result = workspace_service.create_workspace(
db=db, workspace=workspace, user=current_user)
db=db, workspace=workspace, user=current_user, language=language
)
api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}")
api_logger.info(
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
f"创建者: {current_user.username}, language={language}"
)
result_schema = WorkspaceResponse.model_validate(result)
return success(data=result_schema, msg="工作空间创建成功")

View File

@@ -201,7 +201,6 @@ class Settings:
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
# Memory Cache Regeneration Configuration

View File

@@ -17,6 +17,8 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
@@ -32,6 +34,41 @@ class SummaryNodeService(LLMServiceMixin):
# 创建全局服务实例
summary_service = SummaryNodeService()
async def rag_config(state):
user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = {
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id": os.getenv('reranker_id'),
"reranker_top_k": 10
}
return kb_config
async def rag_knowledge(state,question):
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query = question
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception :
retrieval_knowledge=[]
clean_content = ''
raw_results = ''
cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge,clean_content,cleaned_query,raw_results
async def summary_history(state: ReadState) -> ReadState:
end_user_id = state.get("end_user_id", '')
@@ -71,7 +108,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
)
# 验证结构化响应
if structured is None:
logger.warning(f"LLM返回None使用默认回答")
logger.warning("LLM返回None使用默认回答")
return "信息不足,无法回答"
# 根据操作类型提取答案
@@ -82,7 +119,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
if hasattr(structured, 'data') and structured.data:
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
else:
logger.warning(f"结构化响应缺少data字段")
logger.warning("结构化响应缺少data字段")
aimessages = "信息不足,无法回答"
# 验证答案不为空
@@ -186,12 +223,13 @@ async def Input_Summary(state: ReadState) -> ReadState:
}
try:
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
if storage_type!="rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
else:
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
except Exception as e:
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
retrieve_info, question, raw_results = "", data, []
try:
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
# 'input_summary',RetrieveSummaryResponse)
@@ -290,7 +328,6 @@ async def Summary(state: ReadState)-> ReadState:
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary":summary}
async def Summary_fails(state: ReadState)-> ReadState:
storage_type=state.get("storage_type", '')
user_rag_memory_id=state.get("user_rag_memory_id", '')

View File

@@ -21,7 +21,7 @@ async def get_chunked_dialogs(
end_user_id: Group identifier
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference identifier
config_id: Configuration ID for processing
config_id: Configuration ID for processing (used to load pruning config)
Returns:
List of DialogData objects with generated chunks
@@ -57,6 +57,61 @@ async def get_chunked_dialogs(
end_user_id=end_user_id,
config_id=config_id
)
# 语义剪枝步骤(在分块之前)
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
from app.core.memory.models.config_models import PruningConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# 加载剪枝配置
pruning_config = None
if config_id:
try:
with get_db_context() as db:
# 使用 MemoryConfigService 加载完整的 MemoryConfig 对象
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="semantic_pruning"
)
if memory_config:
pruning_config = PruningConfig(
pruning_switch=memory_config.pruning_enabled,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=memory_config.pruning_threshold
)
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
# 获取LLM客户端用于剪枝
if pruning_config.pruning_switch:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
original_msg_count = len(dialog_data.context.msgs)
# 使用 prune_dataset 而不是 prune_dialog
# prune_dataset 会进行消息级剪枝,即使对话整体相关也会删除不重要消息
pruned_dialogs = await pruner.prune_dataset([dialog_data])
if pruned_dialogs:
dialog_data = pruned_dialogs[0]
remaining_msg_count = len(dialog_data.context.msgs)
deleted_count = original_msg_count - remaining_msg_count
logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)")
else:
logger.warning("[剪枝] prune_dataset 返回空列表")
else:
logger.info("[剪枝] 配置中剪枝开关关闭,跳过剪枝")
except Exception as e:
logger.warning(f"[剪枝] 加载配置失败,跳过剪枝: {e}", exc_info=True)
except Exception as e:
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data)

View File

@@ -139,14 +139,14 @@ async def get_raw_tags_from_db(
return [(record["name"], record["frequency"]) for record in results]
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = False) -> List[Tuple[str, int]]:
"""
获取原始标签然后使用LLM进行筛选返回最终的热门标签列表。
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
查询更多的标签(40)给LLM提供更丰富的上下文进行筛选但最终返回数量由limit参数控制
Args:
end_user_id: 必需参数。如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制
limit: 最终返回的标签数量限制默认10
by_user: 是否按user_id查询默认False按end_user_id查询
Raises:
@@ -161,8 +161,9 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool =
# 使用项目的Neo4jConnector
connector = Neo4jConnector()
try:
# 1. 从数据库获取原始排名靠前的标签
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
# 1. 从数据库获取原始排名靠前的标签查询40条给LLM提供更丰富的上下文
query_limit = 40
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
if not raw_tags_with_freq:
return []
@@ -177,7 +178,8 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool =
if tag in meaningful_tag_names:
final_tags.append((tag, freq))
return final_tags
# 4. 限制返回的标签数量
return final_tags[:limit]
finally:
# 确保关闭连接
await connector.close()

View File

@@ -5,20 +5,27 @@
- 对话级一次性抽取判定相关性
- 仅对"不相关对话"的消息按比例删除
- 重要信息(时间、编号、金额、联系方式、地址等)优先保留
- 改进版:增强重要性判断、智能填充消息识别、问答对保护、并发优化
"""
import asyncio
import os
import hashlib
import json
import re
from collections import OrderedDict
from datetime import datetime
from typing import List, Optional
from typing import List, Optional, Dict, Tuple, Set
from pydantic import BaseModel, Field
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
from app.core.memory.models.config_models import PruningConfig
from app.core.memory.utils.config.config_utils import get_pruning_config
from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import (
SceneConfigRegistry,
ScenePatterns
)
class DialogExtractionResponse(BaseModel):
@@ -36,6 +43,23 @@ class DialogExtractionResponse(BaseModel):
keywords: List[str] = Field(default_factory=list)
class MessageImportanceResponse(BaseModel):
"""消息重要性批量判断的结构化返回用于LLM语义判断
- importance_scores: 消息索引到重要性分数的映射 (0-10分)
- reasons: 可选的判断理由
"""
importance_scores: Dict[int, int] = Field(default_factory=dict, description="消息索引到重要性分数(0-10)的映射")
reasons: Optional[Dict[int, str]] = Field(default_factory=dict, description="可选的判断理由")
class QAPair(BaseModel):
"""问答对模型,用于识别和保护对话中的问答结构。"""
question_idx: int = Field(..., description="问题消息的索引")
answer_idx: int = Field(..., description="答案消息的索引")
confidence: float = Field(default=1.0, description="问答对的置信度(0-1)")
class SemanticPruner:
"""语义剪枝:在预处理与分块之间过滤与场景不相关内容。
@@ -43,109 +67,374 @@ class SemanticPruner:
重要信息(时间、编号、金额、联系方式、地址等)优先保留。
"""
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None):
cfg_dict = get_pruning_config() if config is None else config.model_dump()
self.config = PruningConfig.model_validate(cfg_dict)
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None, language: str = "zh", max_concurrent: int = 5):
# 如果没有提供config使用默认配置
if config is None:
# 使用默认的剪枝配置
config = PruningConfig(
pruning_switch=False, # 默认关闭剪枝,保持向后兼容
pruning_scene="education",
pruning_threshold=0.5
)
self.config = config
self.llm_client = llm_client
self.language = language # 保存语言配置
self.max_concurrent = max_concurrent # 新增:最大并发数
# 详细日志配置:限制逐条消息日志的数量
self._detailed_prune_logging = True # 是否启用详细日志
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
# 加载场景特定配置
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(
self.config.pruning_scene,
fallback_to_generic=True
)
# 检查场景是否有专门支持
is_supported = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
if is_supported:
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用专门配置")
else:
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 未预定义,使用通用配置(保守策略)")
self._log(f"[剪枝-初始化] 支持的场景: {SceneConfigRegistry.get_all_scenes()}")
# Load Jinja2 template
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
# 对话抽取缓存:避免同一对话重复调用 LLM / 重复渲染
self._dialog_extract_cache: dict[str, DialogExtractionResponse] = {}
# 对话抽取缓存:使用 OrderedDict 实现 LRU 缓存
self._dialog_extract_cache: OrderedDict[str, DialogExtractionResponse] = OrderedDict()
self._cache_max_size = 1000 # 缓存大小限制
# 运行日志:收集关键终端输出,便于写入 JSON
self.run_logs: List[str] = []
# 采用顺序处理,移除并发配置以简化与稳定执行
def _is_important_message(self, message: ConversationMessage) -> bool:
"""基于启发式规则识别重要信息消息,优先保留。
- 含日期/时间如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)。
- 含编号/ID/订单号/申请号/账号/电话/金额等关键字段。
- 关键词:"时间""日期""编号""订单""流水""金额""""""电话""手机号""邮箱""地址"
改进版:使用场景特定的模式进行识别
- 根据 pruning_scene 动态加载对应的识别规则
- 支持教育、在线服务、外呼三个场景的特定模式
"""
import re
text = message.msg.strip()
if not text:
return False
patterns = [
r"\b\d{4}-\d{1,2}-\d{1,2}\b",
r"\b\d{1,2}:\d{2}\b",
r"\d{4}\d{1,2}月\d{1,2}日",
r"上午|下午|AM|PM",
r"订单号|工单|申请号|编号|ID|账号|账户",
r"电话|手机号|微信|QQ|邮箱",
r"地址|地点",
r"金额|费用|价格|¥|¥|\d+元",
r"时间|日期|有效期|截止",
]
for p in patterns:
if re.search(p, text, flags=re.IGNORECASE):
# 使用场景特定的模式
all_patterns = (
self.scene_config.high_priority_patterns +
self.scene_config.medium_priority_patterns +
self.scene_config.low_priority_patterns
)
for pattern, _ in all_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
return True
# 检查是否为问句(以问号结尾或包含疑问词)
if text.endswith("") or text.endswith("?"):
return True
# 检查是否包含问句关键词
if any(keyword in text for keyword in self.scene_config.question_keywords):
return True
# 检查是否包含决策性关键词
if any(keyword in text for keyword in self.scene_config.decision_keywords):
return True
return False
def _importance_score(self, message: ConversationMessage) -> int:
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
简单启发:匹配到的类别越多、越关键分值越高。
改进版使用场景特定的权重体系0-10分
- 根据场景动态调整不同信息类型的权重
- 高优先级模式4-6分
- 中优先级模式2-3分
- 低优先级模式1分
"""
import re
text = message.msg.strip()
score = 0
weights = [
(r"\b\d{4}-\d{1,2}-\d{1,2}\b", 3),
(r"\b\d{1,2}:\d{2}\b", 2),
(r"\d{4}\d{1,2}月\d{1,2}日", 3),
(r"订单号|工单|申请号|编号|ID|账号|账户", 4),
(r"电话|手机号|微信|QQ|邮箱", 3),
(r"地址|地点", 2),
(r"金额|费用|价格|¥|¥|\d+元", 4),
(r"时间|日期|有效期|截止", 2),
]
for p, w in weights:
if re.search(p, text, flags=re.IGNORECASE):
score += w
return score
# 使用场景特定的权重
for pattern, weight in self.scene_config.high_priority_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
score += weight
for pattern, weight in self.scene_config.medium_priority_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
score += weight
for pattern, weight in self.scene_config.low_priority_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
score += weight
# 问句加分
if text.endswith("") or text.endswith("?"):
score += 2
# 包含问句关键词加分
if any(keyword in text for keyword in self.scene_config.question_keywords):
score += 1
# 包含决策性关键词加分
if any(keyword in text for keyword in self.scene_config.decision_keywords):
score += 2
# 长度加分(较长的消息通常包含更多信息)
if len(text) > 50:
score += 1
if len(text) > 100:
score += 1
return min(score, 10) # 最高10分
def _is_filler_message(self, message: ConversationMessage) -> bool:
"""检测典型寒暄/口头禅/确认类短消息用于跳过LLM分类以加速
"""检测典型寒暄/口头禅/确认类短消息。
改进版:更严格的填充消息判断,避免误删场景相关内容
满足以下之一视为填充消息:
- 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体;
- 常见词:你好/您好/在吗/嗯/嗯嗯/哦/好的/好/行/可以/不可以/谢谢/拜拜/再见/哈哈/呵呵/哈哈哈/。。。/??。
- 纯标点或空白
- 在场景特定填充词库中(精确匹配)
- 纯表情符号
- 常见寒暄(精确匹配短语)
注意:不再使用长度判断,避免误删短但重要的消息
"""
import re
t = message.msg.strip()
if not t:
return True
# 常见填充语
fillers = [
"你好", "您好", "在吗", "", "嗯嗯", "", "好的", "", "", "可以", "不可以", "谢谢",
"拜拜", "再见", "哈哈", "呵呵", "哈哈哈", "。。。", "??", ""
]
if t in fillers:
# 检查是否在场景特定填充词库中(精确匹配)
if t in self.scene_config.filler_phrases:
return True
# 长度与字符类型判断
if len(t) <= 8:
# 非数字、无关键实体的短文本
if not re.search(r"[0-9]", t) and not self._is_important_message(message):
# 主要是标点或简单确认词
if re.fullmatch(r"[。!?,.!?…·\s]+", t) or t in fillers:
return True
# 常见寒暄和问候(精确匹配,避免误删)
common_greetings = {
"在吗", "在不在", "在呢", "在的",
"你好", "您好", "hello", "hi",
"拜拜", "再见", "", "88", "bye",
"好的", "", "", "可以", "", "", "",
"是的", "", "对的", "没错", "是啊",
"哈哈", "呵呵", "嘿嘿", "嗯嗯"
}
if t in common_greetings:
return True
# 检查是否为纯表情符号(方括号包裹)
if re.fullmatch(r"(\[[^\]]+\])+", t):
return True
# 检查是否为纯emojiUnicode表情
emoji_pattern = re.compile(
"["
"\U0001F600-\U0001F64F" # 表情符号
"\U0001F300-\U0001F5FF" # 符号和象形文字
"\U0001F680-\U0001F6FF" # 交通和地图符号
"\U0001F1E0-\U0001F1FF" # 旗帜
"\U00002702-\U000027B0"
"\U000024C2-\U0001F251"
"]+", flags=re.UNICODE
)
if emoji_pattern.fullmatch(t):
return True
# 纯标点符号
if re.fullmatch(r"[。!?,.!?…·\s]+", t):
return True
return False
async def _batch_evaluate_importance_with_llm(
self,
messages: List[ConversationMessage],
context: str = ""
) -> Dict[int, int]:
"""使用LLM批量评估消息的重要性语义层面
Args:
messages: 消息列表
context: 对话上下文(可选)
Returns:
消息索引到重要性分数(0-10)的映射
"""
if not self.llm_client or not messages:
return {}
# 构建批量评估的提示词
msg_list = []
for idx, msg in enumerate(messages):
msg_list.append(f"{idx}. {msg.msg}")
msg_text = "\n".join(msg_list)
prompt = f"""请评估以下消息的重要性给每条消息打分0-10分
- 0-2分无意义的寒暄、口头禅、纯表情
- 3-5分一般性对话有一定信息量但不关键
- 6-8分包含重要信息时间、地点、人物、事件等
- 9-10分关键决策、承诺、重要数据
对话上下文:
{context if context else ""}
待评估的消息:
{msg_text}
请以JSON格式返回格式为
{{
"importance_scores": {{
"0": 分数,
"1": 分数,
...
}}
}}
"""
try:
messages_for_llm = [
{"role": "system", "content": "你是一个专业的对话分析助手,擅长评估消息的重要性。"},
{"role": "user", "content": prompt}
]
response = await self.llm_client.response_structured(
messages_for_llm,
MessageImportanceResponse
)
# 转换字符串键为整数键
return {int(k): v for k, v in response.importance_scores.items()}
except Exception as e:
self._log(f"[剪枝-LLM] 批量重要性评估失败: {str(e)[:100]}")
return {}
def _identify_qa_pairs(self, messages: List[ConversationMessage]) -> List[QAPair]:
"""识别对话中的问答对,用于保护问答结构的完整性。
改进版:使用场景特定的问句关键词,并排除寒暄类问句
Args:
messages: 消息列表
Returns:
问答对列表
"""
qa_pairs = []
# 寒暄类问句,不应该被保护(这些不是真正的问答)
greeting_questions = {
"在吗", "在不在", "你好吗", "怎么样", "好吗",
"有空吗", "忙吗", "睡了吗", "起床了吗"
}
for i in range(len(messages) - 1):
current_msg = messages[i].msg.strip()
next_msg = messages[i + 1].msg.strip()
# 排除寒暄类问句
if current_msg in greeting_questions:
continue
# 使用场景特定的问句关键词,但要求更严格
is_question = False
# 1. 以问号结尾
if current_msg.endswith("") or current_msg.endswith("?"):
is_question = True
# 2. 包含实质性问句关键词(排除"吗"这种太宽泛的)
elif any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "", "多少", "几点", "何时"]):
is_question = True
if is_question and next_msg:
# 检查下一条消息是否像答案(不是另一个问句,也不是寒暄)
is_answer = not (next_msg.endswith("") or next_msg.endswith("?"))
# 排除寒暄类回复
greeting_answers = {"你好", "您好", "在呢", "在的", "", "", "好的"}
if next_msg in greeting_answers:
is_answer = False
if is_answer:
qa_pairs.append(QAPair(
question_idx=i,
answer_idx=i + 1,
confidence=0.8 # 基于规则的置信度
))
return qa_pairs
def _get_protected_indices(
self,
messages: List[ConversationMessage],
qa_pairs: List[QAPair],
window_size: int = 2
) -> Set[int]:
"""获取需要保护的消息索引集合(问答对+上下文窗口)。
Args:
messages: 消息列表
qa_pairs: 问答对列表
window_size: 上下文窗口大小(前后各保留几条消息)
Returns:
需要保护的消息索引集合
"""
protected = set()
for qa_pair in qa_pairs:
# 保护问答对本身
protected.add(qa_pair.question_idx)
protected.add(qa_pair.answer_idx)
# 保护上下文窗口
for offset in range(-window_size, window_size + 1):
q_idx = qa_pair.question_idx + offset
a_idx = qa_pair.answer_idx + offset
if 0 <= q_idx < len(messages):
protected.add(q_idx)
if 0 <= a_idx < len(messages):
protected.add(a_idx)
return protected
async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse:
"""对话级一次性抽取:从整段对话中提取重要信息并判定相关性。
- 仅使用 LLM 结构化输出;
改进版:
- LRU缓存管理
- 重试机制
- 降级策略
"""
# 缓存命中则直接返回(场景+内容作为键)
cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest()
# LRU缓存如果命中移到末尾最近使用
if cache_key in self._dialog_extract_cache:
self._dialog_extract_cache.move_to_end(cache_key)
return self._dialog_extract_cache[cache_key]
rendered = self.template.render(pruning_scene=self.config.pruning_scene, dialog_text=dialog_text)
log_template_rendering("extracat_Pruning.jinja2", {"pruning_scene": self.config.pruning_scene})
# LRU缓存大小限制超过限制时删除最旧的条目
if len(self._dialog_extract_cache) >= self._cache_max_size:
# 删除最旧的条目OrderedDict的第一个
oldest_key = next(iter(self._dialog_extract_cache))
del self._dialog_extract_cache[oldest_key]
self._log(f"[剪枝-缓存] LRU缓存已满删除最旧条目")
rendered = self.template.render(
pruning_scene=self.config.pruning_scene,
dialog_text=dialog_text,
language=self.language
)
log_template_rendering("extracat_Pruning.jinja2", {
"pruning_scene": self.config.pruning_scene,
"language": self.language
})
log_prompt_rendering("pruning-extract", rendered)
# 强制使用 LLM;移除正则回退
# 强制使用 LLM
if not self.llm_client:
raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。")
@@ -153,12 +442,32 @@ class SemanticPruner:
{"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"},
{"role": "user", "content": rendered},
]
try:
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
self._dialog_extract_cache[cache_key] = ex
return ex
except Exception as e:
raise RuntimeError("LLM 结构化抽取失败;请检查 LLM 配置或重试。") from e
# 重试机制
max_retries = 3
for attempt in range(max_retries):
try:
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
self._dialog_extract_cache[cache_key] = ex
return ex
except Exception as e:
if attempt < max_retries - 1:
self._log(f"[剪枝-LLM] 第 {attempt + 1} 次尝试失败,重试中... 错误: {str(e)[:100]}")
await asyncio.sleep(0.5 * (attempt + 1)) # 指数退避
continue
else:
# 降级策略:标记为相关,避免误删
self._log(f"[剪枝-LLM] LLM 调用失败 {max_retries} 次,使用降级策略(标记为相关)")
fallback_response = DialogExtractionResponse(
is_related=True,
times=[],
ids=[],
amounts=[],
contacts=[],
addresses=[],
keywords=[]
)
return fallback_response
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
"""判断消息是否包含任意抽取到的重要片段。"""
@@ -248,12 +557,14 @@ class SemanticPruner:
async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]:
"""数据集层面:全局消息级剪枝,保留所有对话。
- 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动。
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留。
- 删除总量 = 阈值 * 全部不相关可删消息数,按可删容量比例分配;顺序删除。
- 保证每段对话至少保留1条消息不会删除整段对话。
改进版:
- 消息级独立判断,每条消息根据场景规则独立评估
- 问答对保护已注释(暂不启用,留作观察)
- 优化删除策略:填充消息 → 不重要消息 → 低分重要消息
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留
- 保证每段对话至少保留1条消息不会删除整段对话
"""
# 如果剪枝功能关闭,直接返回原始数据集
# 如果剪枝功能关闭,直接返回原始数据集
if not self.config.pruning_switch:
return dialogs
@@ -264,179 +575,140 @@ class SemanticPruner:
proportion = 0.9
if proportion < 0.0:
proportion = 0.0
evaluated_dialogs = [] # list of dicts: {dialog, is_related}
self._log(
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}"
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断"
)
# 对话级相关性分类(一次性对整段对话文本进行判断,顺序执行并复用缓存)
evaluated_dialogs = []
for idx, dd in enumerate(dialogs):
try:
ex = await self._extract_dialog_important(dd.content)
evaluated_dialogs.append({
"dialog": dd,
"is_related": bool(ex.is_related),
"index": idx,
"extraction": ex
})
except Exception:
evaluated_dialogs.append({
"dialog": dd,
"is_related": True,
"index": idx,
"extraction": None
})
# 统计相关 / 不相关对话
not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]]
related_dialogs = [d for d in evaluated_dialogs if d["is_related"]]
self._log(
f"[剪枝-数据集] 相关对话数={len(related_dialogs)} 不相关对话数={len(not_related_dialogs)}"
)
# 简洁打印第几段对话相关/不相关索引基于1
def _fmt_indices(items, cap: int = 10):
inds = [i["index"] + 1 for i in items]
if len(inds) <= cap:
return inds
# 超过上限时只打印前cap个并标注总数
return inds[:cap] + ["...", f"{len(inds)}"]
rel_inds = _fmt_indices(related_dialogs)
nrel_inds = _fmt_indices(not_related_dialogs)
self._log(f"[剪枝-数据集] 相关对话:第{rel_inds}段;不相关对话:第{nrel_inds}")
result: List[DialogData] = []
if not_related_dialogs:
# 为每个不相关对话进行一次性抽取,识别重要/不重要(避免逐条 LLM
per_dialog_info = {}
total_unrelated = 0
total_capacity = 0
for d in not_related_dialogs:
dd = d["dialog"]
extraction = d.get("extraction")
if extraction is None:
extraction = await self._extract_dialog_important(dd.content)
# 合并所有重要标记
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
msgs = dd.context.msgs
# 分类消息
imp_unrel_msgs = [m for m in msgs if self._msg_matches_tokens(m, tokens) or self._is_important_message(m)]
unimp_unrel_msgs = [m for m in msgs if m not in imp_unrel_msgs]
# 重要消息按重要性排序
imp_sorted_ids = [id(m) for m in sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))]
info = {
"dialog": dd,
"total_msgs": len(msgs),
"unrelated_count": len(msgs),
"imp_ids_sorted": imp_sorted_ids,
"unimp_ids": [id(m) for m in unimp_unrel_msgs],
}
per_dialog_info[d["index"]] = info
total_unrelated += info["unrelated_count"]
# 全局删除配额:比例作用于全部不相关消息(重要+不重要)
global_delete = int(total_unrelated * proportion)
if proportion > 0 and total_unrelated > 0 and global_delete == 0:
global_delete = 1
# 每段的最大可删容量:不重要全部 + 重要最多删除 floor(len(重要)*比例)且至少保留1条消息
capacities = []
for d in not_related_dialogs:
idx = d["index"]
info = per_dialog_info[idx]
# 统计重要数量
imp_count = len(info["imp_ids_sorted"])
unimp_count = len(info["unimp_ids"])
imp_cap = int(imp_count * proportion)
cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1))
capacities.append(cap)
total_capacity = sum(capacities)
if global_delete > total_capacity:
print(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}(重要消息按比例保留)。将按最大可删执行。")
global_delete = total_capacity
# 配额分配:按不相关消息占比分配到各对话,但不超过各自容量
alloc = []
for i, d in enumerate(not_related_dialogs):
idx = d["index"]
info = per_dialog_info[idx]
share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0
alloc.append(min(share, capacities[i]))
allocated = sum(alloc)
rem = global_delete - allocated
turn = 0
while rem > 0 and turn < 100000:
progressed = False
for i in range(len(not_related_dialogs)):
if rem <= 0:
break
if alloc[i] < capacities[i]:
alloc[i] += 1
rem -= 1
progressed = True
if not progressed:
break
turn += 1
# 应用删除:相关对话不动;不相关按分配先删不重要,再删重要(低分优先)
total_deleted_confirm = 0
for d in evaluated_dialogs:
dd = d["dialog"]
msgs = dd.context.msgs
original = len(msgs)
if d["is_related"]:
result.append(dd)
continue
idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None)
if idx_in_unrel is None:
result.append(dd)
continue
quota = alloc[idx_in_unrel]
info = per_dialog_info[d["index"]]
# 计算本对话重要最多可删数量
imp_count = len(info["imp_ids_sorted"])
imp_del_cap = int(imp_count * proportion)
# 先构造顺序删除的"不重要ID集合"(按出现顺序前 quota 条)
unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))])
del_unimp = min(quota, len(unimp_delete_ids))
rem_quota = quota - del_unimp
# 再从重要里选低分优先的删除ID不超过 imp_del_cap
imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)])
deleted_here = 0
actual_unimp_deleted = 0
actual_imp_deleted = 0
kept = []
for m in msgs:
mid = id(m)
if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp:
actual_unimp_deleted += 1
deleted_here += 1
continue
if mid in imp_delete_ids and actual_imp_deleted < len(imp_delete_ids):
actual_imp_deleted += 1
deleted_here += 1
continue
kept.append(m)
if not kept and msgs:
kept = [msgs[0]]
dd.context.msgs = kept
total_deleted_confirm += deleted_here
self._log(
f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}"
)
result.append(dd)
self._log(f"[剪枝-数据集] 全局消息级顺序剪枝完成,总删除 {total_deleted_confirm} 条(不相关消息,重要按比例保留)。")
else:
# 全部相关:不执行剪枝
result = [d["dialog"] for d in evaluated_dialogs]
total_original_msgs = 0
total_deleted_msgs = 0
for d_idx, dd in enumerate(dialogs):
msgs = dd.context.msgs
original_count = len(msgs)
total_original_msgs += original_count
# ========== 问答对保护(已注释,暂不启用,留作观察) ==========
# qa_pairs = self._identify_qa_pairs(msgs)
# protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=0)
# ========================================================
# 消息级分类:每条消息独立判断
important_msgs = [] # 重要消息(保留)
unimportant_msgs = [] # 不重要消息(可删除)
filler_msgs = [] # 填充消息(优先删除)
# 判断是否需要详细日志仅对前N条消息记录
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog:
self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志")
for idx, m in enumerate(msgs):
msg_text = m.msg.strip()
# ========== 问答对保护判断(已注释) ==========
# if idx in protected_indices:
# important_msgs.append((idx, m))
# self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(问答对保护)")
# ==========================================
# 填充消息(寒暄、表情等)
if self._is_filler_message(m):
filler_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...' → 填充")
# 重要信息(学号、成绩、时间、金额等)
elif self._is_important_message(m):
important_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)")
# 其他消息
else:
unimportant_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...' → 不重要")
# 计算删除配额
delete_target = int(original_count * proportion)
if proportion > 0 and original_count > 0 and delete_target == 0:
delete_target = 1
# 确保至少保留1条消息
max_deletable = max(0, original_count - 1)
delete_target = min(delete_target, max_deletable)
# 删除策略:优先删除填充消息,再删除不重要消息
to_delete_indices = set()
deleted_details = [] # 记录删除的消息详情
# 第一步:删除填充消息
filler_to_delete = min(len(filler_msgs), delete_target)
for i in range(filler_to_delete):
idx, msg = filler_msgs[i]
to_delete_indices.add(idx)
deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'")
# 第二步:如果还需要删除,删除不重要消息
remaining_quota = delete_target - len(to_delete_indices)
if remaining_quota > 0:
unimp_to_delete = min(len(unimportant_msgs), remaining_quota)
for i in range(unimp_to_delete):
idx, msg = unimportant_msgs[i]
to_delete_indices.add(idx)
deleted_details.append(f"[{idx}] 不重要: '{msg.msg[:50]}'")
# 第三步:如果还需要删除,按重要性分数删除重要消息
remaining_quota = delete_target - len(to_delete_indices)
if remaining_quota > 0 and important_msgs:
# 按重要性分数排序(分数低的优先删除)
imp_sorted = sorted(important_msgs, key=lambda x: self._importance_score(x[1]))
imp_to_delete = min(len(imp_sorted), remaining_quota)
for i in range(imp_to_delete):
idx, msg = imp_sorted[i]
to_delete_indices.add(idx)
score = self._importance_score(msg)
deleted_details.append(f"[{idx}] 重要(分数{score}): '{msg.msg[:50]}'")
# 执行删除
kept_msgs = []
for idx, m in enumerate(msgs):
if idx not in to_delete_indices:
kept_msgs.append(m)
# 确保至少保留1条
if not kept_msgs and msgs:
kept_msgs = [msgs[0]]
dd.context.msgs = kept_msgs
deleted_count = original_count - len(kept_msgs)
total_deleted_msgs += deleted_count
# 输出删除详情
if deleted_details:
self._log(f"[剪枝-删除详情] 对话 {d_idx+1} 删除了以下消息:")
for detail in deleted_details:
self._log(f" {detail}")
# ========== 问答对统计(已注释) ==========
# qa_info = f",问答对={len(qa_pairs)}" if qa_pairs else ""
# ========================================
self._log(
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
f"(重要={len(important_msgs)} 不重要={len(unimportant_msgs)} 填充={len(filler_msgs)}) "
f"删除={deleted_count} 保留={len(kept_msgs)}"
)
result.append(dd)
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
# 将本次剪枝阶段的终端输出保存为 JSON 文件(仅在剪枝器内部完成)
# 保存日志
try:
from app.core.config import settings
settings.ensure_memory_output_dir()
log_output_path = settings.get_memory_output_path("pruned_terminal.json")
# 去除日志前缀标签(如 [剪枝-数据集]、[剪枝-对话])后再解析为结构化字段保存
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
payload = self._parse_logs_to_structured(sanitized_logs)
with open(log_output_path, "w", encoding="utf-8") as f:
@@ -448,6 +720,7 @@ class SemanticPruner:
if not result:
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
return dialogs
return result
def _log(self, msg: str) -> None:

View File

@@ -0,0 +1,326 @@
"""
场景特定配置 - 为不同场景提供定制化的剪枝规则
功能:
- 场景特定的重要信息识别模式
- 场景特定的重要性评分权重
- 场景特定的填充词库
- 场景特定的问答对识别规则
"""
from typing import Dict, List, Set, Tuple
from dataclasses import dataclass, field
@dataclass
class ScenePatterns:
"""场景特定的识别模式"""
# 重要信息的正则模式(优先级从高到低)
high_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) # (pattern, weight)
medium_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
low_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
# 填充词库(无意义对话)
filler_phrases: Set[str] = field(default_factory=set)
# 问句关键词(用于识别问答对)
question_keywords: Set[str] = field(default_factory=set)
# 决策性/承诺性关键词
decision_keywords: Set[str] = field(default_factory=set)
class SceneConfigRegistry:
"""场景配置注册表 - 管理所有场景的特定配置"""
# 基础通用模式(所有场景共享)
BASE_HIGH_PRIORITY = [
(r"订单号|工单|申请号|编号|ID|账号|账户", 5),
(r"金额|费用|价格|¥|¥|\d+元", 5),
(r"\d{11}", 4), # 手机号
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱
]
BASE_MEDIUM_PRIORITY = [
(r"\d{4}-\d{1,2}-\d{1,2}", 3), # 日期
(r"\d{4}\d{1,2}月\d{1,2}日", 3),
(r"电话|手机号|微信|QQ|联系方式", 3),
(r"地址|地点|位置", 2),
(r"时间|日期|有效期|截止", 2),
(r"今天|明天|后天|昨天|前天", 3), # 相对时间(提高权重)
(r"下周|下月|下年|上周|上月|上年|本周|本月|本年", 3),
(r"今年|去年|明年", 3),
]
BASE_LOW_PRIORITY = [
(r"\d{1,2}:\d{2}", 2), # 时间点 HH:MM
(r"\d{1,2}点\d{0,2}分?", 2), # 时间点 X点Y分 或 X点
(r"上午|下午|中午|晚上|早上|傍晚|凌晨", 2), # 时段(提高权重并扩充)
(r"AM|PM|am|pm", 1),
]
BASE_FILLERS = {
# 基础寒暄
"你好", "您好", "在吗", "在的", "在呢", "", "嗯嗯", "", "哦哦",
"好的", "", "", "可以", "不可以", "谢谢", "多谢", "感谢",
"拜拜", "再见", "88", "", "回见",
# 口头禅
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
"", "", "", "", "", "", "嗯哼",
# 确认词
"是的", "", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了",
# 标点和符号
"。。。", "...", "???", "", "!!!", "",
# 表情符号
"[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]",
"[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]",
"[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]",
"[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]",
# 网络用语
"hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok",
"emmm", "emm", "em", "mmp", "wtf", "omg",
}
BASE_QUESTION_KEYWORDS = {
"什么", "为什么", "怎么", "如何", "哪里", "哪个", "", "多少", "几点", "何时", ""
}
BASE_DECISION_KEYWORDS = {
"必须", "一定", "务必", "需要", "要求", "规定", "应该",
"承诺", "保证", "确保", "负责", "同意", "答应"
}
@classmethod
def get_education_config(cls) -> ScenePatterns:
"""教育场景配置"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
# 成绩相关(最高优先级)
(r"成绩|分数|得分|满分|及格|不及格", 6),
(r"GPA|绩点|学分|平均分", 6),
(r"\d+分|\d+\.?\d*分", 5), # 具体分数
(r"排名|名次|第.{1,3}名", 5), # 支持"第三名"、"第1名"等
# 学籍信息
(r"学号|学生证|教师工号|工号", 5),
(r"班级|年级|专业|院系", 4),
# 课程相关
(r"课程|科目|学科|必修|选修", 4),
(r"教材|课本|教科书|参考书", 4),
(r"章节|第.{1,3}章|第.{1,3}节", 3), # 支持"第三章"、"第1章"等
# 学科内容(新增)
(r"微积分|导数|积分|函数|极限|微分", 4),
(r"代数|几何|三角|概率|统计", 4),
(r"物理|化学|生物|历史|地理", 4),
(r"英语|语文|数学|政治|哲学", 4),
(r"定义|定理|公式|概念|原理|法则", 3),
(r"例题|解题|证明|推导|计算", 3),
],
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
# 教学活动
(r"作业|练习|习题|题目", 3),
(r"考试|测验|测试|考核|期中|期末", 3),
(r"上课|下课|课堂|讲课", 2),
(r"提问|回答|发言|讨论", 2),
(r"问一下|请教|咨询|询问", 2), # 新增:问询相关
(r"理解|明白|懂|掌握|学会", 2), # 新增:学习状态
# 时间安排
(r"课表|课程表|时间表", 3),
(r"第.{1,3}节课|第.{1,3}周", 2), # 支持"第三节课"、"第1周"等
],
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
(r"老师|教师|同学|学生", 1),
(r"教室|实验室|图书馆", 1),
],
filler_phrases=cls.BASE_FILLERS | {
# 教育场景特有填充词(移除了"明白了"、"懂了"、"不懂"等,这些在教育场景中有意义)
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
"举手", "请坐", "很好", "不错", "继续",
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
},
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
"为啥", "", "咋办", "怎样", "如何做",
"能不能", "可不可以", "行不行", "对不对", "是不是",
},
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
"必考", "重点", "考点", "难点", "关键",
"记住", "背诵", "掌握", "理解", "复习",
}
)
@classmethod
def get_online_service_config(cls) -> ScenePatterns:
"""在线服务场景配置"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
# 工单相关(最高优先级)
(r"工单号|工单编号|ticket|TK\d+", 6),
(r"工单状态|处理中|已解决|已关闭|待处理", 5),
(r"优先级|紧急|高优先级|P0|P1|P2", 5),
# 产品信息
(r"产品型号|型号|SKU|产品编号", 5),
(r"序列号|SN|设备号", 5),
(r"版本号|软件版本|固件版本", 4),
# 问题描述
(r"故障|错误|异常|bug|问题", 4),
(r"错误代码|故障代码|error code", 5),
(r"无法|不能|失败|报错", 3),
],
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
# 服务相关
(r"退款|退货|换货|补发", 4),
(r"发票|收据|凭证", 3),
(r"物流|快递|运单号", 3),
(r"保修|质保|售后", 3),
# 时效相关
(r"SLA|响应时间|处理时长", 4),
(r"超时|延迟|等待", 2),
],
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
(r"客服|工程师|技术支持", 1),
(r"用户|客户|会员", 1),
],
filler_phrases=cls.BASE_FILLERS | {
# 在线服务特有填充词
"您好", "请问", "请稍等", "稍等", "马上", "立即",
"正在查询", "正在处理", "正在为您", "帮您查一下",
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
"感谢您的耐心等待", "抱歉让您久等了",
"已记录", "已反馈", "已转接", "已升级",
"祝您生活愉快", "再见", "欢迎下次咨询",
},
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
"能否", "可否", "是否", "有没有", "能不能",
"怎么办", "如何处理", "怎么解决",
},
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
"立即处理", "马上解决", "尽快", "优先",
"升级", "转接", "派单", "跟进",
"补偿", "赔偿", "退款", "换货",
}
)
@classmethod
def get_outbound_config(cls) -> ScenePatterns:
"""外呼场景配置"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
# 意向相关(最高优先级)
(r"意向|意愿|兴趣|感兴趣", 6),
(r"A类|B类|C类|D类|高意向|低意向", 6),
(r"成交|签约|下单|购买|确认", 6),
# 联系信息(外呼场景中更重要)
(r"预约|约定|安排|确定时间", 5),
(r"下次联系|回访|跟进", 5),
(r"方便|有空|可以|时间", 4),
# 通话状态
(r"接通|未接通|占线|关机|停机", 4),
(r"通话时长|通话时间", 3),
],
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
# 客户信息
(r"姓名|称呼|先生|女士", 3),
(r"公司|单位|职位|职务", 3),
(r"需求|要求|期望", 3),
# 跟进状态
(r"跟进状态|进展|进度", 3),
(r"已联系|待联系|联系中", 2),
(r"拒绝|不感兴趣|考虑|再说", 3),
],
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
(r"销售|客户经理|业务员", 1),
(r"产品|服务|方案", 1),
],
filler_phrases=cls.BASE_FILLERS | {
# 外呼场景特有填充词
"您好", "", "hello", "打扰了", "不好意思",
"方便接电话吗", "现在方便吗", "占用您一点时间",
"我是", "我们是", "我们公司", "我们这边",
"了解一下", "介绍一下", "简单说一下",
"考虑考虑", "想一想", "再说", "再看看",
"不需要", "不感兴趣", "没兴趣", "不用了",
"好的", "", "可以", "没问题", "那就这样",
"再联系", "回头聊", "有需要再说",
},
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
"有没有", "需不需要", "要不要", "考虑不考虑",
"了解吗", "知道吗", "听说过吗",
"方便吗", "有空吗", "在吗",
},
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
"确定", "决定", "选择", "购买", "下单",
"预约", "安排", "约定", "确认",
"跟进", "回访", "联系", "沟通",
}
)
@classmethod
def get_config(cls, scene: str, fallback_to_generic: bool = True) -> ScenePatterns:
"""根据场景名称获取配置
Args:
scene: 场景名称 ('education', 'online_service', 'outbound' 或其他)
fallback_to_generic: 如果场景不存在,是否降级到通用配置
Returns:
对应场景的配置,如果场景不存在:
- fallback_to_generic=True: 返回通用配置(仅基础规则)
- fallback_to_generic=False: 抛出异常
"""
scene_map = {
'education': cls.get_education_config,
'online_service': cls.get_online_service_config,
'outbound': cls.get_outbound_config,
}
if scene in scene_map:
return scene_map[scene]()
if fallback_to_generic:
# 返回通用配置(仅包含基础规则,不包含场景特定规则)
return cls.get_generic_config()
else:
raise ValueError(f"不支持的场景: {scene},支持的场景: {list(scene_map.keys())}")
@classmethod
def get_generic_config(cls) -> ScenePatterns:
"""通用场景配置 - 仅包含基础规则,适用于未定义的场景
这是一个保守的配置,只使用最通用的规则,避免误删重要信息
"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY,
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY,
low_priority_patterns=cls.BASE_LOW_PRIORITY,
filler_phrases=cls.BASE_FILLERS,
question_keywords=cls.BASE_QUESTION_KEYWORDS,
decision_keywords=cls.BASE_DECISION_KEYWORDS
)
@classmethod
def get_all_scenes(cls) -> List[str]:
"""获取所有预定义场景的列表"""
return ['education', 'online_service', 'outbound']
@classmethod
def is_scene_supported(cls, scene: str) -> bool:
"""检查场景是否有专门的配置支持
Args:
scene: 场景名称
Returns:
True: 有专门配置
False: 将使用通用配置
"""
return scene in cls.get_all_scenes()

View File

@@ -1932,17 +1932,17 @@ def preprocess_data(
Returns:
经过清洗转换后的 DialogData 列表
"""
print("\n=== 数据预处理 ===")
logger.debug("=== 数据预处理 ===")
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
DataPreprocessor,
)
preprocessor = DataPreprocessor()
try:
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
print(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
return cleaned_data
except Exception as e:
print(f"数据预处理过程中出现错误: {e}")
logger.error(f"数据预处理过程中出现错误: {e}")
raise
@@ -1961,7 +1961,7 @@ async def get_chunked_dialogs_from_preprocessed(
Returns:
带 chunks 的 DialogData 列表
"""
print(f"\n=== 批量对话分块处理 (使用 {chunker_strategy}) ===")
logger.debug(f"=== 批量对话分块处理 (使用 {chunker_strategy}) ===")
if not data:
raise ValueError("预处理数据为空,无法进行分块")
@@ -1988,6 +1988,7 @@ async def get_chunked_dialogs_with_preprocessing(
input_data_path: Optional[str] = None,
llm_client: Optional[Any] = None,
skip_cleaning: bool = True,
pruning_config: Optional[Dict] = None,
) -> List[DialogData]:
"""包含数据预处理步骤的完整分块流程
@@ -2000,11 +2001,12 @@ async def get_chunked_dialogs_with_preprocessing(
input_data_path: 输入数据路径
llm_client: LLM 客户端
skip_cleaning: 是否跳过数据清洗步骤默认False
pruning_config: 剪枝配置字典,包含 pruning_switch, pruning_scene, pruning_threshold
Returns:
带 chunks 的 DialogData 列表
"""
print("\n=== 完整数据处理流程(包含预处理)===")
logger.debug("=== 完整数据处理流程(包含预处理)===")
if input_data_path is None:
input_data_path = os.path.join(
@@ -2030,7 +2032,19 @@ async def get_chunked_dialogs_with_preprocessing(
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
SemanticPruner,
)
pruner = SemanticPruner(llm_client=llm_client)
from app.core.memory.models.config_models import PruningConfig
# 构建剪枝配置
if pruning_config:
# 使用传入的配置
config = PruningConfig(**pruning_config)
logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
else:
# 使用默认配置(关闭剪枝)
config = None
logger.debug("[剪枝] 未提供配置,使用默认配置(剪枝关闭)")
pruner = SemanticPruner(config=config, llm_client=llm_client)
# 记录单对话场景下剪枝前的消息数量
single_dialog_original_msgs = None
@@ -2043,12 +2057,12 @@ async def get_chunked_dialogs_with_preprocessing(
if len(preprocessed_data) == 1 and single_dialog_original_msgs is not None:
remaining_msgs = len(preprocessed_data[0].context.msgs) if preprocessed_data[0].context else 0
deleted_msgs = max(0, single_dialog_original_msgs - remaining_msgs)
print(
logger.debug(
f"语义剪枝完成!剩余 1 条对话!原始消息数:{single_dialog_original_msgs}"
f"保留消息数:{remaining_msgs},删除 {deleted_msgs} 条。"
)
else:
print(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话")
logger.debug(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话")
# 保存剪枝后的数据
try:
@@ -2059,9 +2073,9 @@ async def get_chunked_dialogs_with_preprocessing(
dp = DataPreprocessor(output_file_path=pruned_output_path)
dp.save_data(preprocessed_data, output_path=pruned_output_path)
except Exception as se:
print(f"保存剪枝结果失败:{se}")
logger.error(f"保存剪枝结果失败:{se}")
except Exception as e:
print(f"语义剪枝过程中出现错误,跳过剪枝: {e}")
logger.error(f"语义剪枝过程中出现错误,跳过剪枝: {e}")
# 步骤3: 对话分块
return await get_chunked_dialogs_from_preprocessed(

View File

@@ -1,5 +1,7 @@
import os
from typing import Optional
from typing import Optional, List, Any
from enum import Enum
from pathlib import Path
from app.core.logging_config import get_memory_logger
from app.core.memory.models.message_models import DialogData, Chunk
@@ -10,6 +12,20 @@ from app.core.memory.utils.config.config_utils import get_chunker_config
logger = get_memory_logger(__name__)
class ChunkerStrategy(Enum):
"""Supported chunking strategies."""
RECURSIVE = "RecursiveChunker"
SEMANTIC = "SemanticChunker"
LATE = "LateChunker"
NEURAL = "NeuralChunker"
LLM = "LLMChunker"
@classmethod
def get_valid_strategies(cls) -> List[str]:
"""Get list of valid strategy names."""
return [strategy.value for strategy in cls]
class DialogueChunker:
"""A class that processes dialogues and fills them with chunks based on a specified strategy.
@@ -17,23 +33,51 @@ class DialogueChunker:
of different chunking strategies to dialogue data.
"""
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client=None):
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client: Optional[Any] = None):
"""Initialize the DialogueChunker with a specific chunking strategy.
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker, LLMChunker
llm_client: LLM client instance (required for LLMChunker strategy)
Raises:
ValueError: If chunker_strategy is invalid or required parameters are missing
"""
self.chunker_strategy = chunker_strategy
chunker_config_dict = get_chunker_config(chunker_strategy)
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
# Validate strategy
valid_strategies = ChunkerStrategy.get_valid_strategies()
if chunker_strategy not in valid_strategies:
raise ValueError(
f"Invalid chunker_strategy: '{chunker_strategy}'. "
f"Must be one of {valid_strategies}"
)
if self.chunker_config.chunker_strategy == "LLMChunker":
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
else:
self.chunker_client = ChunkerClient(self.chunker_config)
self.chunker_strategy = chunker_strategy
logger.info(f"Initializing DialogueChunker with strategy: {chunker_strategy}")
try:
# Load and validate configuration
chunker_config_dict = get_chunker_config(chunker_strategy)
if not chunker_config_dict:
raise ValueError(f"Failed to load configuration for strategy: {chunker_strategy}")
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
# Initialize chunker client
if self.chunker_config.chunker_strategy == "LLMChunker":
if not llm_client:
raise ValueError("llm_client is required for LLMChunker strategy")
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
else:
self.chunker_client = ChunkerClient(self.chunker_config)
logger.info(f"DialogueChunker initialized successfully with strategy: {chunker_strategy}")
except Exception as e:
logger.error(f"Failed to initialize DialogueChunker: {e}", exc_info=True)
raise
async def process_dialogue(self, dialogue: DialogData) -> list[Chunk]:
async def process_dialogue(self, dialogue: DialogData) -> List[Chunk]:
"""Process a dialogue by generating chunks and adding them to the DialogData object.
Args:
@@ -43,54 +87,125 @@ class DialogueChunker:
A list of Chunk objects
Raises:
ValueError: If chunking fails or returns empty chunks
ValueError: If dialogue is invalid or chunking fails
Exception: If chunking process encounters an error
"""
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
chunks = result_dialogue.chunks
if not chunks or len(chunks) == 0:
# Validate input
if not dialogue:
raise ValueError("dialogue cannot be None")
if not dialogue.context or not dialogue.context.msgs:
raise ValueError(
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, "
f"Strategy: {self.chunker_config.chunker_strategy}"
f"Dialogue {dialogue.ref_id} has no messages to chunk. "
f"Context: {dialogue.context is not None}, "
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}"
)
logger.info(
f"Processing dialogue {dialogue.ref_id} with {len(dialogue.context.msgs)} messages "
f"using strategy: {self.chunker_strategy}"
)
try:
# Generate chunks
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
chunks = result_dialogue.chunks
return chunks
# Validate results
if not chunks or len(chunks) == 0:
raise ValueError(
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
f"Messages: {len(dialogue.context.msgs)}, "
f"Content length: {len(dialogue.content) if dialogue.content else 0}, "
f"Strategy: {self.chunker_config.chunker_strategy}"
)
def save_chunking_results(self, dialogue: DialogData, output_path: Optional[str] = None) -> str:
logger.info(
f"Successfully generated {len(chunks)} chunks for dialogue {dialogue.ref_id}. "
f"Total characters processed: {len(dialogue.content) if dialogue.content else 0}"
)
return chunks
except ValueError:
# Re-raise validation errors
raise
except Exception as e:
logger.error(
f"Error processing dialogue {dialogue.ref_id} with strategy {self.chunker_strategy}: {e}",
exc_info=True
)
raise
def save_chunking_results(
self,
chunks: List[Chunk],
dialogue: DialogData,
output_path: Optional[str] = None,
preview_length: int = 100
) -> str:
"""Save the chunking results to a file and return the output path.
Args:
dialogue: The processed DialogData object with chunks
output_path: Optional path to save the output
chunks: List of Chunk objects to save
dialogue: The DialogData object that was processed
output_path: Optional path to save the output (defaults to current directory)
preview_length: Maximum length of content preview (default: 100)
Returns:
The path where the output was saved
Raises:
ValueError: If chunks or dialogue is invalid
IOError: If file writing fails
"""
if not output_path:
output_path = os.path.join(
os.path.dirname(__file__), "..", "..",
f"chunker_output_{self.chunker_strategy.lower()}.txt"
)
output_lines = [
f"=== Chunking Results ({self.chunker_strategy}) ===",
f"Dialogue ID: {dialogue.ref_id}",
f"Original conversation has {len(dialogue.context.msgs)} messages",
f"Total characters: {len(dialogue.content)}",
f"Generated {len(dialogue.chunks)} chunks:"
]
# Validate input
if not chunks:
raise ValueError("chunks list cannot be empty")
if not dialogue:
raise ValueError("dialogue cannot be None")
for i, chunk in enumerate(dialogue.chunks):
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
output_lines.append(f" Content preview: {chunk.content}...")
if chunk.metadata:
output_lines.append(f" Metadata: {chunk.metadata}")
# Generate default output path if not provided
if not output_path:
output_dir = Path(__file__).parent.parent.parent
output_path = str(output_dir / f"chunker_output_{self.chunker_strategy.lower()}.txt")
logger.info(f"Saving chunking results to: {output_path}")
try:
# Prepare output content
output_lines = [
f"=== Chunking Results ({self.chunker_strategy}) ===",
f"Dialogue ID: {dialogue.ref_id}",
f"Original conversation has {len(dialogue.context.msgs) if dialogue.context else 0} messages",
f"Total characters: {len(dialogue.content) if dialogue.content else 0}",
f"Generated {len(chunks)} chunks:",
""
]
for i, chunk in enumerate(chunks, 1):
content_preview = chunk.content[:preview_length] if chunk.content else ""
if len(chunk.content) > preview_length:
content_preview += "..."
output_lines.append(f" Chunk {i}: {len(chunk.content)} characters")
output_lines.append(f" Content preview: {content_preview}")
if chunk.metadata:
output_lines.append(f" Metadata: {chunk.metadata}")
output_lines.append("")
with open(output_path, "w", encoding="utf-8") as f:
f.write("\n".join(output_lines))
# Write to file
with open(output_path, "w", encoding="utf-8") as f:
f.write("\n".join(output_lines))
logger.info(f"Chunking results saved to: {output_path}")
return output_path
logger.info(f"Successfully saved chunking results to: {output_path}")
return output_path
except IOError as e:
logger.error(f"Failed to write chunking results to {output_path}: {e}", exc_info=True)
raise
except Exception as e:
logger.error(f"Unexpected error saving chunking results: {e}", exc_info=True)
raise

View File

@@ -400,7 +400,8 @@ async def render_user_summary_prompt(
user_id: str,
entities: str,
statements: str,
language: str = "zh"
language: str = "zh",
user_display_name: str = None
) -> str:
"""
Renders the user summary prompt using the user_summary.jinja2 template.
@@ -410,16 +411,22 @@ async def render_user_summary_prompt(
entities: Core entities with frequency information
statements: Representative statement samples
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
user_display_name: Display name for the user (e.g., other_name or "该用户"/"the user")
Returns:
Rendered prompt content as string
"""
# 如果没有提供 user_display_name使用默认值
if user_display_name is None:
user_display_name = "该用户" if language == "zh" else "the user"
template = prompt_env.get_template("user_summary.jinja2")
rendered_prompt = template.render(
user_id=user_id,
entities=entities,
statements=statements,
language=language
language=language,
user_display_name=user_display_name
)
# 记录渲染结果到提示日志
@@ -429,7 +436,8 @@ async def render_user_summary_prompt(
'user_id': user_id,
'entities_len': len(entities),
'statements_len': len(statements),
'language': language
'language': language,
'user_display_name': user_display_name
})
return rendered_prompt

View File

@@ -14,8 +14,8 @@ Your task is to generate a comprehensive user profile based on the provided enti
{% endif %}
===Inputs===
{% if user_id %}
- User ID: {{ user_id }}
{% if user_display_name %}
- User Display Name: {{ user_display_name }}
{% endif %}
{% if entities %}
- Core Entities & Frequency: {{ entities }}
@@ -33,6 +33,20 @@ Your task is to generate a comprehensive user profile based on the provided enti
3. Avoid excessive adjectives and empty phrases
4. Strictly follow the output format specified below
{% if language == "zh" %}
**【严格人称规定】**
- 在描述用户时,必须使用"{{ user_display_name }}"作为人称
- 绝对禁止使用用户ID如 {{ user_id }})来称呼用户
- 绝对禁止在摘要中出现任何形式的UUID或ID字符串
- 如果需要指代用户,只能使用"{{ user_display_name }}"或相应的代词(他/她/TA
{% else %}
**【STRICT PRONOUN RULES】**
- When describing the user, you MUST use "{{ user_display_name }}" as the reference
- It is ABSOLUTELY FORBIDDEN to use the user ID (such as {{ user_id }}) to refer to the user
- It is ABSOLUTELY FORBIDDEN to include any form of UUID or ID string in the summary
- If you need to refer to the user, you can ONLY use "{{ user_display_name }}" or appropriate pronouns (he/she/they)
{% endif %}
**Section-Specific Requirements:**
{% if language == "zh" %}
@@ -103,13 +117,13 @@ Your task is to generate a comprehensive user profile based on the provided enti
{% if language == "zh" %}
Example Input:
- User ID: user_12345
- User Display Name: 张三
- Core Entities & Frequency: 产品经理 (15), AI (12), 深圳 (10), 数据分析 (8), 团队协作 (7)
- Representative Statement Samples: 我在深圳从事产品经理工作已经5年了 | 我相信好的产品源于对用户需求的深刻理解 | 我喜欢在团队中起到协调作用 | 数据驱动决策是我的工作原则
Example Output:
【基本介绍】
我是张三一名充满热情的高级产品经理。在过去的5年里专注于AI和数据驱动的产品设计致力于创造能够真正改善用户生活的产品。相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
张三一名充满热情的高级产品经理,在深圳工作。在过去的5年里张三专注于AI和数据驱动的产品设计致力于创造能够真正改善用户生活的产品。张三相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
【性格特点】
性格开朗,善于沟通,注重细节。喜欢在团队中起到协调作用,帮助大家达成共识。面对挑战时保持乐观,相信每个问题都有解决方案。
@@ -121,13 +135,13 @@ Example Output:
"让每一个产品决策都充满温度。"
{% else %}
Example Input:
- User ID: user_12345
- User Display Name: John
- Core Entities & Frequency: Product Manager (15), AI (12), San Francisco (10), Data Analysis (8), Team Collaboration (7)
- Representative Statement Samples: I have been working as a product manager in San Francisco for 5 years | I believe good products come from deep understanding of user needs | I enjoy playing a coordinating role in teams | Data-driven decision making is my work principle
Example Output:
【Basic Introduction】
This is a passionate senior product manager based in San Francisco. Over the past 5 years, they have focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. They believe good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
John is a passionate senior product manager based in San Francisco. Over the past 5 years, John has focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. John believes good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
【Personality Traits】
Outgoing personality with excellent communication skills and attention to detail. Enjoys playing a coordinating role in teams, helping everyone reach consensus. Maintains optimism when facing challenges, believing every problem has a solution.

View File

@@ -68,6 +68,8 @@ class BasePlatformAdapter(ABC):
self.branch_node_cache = defaultdict(list)
self.error_branch_node_cache = []
self.node_output_map = {}
@abstractmethod
def get_metadata(self) -> PlatformMetadata:
"""get platform metadata"""

View File

@@ -12,7 +12,7 @@ from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModel
ExceptionType
from app.core.workflow.nodes.assigner import AssignerNodeConfig
from app.core.workflow.nodes.assigner.config import AssignmentItem
from app.core.workflow.nodes.base_config import VariableDefinition
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
from app.core.workflow.nodes.code import CodeNodeConfig
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig
@@ -44,6 +44,7 @@ class DifyConverter(BaseConverter):
warnings: list
branch_node_cache: dict
error_branch_node_cache: list
node_output_map: dict
def __init__(self):
self.CONFIG_CONVERT_MAP = {
@@ -60,33 +61,53 @@ class DifyConverter(BaseConverter):
"knowledge-retrieval": self.convert_knowledge_node_config,
"parameter-extractor": self.convert_parameter_extractor_node_config,
"question-classifier": self.convert_question_classifier_node_config,
"variable-aggregator": self.convert_variable_aggregator,
"variable-aggregator": self.convert_variable_aggregator_node_config,
"tool": self.convert_tool_node_config,
"loop-start": lambda x: {},
"iteration-start": lambda x: {},
"loop-end": lambda x: {},
}
def get_node_convert(self, node_type):
func = self.CONFIG_CONVERT_MAP.get(node_type, None)
func = self.CONFIG_CONVERT_MAP.get(node_type, lambda x: {})
return func
def config_validate(
self,
node_id: str,
node_name: str,
config: type[BaseNodeConfig],
value: dict
):
try:
return config.model_validate(value)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.CONFIG,
node_id=node_id,
node_name=node_name,
detail=str(e)
))
return None
@staticmethod
def is_variable(expression) -> bool:
return bool(re.match(r"\{\{#(.*?)#}}", expression))
@staticmethod
def process_var_selector(var_selector):
def process_var_selector(self, var_selector):
if not var_selector:
return ""
selector = var_selector.split('.')
if len(selector) != 2:
if len(selector) not in [2, 3]:
raise Exception(f"invalid variable selector: {var_selector}")
if len(selector) == 3:
selector = selector[1:]
if selector[0] == "conversation":
selector[0] = "conv"
var_selector = ".".join(selector)
mapping = {
"sys.query": "sys.message"
}
"sys.query": "sys.message"
} | self.node_output_map
var_selector = mapping.get(var_selector, var_selector)
return var_selector
@@ -124,6 +145,8 @@ class DifyConverter(BaseConverter):
"checkbox": VariableType.BOOLEAN,
"file-list": VariableType.ARRAY_FILE,
"select": VariableType.STRING,
"integer": VariableType.NUMBER,
"float": VariableType.NUMBER,
}
var_type = type_map.get(source_type, source_type)
return var_type
@@ -160,6 +183,8 @@ class DifyConverter(BaseConverter):
"": ComparisonOperator.GE,
"": ComparisonOperator.LE,
"not empty": ComparisonOperator.NOT_EMPTY,
"start with": ComparisonOperator.START_WITH,
"end with": ComparisonOperator.END_WITH,
}
return operator_map.get(operator, operator)
@@ -232,7 +257,7 @@ class DifyConverter(BaseConverter):
node_id=node["id"],
node_name=node_data["title"],
name=var["variable"],
detail=f"Unsupport Variable type for start node: {var_type}"
detail=f"Unsupported Variable type for start node: {var_type}"
)
)
continue
@@ -248,9 +273,11 @@ class DifyConverter(BaseConverter):
max_length=var.get("max_length"),
)
start_vars.append(var_def)
return StartNodeConfig(
result = StartNodeConfig.model_construct(
variables=start_vars
).model_dump()
self.config_validate(node["id"], node["data"]["title"], StartNodeConfig, result)
return result
def convert_question_classifier_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -265,16 +292,18 @@ class DifyConverter(BaseConverter):
for category in node_data["classes"]:
self.branch_node_cache[node["id"]].append(category["id"])
categories.append(
ClassifierConfig(
ClassifierConfig.model_construct(
class_name=category["name"],
)
)
return QuestionClassifierNodeConfig.model_construct(
input_variable=self._process_list_variable_litearl(node_data["query_variable_selector"]),
user_supplement_prompt=self.trans_variable_format(node_data["instructions"]),
result = QuestionClassifierNodeConfig.model_construct(
input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")),
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
categories=categories,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], QuestionClassifierNodeConfig, result)
return result
def convert_llm_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -310,7 +339,7 @@ class DifyConverter(BaseConverter):
vision_input = self._process_list_variable_litearl(
node_data["vision"]["configs"]["variable_selector"]
) if vision else None
return LLMNodeConfig.model_construct(
result = LLMNodeConfig.model_construct(
model_id=None,
context=context,
memory=memory,
@@ -318,12 +347,16 @@ class DifyConverter(BaseConverter):
vision_input=vision_input,
messages=messages
).model_dump()
self.config_validate(node["id"], node["data"]["title"], LLMNodeConfig, result)
return result
def convert_end_node_config(self, node: dict) -> dict:
node_data = node["data"]
return EndNodeConfig(
output=self.trans_variable_format(node_data["answer"]),
result = EndNodeConfig.model_construct(
output=self.trans_variable_format(node_data.get("answer", "")),
).model_dump()
self.config_validate(node["id"], node["data"]["title"], EndNodeConfig, result)
return result
def convert_if_else_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -354,9 +387,11 @@ class DifyConverter(BaseConverter):
)
)
self.branch_node_cache[node["id"]].append(case_id)
return IfElseNodeConfig(
result = IfElseNodeConfig.model_construct(
cases=cases
).model_dump()
self.config_validate(node["id"], node["data"]["title"], IfElseNodeConfig, result)
return result
def convert_loop_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -365,7 +400,7 @@ class DifyConverter(BaseConverter):
for condition in node_data["break_conditions"]:
right_value = condition["value"]
conditions.append(
LoopConditionDetail(
LoopConditionDetail.model_construct(
operator=self.convert_compare_operator(condition["comparison_operator"]),
left=self._process_list_variable_litearl(condition["variable_selector"]),
right=self.trans_variable_format(
@@ -378,7 +413,7 @@ class DifyConverter(BaseConverter):
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
)
)
condition_config = ConditionsConfig(
condition_config = ConditionsConfig.model_construct(
logical_operator=logical_operator,
expressions=conditions
)
@@ -387,9 +422,9 @@ class DifyConverter(BaseConverter):
right_input_type = variable["value_type"]
right_value_type = self.variable_type_map(variable["var_type"])
if right_input_type == ValueInputType.VARIABLE:
right_value = self._process_list_variable_litearl(variable["value"])
right_value = self._process_list_variable_litearl(variable.get("value", ""))
else:
right_value = self.convert_variable_type(right_value_type, variable["value"])
right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
loop_variables.append(
CycleVariable(
name=variable["label"],
@@ -398,23 +433,28 @@ class DifyConverter(BaseConverter):
input_type=right_input_type
)
)
return LoopNodeConfig(
result = LoopNodeConfig.model_construct(
condition=condition_config,
cycle_vars=loop_variables,
max_loop=node_data["loop_count"]
max_loop=node_data.get("loop_count", 10)
).model_dump()
self.config_validate(node["id"], node["data"]["title"], LoopNodeConfig, result)
return result
def convert_iteration_node_config(self, node: dict) -> dict:
node_data = node["data"]
return IterationNodeConfig(
result = IterationNodeConfig.model_construct(
input=self._process_list_variable_litearl(node_data["iterator_selector"]),
parallel=node_data["is_parallel"],
parallel_count=node_data["parallel_nums"],
output=self._process_list_variable_litearl(node_data["output_selector"]),
output_type=self.variable_type_map(node_data["output_type"]),
output_type=self.variable_type_map(node_data.get("output_type")),
flatten=node_data["flatten_output"],
).model_dump()
self.config_validate(node["id"], node["data"]["title"], IterationNodeConfig, result)
return result
def convert_assigner_node_config(self, node: dict) -> dict:
node_data = node["data"]
assignments = []
@@ -430,16 +470,18 @@ class DifyConverter(BaseConverter):
operation=self.convert_assignment_operator(assignment["operation"])
)
)
return AssignerNodeConfig(
result = AssignerNodeConfig.model_construct(
assignments=assignments
).model_dump()
self.config_validate(node["id"], node["data"]["title"], AssignerNodeConfig, result)
return result
def convert_code_node_config(self, node: dict) -> dict:
node_data = node["data"]
input_variables = []
for input_variable in node_data["variables"]:
input_variables.append(
InputVariable(
InputVariable.model_construct(
name=input_variable["variable"],
variable=self._process_list_variable_litearl(input_variable["value_selector"]),
)
@@ -448,7 +490,7 @@ class DifyConverter(BaseConverter):
output_variables = []
for output_variable in node_data["outputs"]:
output_variables.append(
OutputVariable(
OutputVariable.model_construct(
name=output_variable,
type=node_data["outputs"][output_variable]["type"],
)
@@ -456,18 +498,20 @@ class DifyConverter(BaseConverter):
code = base64.b64encode(quote(node_data["code"]).encode("utf-8")).decode("utf-8")
return CodeNodeConfig(
result = CodeNodeConfig.model_construct(
input_variables=input_variables,
language=node_data["code_language"],
output_variables=output_variables,
code=code
).model_dump()
self.config_validate(node["id"], node["data"]["title"], CodeNodeConfig, result)
return result
def convert_http_node_config(self, node: dict) -> dict:
node_data = node["data"]
if node_data["authorization"] != 'no-auth':
if node_data["authorization"]["type"] != 'no-auth':
auth_type = self.convert_http_auth_type(node_data["authorization"]["config"]["type"])
auth_config = HttpAuthConfig(
auth_config = HttpAuthConfig.model_construct(
auth_type=auth_type,
header=node_data["authorization"]["config"].get("header"),
api_key=node_data["authorization"]["config"].get("api_key"),
@@ -499,7 +543,7 @@ class DifyConverter(BaseConverter):
body_content = ""
headers = {}
for header in node_data["headers"].split("\n"):
for header in node_data.get("headers", "").split("\n"):
if not header:
continue
@@ -517,7 +561,7 @@ class DifyConverter(BaseConverter):
))
params = {}
for param in node_data["params"].split("\n"):
for param in node_data.get("params", "").split("\n"):
if not param:
continue
@@ -542,7 +586,7 @@ class DifyConverter(BaseConverter):
default_body = ""
default_header = {}
default_status_code = 0
for var in node_data["default_value"]:
for var in node_data.get("default_value") or []:
if var["key"] == "body":
default_body = var["value"]
elif var["key"] == "header":
@@ -556,45 +600,50 @@ class DifyConverter(BaseConverter):
)
self.error_branch_node_cache.append(node['id'])
return HttpRequestNodeConfig(
result = HttpRequestNodeConfig.model_construct(
method=node_data["method"].upper(),
url=node_data["url"],
auth=auth_config,
body=HttpContentTypeConfig(
body=HttpContentTypeConfig.model_construct(
content_type=self.convert_http_content_type(node_data["body"]["type"]),
data=body_content,
),
headers=headers,
params=params,
verify_ssl=node_data["ssl_verify"],
timeouts=HttpTimeOutConfig(
timeouts=HttpTimeOutConfig.model_construct(
connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5,
read_timeout=node_data["timeout"]["max_read_timeout"] or 5,
write_timeout=node_data["timeout"]["max_write_timeout"] or 5,
),
retry=HttpRetryConfig(
retry=HttpRetryConfig.model_construct(
enable=node_data["retry_config"]["retry_enabled"],
max_attempts=node_data["retry_config"]["max_retries"],
retry_interval=node_data["retry_config"]["retry_interval"],
),
error_handle=HttpErrorHandleConfig(
error_handle=HttpErrorHandleConfig.model_construct(
method=error_handle_type,
default=default_value,
)
).model_dump()
self.config_validate(node["id"], node["data"]["title"], HttpRequestNodeConfig, result)
return result
def convert_jinja_render_node_config(self, node: dict) -> dict:
node_data = node["data"]
mapping = []
for variable in node_data["variables"]:
mapping.append(VariablesMappingConfig(
mapping.append(VariablesMappingConfig.model_construct(
name=variable["variable"],
value=self._process_list_variable_litearl(variable["value_selector"])
))
return JinjaRenderNodeConfig(
result = JinjaRenderNodeConfig.model_construct(
template=node_data["template"],
mapping=mapping,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], JinjaRenderNodeConfig, result)
return result
def convert_knowledge_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -604,10 +653,13 @@ class DifyConverter(BaseConverter):
type=ExceptionType.CONFIG,
detail=f"Please reconfigure the Knowledge Retrieval node.",
))
return KnowledgeRetrievalNodeConfig.model_construct(
result = KnowledgeRetrievalNodeConfig.model_construct(
query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
).model_dump()
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
return result
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(
@@ -618,42 +670,59 @@ class DifyConverter(BaseConverter):
)
)
params = []
for param in node_data["parameters"]:
for param in node_data.get("parameters", []):
params.append(
ParamsConfig(
ParamsConfig.model_construct(
name=param["name"],
desc=param["description"],
required=param["required"],
type=param["type"],
)
)
return ParameterExtractorNodeConfig.model_construct(
result = ParameterExtractorNodeConfig.model_construct(
text=self._process_list_variable_litearl(node_data["query"]),
params=params,
prompt=node_data["instruction"]
prompt=node_data.get("instruction")
).model_dump()
def convert_variable_aggregator(self, node: dict) -> dict:
self.config_validate(node["id"], node["data"]["title"], ParameterExtractorNodeConfig, result)
return result
def convert_variable_aggregator_node_config(self, node: dict) -> dict:
node_data = node["data"]
group_enable = node_data["advanced_settings"]["group_enabled"]
advanced_settings = node_data.get("advanced_settings", {})
group_variables = {}
group_type = {}
if not group_enable:
if not advanced_settings or not advanced_settings["group_enabled"]:
group_variables["output"] = [
self._process_list_variable_litearl(variable)
for variable in node_data["variables"]
]
group_type["output"] = node_data["output_type"]
else:
for group in node_data["advanced_settings"]["groups"]:
for group in advanced_settings["groups"]:
group_variables[group["group_name"]] = [
self._process_list_variable_litearl(variable)
for variable in group["variables"]
]
group_type[group["group_name"]] = group["output_type"]
return VariableAggregatorNodeConfig(
group=group_enable,
result = VariableAggregatorNodeConfig.model_construct(
group=advanced_settings.get("group_enabled", False),
group_variables=group_variables,
group_type=group_type,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], VariableAggregatorNodeConfig, result)
return result
def convert_tool_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(ExceptionDefineition(
node_id=node["id"],
node_name=node_data["title"],
type=ExceptionType.CONFIG,
detail=f"Please reconfigure the tool node.",
))
return {}

View File

@@ -43,7 +43,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
"question-classifier": NodeType.QUESTION_CLASSIFIER,
"variable-aggregator": NodeType.VAR_AGGREGATOR
"variable-aggregator": NodeType.VAR_AGGREGATOR,
"tool": NodeType.TOOL
}
def __init__(self, config: dict[str, Any]):
@@ -58,7 +59,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
)
def map_node_type(self, platform_node_type) -> str:
return self.NODE_TYPE_MAPPING.get(platform_node_type)
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
@property
def origin_nodes(self):
@@ -89,6 +90,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
return True
def parse_workflow(self) -> WorkflowParserResult:
self._init_node_output_map()
for node in self.origin_nodes:
node = self._convert_node(node)
if node:
@@ -128,6 +130,11 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
errors=self.errors
)
def _init_node_output_map(self):
for node in self.origin_nodes:
if self.map_node_type(node["data"]["type"]) == NodeType.LLM:
self.node_output_map[f"{node['id']}.text"] = f"{node['id']}.output"
def _convert_cycle_node_position(self, node_id: str, position: dict):
for node in self.origin_nodes:
if node["id"] == node_id:
@@ -172,8 +179,13 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
node_type = node_data["type"]
try:
converter = self.get_node_convert(node_type)
if converter is None:
raise Exception(f"node type not supported - {node_type}")
if node_type not in self.CONFIG_CONVERT_MAP:
self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE,
node_id=node["id"],
node_name=node["data"]["title"],
detail=f"node type {node_type} is unsupported",
))
return converter(node)
except Exception as e:
self.errors.append(ExceptionDefineition(
@@ -214,6 +226,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
type=ExceptionType.EDGE,
detail=f"convert edge error - {e}",
))
logger.debug(f"convert edge error - {e}", exc_info=True)
return None
def _convert_variable(self, variable) -> VariableDefinition | None:
@@ -221,7 +234,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
return VariableDefinition(
name=variable["name"],
default=variable["value"],
type=variable["value_type"],
type=self.variable_type_map(variable["value_type"]),
)
except Exception as e:
self.errors.append(ExceptionDefineition(

View File

@@ -73,7 +73,7 @@ class VariableStruct(BaseModel, Generic[T]):
instance:
The concrete variable object. The actual Python type is
represented by the generic parameter ``T`` (e.g. StringVariable,
NumberVariable, ArrayObject[StringVariable]).
NumberVariable, ArrayVariable[StringVariable]).
mut:
Whether the variable is mutable.
"""
@@ -152,6 +152,36 @@ class VariablePool:
return None
return var_instance
def get_instance(
self,
selector: str,
default: Any = None,
strict: bool = True
):
"""Retrieve a variable instance from the variable pool.
Args:
selector:
Variable selector as a string variable literal (e.g. "{{ sys.message }}").
default:
The value to return if the variable does not exist.
strict:
If True, raises KeyError when the variable does not exist.
Returns:
The variable instance object if it exists; otherwise returns `default`.
Raises:
KeyError: If strict is True and the variable does not exist.
"""
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
if strict:
raise KeyError(f"{selector} not exist")
return default
return variable_struct.instance
def get_value(
self,
selector: str,

View File

@@ -132,24 +132,24 @@ class WorkflowExecutor:
start_time = datetime.datetime.now()
# Build the workflow graph
graph = self.build_graph()
# Initialize the variable pool with input data
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
# Execute the workflow
try:
# Build the workflow graph
graph = self.build_graph()
# Initialize the variable pool with input data
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
# Aggregate output from all End nodes
@@ -175,7 +175,7 @@ class WorkflowExecutor:
elapsed_time = (end_time - start_time).total_seconds()
logger.info(
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}s")
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
@@ -231,23 +231,23 @@ class WorkflowExecutor:
}
}
# Build the workflow graph in streaming mode
graph = self.build_graph(stream=True)
# Initialize the variable pool and system variables
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
try:
# Build the workflow graph in streaming mode
graph = self.build_graph(stream=True)
# Initialize the variable pool and system variables
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
full_content = ''
self.stream_coordinator.update_scope_activation("sys")
@@ -322,7 +322,7 @@ class WorkflowExecutor:
)
logger.info(
f"Workflow execution completed (streaming), "
f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_context.execution_id}"
f"elapsed: {elapsed_time:.2f}ms, execution_id: {self.execution_context.execution_id}"
)
yield {

View File

@@ -196,7 +196,7 @@ class BaseNode(ABC):
timeout=timeout
)
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
# Extract processed outputs using subclass-defined logic.
extracted_output = self._extract_output(business_result)
@@ -219,7 +219,7 @@ class BaseNode(ABC):
} | self.trans_activate(state)
except TimeoutError:
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.error(
f"Node {self.node_id} execution timed out ({timeout} seconds)."
)
@@ -230,7 +230,7 @@ class BaseNode(ABC):
variable_pool,
)
except Exception as e:
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.error(
f"Node {self.node_id} execution failed: {e}",
exc_info=True,
@@ -307,10 +307,10 @@ class BaseNode(ABC):
"done": done
})
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.info(f"Node {self.node_id} streaming execution finished, "
f"time elapsed: {elapsed_time:.2f}s, chunks: {chunk_count}")
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result)
@@ -337,7 +337,7 @@ class BaseNode(ABC):
yield state_update | self.trans_activate(state)
except TimeoutError:
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.error(f"Node {self.node_id} execution timed out ({timeout}s)")
error_output = self._wrap_error(
f"Node execution timed out ({timeout}s)",
@@ -347,7 +347,7 @@ class BaseNode(ABC):
)
yield error_output
except Exception as e:
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True)
error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool)
yield error_output

View File

@@ -66,7 +66,7 @@ class CycleGraphNode(BaseNode):
if config.flatten:
outputs['output'] = config.output_type
else:
outputs['output'] = VariableType.ARRAY_STRING
outputs['output'] = VariableType.NESTED_ARRAY
else:
outputs['output'] = VariableType(f"array[{config.output_type}]")
return outputs

View File

@@ -24,6 +24,8 @@ class NodeType(StrEnum):
MEMORY_READ = "memory-read"
MEMORY_WRITE = "memory-write"
UNKNOWN = "unknown"
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]

View File

@@ -1,6 +1,7 @@
import asyncio
import json
import logging
import uuid
from typing import Any, Callable, Coroutine
import httpx
@@ -13,6 +14,7 @@ from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
logger = logging.getLogger(__file__)
@@ -115,7 +117,7 @@ class HttpRequestNode(BaseNode):
params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
return params
def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
async def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
"""
Build HTTP request body arguments for httpx request methods.
@@ -135,16 +137,35 @@ class HttpRequestNode(BaseNode):
))
case HttpContentType.FROM_DATA:
data = {}
content["files"] = {}
for item in self.typed_config.body.data:
if item.type == "text":
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool)
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value,
variable_pool)
elif item.type == "file":
# TODO: File support (Feature)
pass
content["files"][self._render_template(item.key, variable_pool)] = (
uuid.uuid4().hex,
await variable_pool.get_instance(item.value).get_content()
)
content["data"] = data
case HttpContentType.BINARY:
# TODO: File support (Feature)
pass
content["files"] = []
file_instence = variable_pool.get_instance(self.typed_config.body.data)
if isinstance(file_instence, ArrayVariable):
for v in file_instence.value:
if isinstance(v, FileVariable):
content["files"].append(
(
"files", (uuid.uuid4().hex, await v.get_content())
)
)
elif isinstance(file_instence, FileVariable):
content["files"].append(
(
"file", (uuid.uuid4().hex, await file_instence.get_content())
)
)
case HttpContentType.WWW_FORM:
content["data"] = json.loads(self._render_template(
json.dumps(self.typed_config.body.data), variable_pool
@@ -207,7 +228,7 @@ class HttpRequestNode(BaseNode):
request_func = self._get_client_method(client)
resp = await request_func(
url=self._render_template(self.typed_config.url, variable_pool),
**self._build_content(variable_pool)
**(await self._build_content(variable_pool))
)
resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded")

View File

@@ -123,10 +123,10 @@ class NodeFactory:
# 获取节点类
node_class = cls._node_types.get(node_type)
if not node_class:
raise ValueError(f"不支持的节点类型: {node_type}")
raise ValueError(f"Unsupported node type: {node_type}")
# 创建节点实例
logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})")
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config)
@classmethod

View File

@@ -12,9 +12,20 @@ class ExpressionEvaluator:
# Reserved namespaces
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
@staticmethod
@classmethod
def normalize_template(cls, template: str) -> str:
pattern = re.compile(
r"\{\{\s*(\d+)\.(\w+)\s*}}"
)
return pattern.sub(
r'{{ node["\1"].\2 }}',
template
)
@classmethod
def evaluate(
cls,
expression: str,
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
@@ -37,6 +48,7 @@ class ExpressionEvaluator:
"""
# Remove Jinja2-style brackets if present
expression = expression.strip()
expression = cls.normalize_template(expression)
pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", expression).strip()

View File

@@ -5,6 +5,7 @@
"""
import logging
import re
from typing import Any
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
@@ -39,6 +40,16 @@ class TemplateRenderer:
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
)
@staticmethod
def normalize_template(template: str) -> str:
pattern = re.compile(
r"\{\{\s*(\d+)\.(\w+)\s*}}"
)
return pattern.sub(
r'{{ node["\1"].\2 }}',
template
)
def render(
self,
template: str,
@@ -95,7 +106,7 @@ class TemplateRenderer:
context.update(conv_vars)
context["nodes"] = node_outputs or {} # 旧语法兼容
template = self.normalize_template(template)
try:
tmpl = self.env.from_string(template)
return tmpl.render(**context)

View File

@@ -1,8 +1,10 @@
from typing import Any, TypeVar, Type, Generic
import httpx
from deprecated import deprecated
from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType
from app.core.config import settings
T = TypeVar("T", bound=BaseVariable)
@@ -80,8 +82,23 @@ class FileVariable(BaseVariable):
def get_value(self) -> Any:
return self.value.model_dump()
async def get_content(self):
total_bytes = 0
chunks = []
class ArrayObject(BaseVariable, Generic[T]):
async with httpx.AsyncClient() as client:
async with client.stream("GET", self.value.url) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes(8192):
total_bytes += len(chunk)
if total_bytes > settings.MAX_FILE_SIZE:
raise ValueError(f"File too large: {total_bytes} bytes")
chunks.append(chunk)
return b"".join(chunks)
class ArrayVariable(BaseVariable, Generic[T]):
type = 'array'
def __init__(self, child_type: Type[T], value: list[Any]):
@@ -108,7 +125,7 @@ class ArrayObject(BaseVariable, Generic[T]):
return [v.get_value() for v in self.value]
class NestedArrayObject(BaseVariable):
class NestedArrayVariable(BaseVariable):
type = 'array_nest'
def valid_value(self, value: list[T]) -> list[T]:
@@ -116,23 +133,23 @@ class NestedArrayObject(BaseVariable):
raise TypeError(f"Value must be a list - {type(value)}:{value}")
final_value = []
for v in value:
if not isinstance(v, ArrayObject):
if not isinstance(v, list):
raise TypeError("All elements must be of type list")
final_value.append(v)
final_value.append(make_array(AnyVariable, v))
return final_value
def to_literal(self) -> str:
return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value])
return "\n".join(["\n".join([str(item) for item in row.get_value()]) for row in self.value])
def get_value(self) -> Any:
return [[item.get_value() for item in row] for row in self.value]
return [[item for item in row.get_value()] for row in self.value]
@deprecated(
reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.",
category=RuntimeWarning
)
class AnyObject(BaseVariable):
class AnyVariable(BaseVariable):
type = 'any'
def valid_value(self, value: Any) -> Any:
@@ -142,10 +159,10 @@ class AnyObject(BaseVariable):
return str(self.value)
def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
"""简化 ArrayObject 创建,不需要重复写类型"""
def make_array(child_type: Type[T], value: list[Any]) -> ArrayVariable[T]:
"""简化 ArrayVariable 创建,不需要重复写类型"""
return ArrayObject(child_type, value)
return ArrayVariable(child_type, value)
def create_variable_instance(var_type: VariableType, value: Any) -> T:
@@ -168,7 +185,9 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
return make_array(DictVariable, value)
case VariableType.ARRAY_FILE:
return make_array(FileVariable, value)
case VariableType.NESTED_ARRAY:
return NestedArrayVariable(value)
case VariableType.ANY:
return AnyObject(value)
return AnyVariable(value)
case _:
raise TypeError(f"Invalid type - {var_type}")

View File

@@ -9,7 +9,7 @@ Classes:
import datetime
import uuid
from sqlalchemy import Column, String, DateTime, Text, ForeignKey
from sqlalchemy import Column, String, DateTime, Text, ForeignKey, Boolean
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.db import Base
@@ -25,6 +25,9 @@ class OntologyClass(Base):
# 类型信息
class_name = Column(String(200), nullable=False, comment="类型名称")
class_description = Column(Text, nullable=True, comment="类型描述")
# 系统默认标识
is_system_default = Column(Boolean, default=False, nullable=False, comment="是否为系统默认类型")
# 外键:关联到本体场景
scene_id = Column(UUID(as_uuid=True), ForeignKey("ontology_scene.scene_id", ondelete="CASCADE"), nullable=False, index=True, comment="所属场景ID")

View File

@@ -9,7 +9,7 @@ Classes:
import datetime
import uuid
from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint
from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint, Boolean
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.db import Base
@@ -28,6 +28,9 @@ class OntologyScene(Base):
# 场景信息
scene_name = Column(String(200), nullable=False, comment="场景名称")
scene_description = Column(Text, nullable=True, comment="场景描述")
# 系统默认标识
is_system_default = Column(Boolean, default=False, nullable=False, index=True, comment="是否为系统默认场景")
# 外键:关联到工作空间
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False, index=True, comment="所属工作空间ID")

View File

@@ -428,19 +428,17 @@ class ModelConfigRepository:
try:
# 查询ModelConfig关联的ModelApiKey筛选出匹配的model_config_id
model_config_ids = db.query(ModelConfig.id).join(
ModelBase, ModelConfig.model_id == ModelBase.id
).filter(
model_config_ids = db.query(ModelConfig.id).filter(
and_(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public
),
ModelBase.provider == provider,
ModelConfig.provider == provider,
ModelConfig.is_active,
~ModelConfig.is_composite
)
).distinct().all()
).all()
db_logger.debug(f"查询成功: 数量={len(model_config_ids)}")
return [row[0] for row in model_config_ids]

View File

@@ -68,7 +68,7 @@ class WorkflowImportSave(BaseModel):
"""工作流导入请求"""
temp_id: str
name: str
description: str
description: str | None = Field(default=None)
# ==================== 工作流配置 ====================

View File

@@ -816,11 +816,10 @@ class MemoryAgentService:
"""
统计知识库类型分布,包含:
1. PostgreSQL 中的知识库类型General, Web, Third-party, Folder根据 workspace_id 过滤)
2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤)
3. total: 所有类型的总和
2. total: 所有类型的总和
参数:
- end_user_id: 用户组ID可选未提供时 memory 统计为 0
- end_user_id: 用户组ID可选保留参数以保持接口兼容性
- only_active: 是否仅统计有效记录
- current_workspace_id: 当前工作空间ID可选未提供时知识库统计为 0
- db: 数据库会话
@@ -831,7 +830,6 @@ class MemoryAgentService:
"Web": count,
"Third-party": count,
"Folder": count,
"memory": chunk_count,
"total": sum_of_all
}
"""
@@ -878,51 +876,8 @@ class MemoryAgentService:
logger.error(f"知识库类型统计失败: {e}")
raise Exception(f"知识库类型统计失败: {e}")
# 2. 统计 Neo4j 中的 memory 总量(统计当前空间下所有宿主的 Chunk 总数)
try:
if current_workspace_id:
# 获取当前空间下的所有宿主
from app.repositories import app_repository, end_user_repository
from app.schemas.app_schema import App as AppSchema
from app.schemas.end_user_schema import EndUser as EndUserSchema
# 查询应用并转换为 Pydantic 模型
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
apps = [AppSchema.model_validate(h) for h in apps_orm]
app_ids = [app.id for app in apps]
# 获取所有宿主
end_users = []
for app_id in app_ids:
end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id)
end_users.extend(h for h in end_user_orm_list)
# 统计所有宿主的 Chunk 总数
total_chunks = 0
for end_user in end_users:
end_user_id_str = str(end_user.id)
memory_query = """
MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN count(n) AS Count
"""
neo4j_result = await _neo4j_connector.execute_query(
memory_query,
end_user_id=end_user_id_str,
)
chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0
total_chunks += chunk_count
logger.debug(f"EndUser {end_user_id_str} Chunk数量: {chunk_count}")
result["memory"] = total_chunks
logger.info(f"Neo4j memory统计成功: 总Chunk数={total_chunks}, 宿主数={len(end_users)}")
else:
# 没有 workspace_id 时,返回 0
result["memory"] = 0
logger.info("未提供 workspace_idmemory 统计为 0")
except Exception as e:
logger.error(f"Neo4j memory统计失败: {e}", exc_info=True)
# 如果 Neo4j 查询失败memory 设为 0
result["memory"] = 0
# 2. 统计 Neo4j 中的 memory 总量已移除
# memory 字段不再返回
# 3. 计算知识库类型总和(不包括 memory
result["total"] = (

View File

@@ -101,34 +101,141 @@ async def run_pilot_extraction(
)
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
await progress_callback("text_preprocessing", "开始预处理文本(语义剪枝 + 语义分块)...")
# ========== 步骤 2.1: 语义剪枝 ==========
pruned_dialogs = [dialog]
deleted_messages = [] # 记录被删除的消息
pruning_stats = None # 保存剪枝统计信息,用于最终汇总
if memory_config.pruning_enabled:
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
SemanticPruner,
)
from app.core.memory.models.config_models import PruningConfig
# 构建剪枝配置
pruning_config_dict = {
"pruning_switch": memory_config.pruning_enabled,
"pruning_scene": memory_config.pruning_scene,
"pruning_threshold": memory_config.pruning_threshold,
"llm_model_id": str(memory_config.llm_model_id),
}
config = PruningConfig(**pruning_config_dict)
logger.info(f"[PILOT_RUN] 开始语义剪枝: scene={config.pruning_scene}, threshold={config.pruning_threshold}")
# 记录剪枝前的消息(用于对比)
original_messages = [{"role": msg.role, "content": msg.msg} for msg in dialog.context.msgs]
original_msg_count = len(original_messages)
# 执行剪枝
pruner = SemanticPruner(config=config, llm_client=llm_client)
pruned_dialogs = await pruner.prune_dataset([dialog])
# 计算剪枝结果并找出被删除的消息
if pruned_dialogs and pruned_dialogs[0].context:
remaining_messages = [{"role": msg.role, "content": msg.msg} for msg in pruned_dialogs[0].context.msgs]
remaining_msg_count = len(remaining_messages)
deleted_msg_count = original_msg_count - remaining_msg_count
# 找出被删除的消息(基于索引精确匹配)
# 为剩余消息创建带索引的列表,用于精确追踪
remaining_with_index = []
remaining_idx = 0
for orig_idx, orig_msg in enumerate(original_messages):
if remaining_idx < len(remaining_messages) and \
orig_msg["role"] == remaining_messages[remaining_idx]["role"] and \
orig_msg["content"] == remaining_messages[remaining_idx]["content"]:
remaining_with_index.append(orig_idx)
remaining_idx += 1
# 找出未在保留列表中的消息索引
deleted_messages = [
{"index": idx, "role": msg["role"], "content": msg["content"]}
for idx, msg in enumerate(original_messages)
if idx not in remaining_with_index
]
# 保存剪枝统计信息用于最终汇总只保留deleted_count
pruning_stats = {
"enabled": True,
"scene": config.pruning_scene,
"threshold": config.pruning_threshold,
"deleted_count": deleted_msg_count,
}
# 输出剪枝结果(显示删除的消息详情)
pruning_result = {
"type": "pruning",
"deleted_messages": deleted_messages,
}
logger.info(
f"[PILOT_RUN] 语义剪枝完成: 原始{original_msg_count}条 -> "
f"保留{remaining_msg_count}条 (删除{deleted_msg_count}条)"
)
if progress_callback:
await progress_callback("text_preprocessing_result", "语义剪枝完成", pruning_result)
else:
logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话")
pruned_dialogs = [dialog]
except Exception as e:
logger.error(f"[PILOT_RUN] 语义剪枝失败,使用原始对话: {e}", exc_info=True)
pruned_dialogs = [dialog]
if progress_callback:
error_result = {
"type": "pruning",
"error": str(e),
"fallback": "使用原始对话"
}
await progress_callback("text_preprocessing_result", "语义剪枝失败", error_result)
else:
logger.info("[PILOT_RUN] 语义剪枝已关闭,跳过")
pruning_stats = {
"enabled": False,
}
# ========== 步骤 2.2: 语义分块 ==========
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
data=pruned_dialogs,
chunker_strategy=memory_config.chunker_strategy,
llm_client=llm_client,
)
logger.info(f"Processed dialogue text: {len(messages)} messages")
remaining_msg_count = len(pruned_dialogs[0].context.msgs) if pruned_dialogs and pruned_dialogs[0].context else 0
logger.info(f"Processed dialogue text: {remaining_msg_count} messages after pruning")
# 进度回调:输出每个分块的结果
if progress_callback:
for dlg in chunked_dialogs:
for i, chunk in enumerate(dlg.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dlg.id,
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
if hasattr(dlg, 'chunks') and dlg.chunks:
for i, chunk in enumerate(dlg.chunks):
chunk_result = {
"type": "chunking",
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dlg.id,
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
# 构建预处理完成总结(包含剪枝统计)
preprocessing_summary = {
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs if hasattr(dlg, 'chunks') and dlg.chunks),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
# 添加剪枝统计信息
if pruning_stats:
preprocessing_summary["pruning"] = pruning_stats
await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)

View File

@@ -1163,11 +1163,32 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
"""
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
from app.core.language_utils import validate_language
from app.repositories.end_user_repository import EndUserRepository
from app.db import get_db
import re
# 验证语言参数
language = validate_language(language)
# 获取用户的 other_name 字段
user_display_name = "该用户" if language == "zh" else "the user"
if end_user_id:
try:
# 获取数据库会话并查询用户信息
db = next(get_db())
try:
repo = EndUserRepository(db)
end_user = repo.get_by_id(uuid.UUID(end_user_id))
if end_user and end_user.other_name:
user_display_name = end_user.other_name
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
else:
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
finally:
db.close()
except Exception as e:
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")
# 创建 UserSummaryHelper 实例
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_end_user_id", "group_123"))
@@ -1184,7 +1205,8 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
user_id=user_summary_tool.user_id,
entities=", ".join(entity_lines) if entity_lines else "(空)" if language == "zh" else "(empty)",
statements=" | ".join(statement_samples) if statement_samples else "(空)" if language == "zh" else "(empty)",
language=language
language=language,
user_display_name=user_display_name
)
messages = [

View File

@@ -580,6 +580,7 @@ class WorkflowService:
# "variables": result.get("variables"),
# "messages": result.get("messages"),
"output": result.get("output"), # 最终输出(字符串)
"message": result.get("output"), # 最终输出(字符串)
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"),

View File

@@ -30,6 +30,7 @@ from app.schemas.workspace_schema import (
WorkspaceModelsUpdate,
WorkspaceUpdate,
)
from app.config.default_ontology_initializer import DefaultOntologyInitializer
# 获取业务逻辑专用日志器
business_logger = get_business_logger()
@@ -129,7 +130,7 @@ def _create_workspace_only(
raise
def create_workspace(
db: Session, workspace: WorkspaceCreate, user: User
db: Session, workspace: WorkspaceCreate, user: User, language: str = "zh"
) -> Workspace:
business_logger.info(
f"创建工作空间: {workspace.name}, 创建者: {user.username}, "
@@ -145,10 +146,68 @@ def create_workspace(
db=db, workspace=workspace, tenant_id=user.tenant_id
)
business_logger.info(f"工作空间创建成功: {db_workspace.name} (ID: {db_workspace.id}), 创建者: {user.username}")
db.commit()
db.flush() # 使用 flush 而不是 commit获取 ID 但不提交事务
db.refresh(db_workspace)
# Initialize default ontology scenes for the workspace (先创建本体场景)
default_scene_id = None
try:
initializer = DefaultOntologyInitializer(db)
success, error_msg = initializer.initialize_default_scenes(
db_workspace.id, language=language
)
if success:
business_logger.info(
f"为工作空间 {db_workspace.id} 创建默认本体场景成功 (language={language})"
)
# 获取默认场景ID优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景
from app.repositories.ontology_scene_repository import OntologySceneRepository
from app.config.default_ontology_config import (
ONLINE_EDUCATION_SCENE,
EMOTIONAL_COMPANION_SCENE,
get_scene_name
)
scene_repo = OntologySceneRepository(db)
# 优先尝试获取教育场景
education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language)
education_scene = scene_repo.get_by_name(education_scene_name, db_workspace.id)
if education_scene:
default_scene_id = education_scene.scene_id
business_logger.info(
f"获取到教育场景ID用于默认记忆配置: {default_scene_id} (scene_name={education_scene_name})"
)
else:
# 如果教育场景不存在,尝试获取情感陪伴场景
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
companion_scene = scene_repo.get_by_name(companion_scene_name, db_workspace.id)
if companion_scene:
default_scene_id = companion_scene.scene_id
business_logger.info(
f"教育场景不存在使用情感陪伴场景ID用于默认记忆配置: {default_scene_id} (scene_name={companion_scene_name})"
)
else:
business_logger.warning(
f"未找到任何默认场景 (education={education_scene_name}, companion={companion_scene_name})"
)
else:
business_logger.warning(
f"为工作空间 {db_workspace.id} 创建默认本体场景失败: {error_msg} (language={language})"
)
except Exception as ontology_error:
business_logger.error(
f"为工作空间 {db_workspace.id} 创建默认本体场景异常: {str(ontology_error)} (language={language})"
)
# Don't fail workspace creation if default ontology initialization fails
# The workspace can still function without default ontology scenes
# Create default memory config for the workspace (only for neo4j storage types)
# 将默认场景ID教育场景或情感陪伴场景关联到记忆配置
if workspace.storage_type == 'neo4j':
try:
_create_default_memory_config(
@@ -158,9 +217,10 @@ def create_workspace(
llm_id=llm,
embedding_id=embedding,
rerank_id=rerank,
scene_id=default_scene_id, # 传入默认场景ID优先教育场景其次情感陪伴场景
)
business_logger.info(
f"为工作空间 {db_workspace.id} 创建默认记忆配置成功"
f"为工作空间 {db_workspace.id} 创建默认记忆配置成功 (scene_id={default_scene_id})"
)
except Exception as mc_error:
business_logger.error(
@@ -209,7 +269,6 @@ def create_workspace(
db=db,
knowledge=knowledge_data
)
db.commit()
business_logger.info(
f"为工作空间 {db_workspace.id} 自动创建知识库成功: "
f"{db_knowledge.name} (ID: {db_knowledge.id})"
@@ -224,6 +283,12 @@ def create_workspace(
BizCode.INTERNAL_ERROR
)
# 统一提交所有更改
db.commit()
business_logger.info(
f"工作空间 {db_workspace.id} 及相关资源创建完成并已提交"
)
return db_workspace
except Exception as e:
@@ -919,6 +984,43 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
f"Workspace {workspace.id} missing default memory config, creating one"
)
# 尝试获取默认场景ID优先教育场景其次情感陪伴场景
default_scene_id = None
try:
from app.repositories.ontology_scene_repository import OntologySceneRepository
from app.config.default_ontology_config import (
ONLINE_EDUCATION_SCENE,
EMOTIONAL_COMPANION_SCENE,
get_scene_name
)
scene_repo = OntologySceneRepository(db)
# 尝试中文和英文场景名称
for language in ["zh", "en"]:
# 优先尝试教育场景
education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language)
education_scene = scene_repo.get_by_name(education_scene_name, workspace.id)
if education_scene:
default_scene_id = education_scene.scene_id
business_logger.info(
f"找到教育场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={education_scene_name}"
)
break
# 如果教育场景不存在,尝试情感陪伴场景
companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language)
companion_scene = scene_repo.get_by_name(companion_scene_name, workspace.id)
if companion_scene:
default_scene_id = companion_scene.scene_id
business_logger.info(
f"教育场景不存在,找到情感陪伴场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={companion_scene_name}"
)
break
except Exception as scene_error:
business_logger.warning(
f"获取默认场景失败,将创建不关联场景的记忆配置: {str(scene_error)}"
)
try:
_create_default_memory_config(
db=db,
@@ -927,6 +1029,7 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None:
llm_id=uuid.UUID(workspace.llm) if workspace.llm else None,
embedding_id=uuid.UUID(workspace.embedding) if workspace.embedding else None,
rerank_id=uuid.UUID(workspace.rerank) if workspace.rerank else None,
scene_id=default_scene_id, # 传入默认场景ID优先教育场景其次情感陪伴场景
)
except Exception as e:
business_logger.error(
@@ -1008,6 +1111,7 @@ def _create_default_memory_config(
llm_id: Optional[uuid.UUID] = None,
embedding_id: Optional[uuid.UUID] = None,
rerank_id: Optional[uuid.UUID] = None,
scene_id: Optional[uuid.UUID] = None,
) -> None:
"""Create a default memory config for a newly created workspace.
@@ -1018,6 +1122,7 @@ def _create_default_memory_config(
llm_id: Optional LLM model ID
embedding_id: Optional embedding model ID
rerank_id: Optional rerank model ID
scene_id: Optional ontology scene ID (默认关联教育场景)
"""
from app.models.memory_config_model import MemoryConfig
@@ -1031,12 +1136,13 @@ def _create_default_memory_config(
llm_id=str(llm_id) if llm_id else None,
embedding_id=str(embedding_id) if embedding_id else None,
rerank_id=str(rerank_id) if rerank_id else None,
scene_id=scene_id, # 关联本体场景ID
state=True, # Active by default
is_default=True, # Mark as workspace default
)
db.add(default_config)
db.commit()
db.flush() # 使用 flush 而不是 commit让调用者统一提交
business_logger.info(
"Created default memory config for workspace",
@@ -1044,5 +1150,6 @@ def _create_default_memory_config(
"workspace_id": str(workspace_id),
"config_id": str(config_id),
"config_name": default_config.config_name,
"scene_id": str(scene_id) if scene_id else None,
}
)

View File

@@ -1304,6 +1304,203 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
"workspace_id": workspace_id,
"elapsed_time": elapsed_time,
}
@celery_app.task(
name="app.tasks.write_all_workspaces_memory_task",
bind=True,
ignore_result=False,
max_retries=3,
acks_late=True,
time_limit=3600,
soft_time_limit=3300,
)
def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
"""定时任务:遍历所有工作空间,统计并写入记忆增量
此任务会:
1. 查询所有活跃的工作空间
2. 对每个工作空间统计记忆总量
3. 将统计结果写入 memory_increments 表
Returns:
包含任务执行结果的字典
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_api_logger
from app.models.workspace_model import Workspace
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.repositories.memory_increment_repository import write_memory_increment
from app.services.memory_storage_service import search_all
api_logger = get_api_logger()
with get_db_context() as db:
try:
# 获取所有活跃的工作空间
workspaces = db.query(Workspace).filter(
Workspace.is_active.is_(True)
).all()
if not workspaces:
api_logger.warning("没有找到活跃的工作空间")
return {
"status": "SUCCESS",
"message": "没有找到活跃的工作空间",
"workspace_count": 0,
"workspace_results": []
}
api_logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
all_workspace_results = []
# 遍历每个工作空间
for workspace in workspaces:
workspace_id = workspace.id
api_logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
try:
# 1. 查询当前workspace下的所有app仅未删除的
apps = db.query(App).filter(
App.workspace_id == workspace_id,
App.is_active.is_(True)
).all()
if not apps:
# 如果没有app总量为0
memory_increment = write_memory_increment(
db=db,
workspace_id=workspace_id,
total_num=0
)
all_workspace_results.append({
"workspace_id": str(workspace_id),
"workspace_name": workspace.name,
"status": "SUCCESS",
"total_num": 0,
"end_user_count": 0,
"memory_increment_id": str(memory_increment.id),
"created_at": memory_increment.created_at.isoformat(),
})
api_logger.info(f"工作空间 {workspace.name} 没有应用记录总量为0")
continue
# 2. 查询所有app下的end_user_id去重
app_ids = [app.id for app in apps]
end_users = db.query(EndUser.id).filter(
EndUser.app_id.in_(app_ids)
).distinct().all()
# 3. 遍历所有end_user查询每个宿主的记忆总量并累加
total_num = 0
end_user_details = []
for (end_user_id,) in end_users:
try:
# 调用 search_all 接口查询该宿主的总量
result = await search_all(str(end_user_id))
user_total = result.get("total", 0)
total_num += user_total
end_user_details.append({
"end_user_id": str(end_user_id),
"total": user_total
})
except Exception as e:
# 记录单个用户查询失败,但继续处理其他用户
api_logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
end_user_details.append({
"end_user_id": str(end_user_id),
"total": 0,
"error": str(e)
})
# 4. 写入数据库
memory_increment = write_memory_increment(
db=db,
workspace_id=workspace_id,
total_num=total_num
)
all_workspace_results.append({
"workspace_id": str(workspace_id),
"workspace_name": workspace.name,
"status": "SUCCESS",
"total_num": total_num,
"end_user_count": len(end_users),
"memory_increment_id": str(memory_increment.id),
"created_at": memory_increment.created_at.isoformat(),
})
api_logger.info(
f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}"
)
except Exception as e:
db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间
api_logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
all_workspace_results.append({
"workspace_id": str(workspace_id),
"workspace_name": workspace.name,
"status": "FAILURE",
"error": str(e),
"total_num": 0,
"end_user_count": 0,
})
total_memory = sum(r.get("total_num", 0) for r in all_workspace_results)
success_count = sum(1 for r in all_workspace_results if r.get("status") == "SUCCESS")
return {
"status": "SUCCESS",
"message": f"成功处理 {success_count}/{len(workspaces)} 个工作空间,总记忆量: {total_memory}",
"workspace_count": len(workspaces),
"success_count": success_count,
"total_memory": total_memory,
"workspace_results": all_workspace_results
}
except Exception as e:
api_logger.error(f"记忆增量统计任务执行失败: {str(e)}")
return {
"status": "FAILURE",
"error": str(e),
"workspace_count": 0,
"workspace_results": []
}
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id
return result
except Exception as e:
elapsed_time = time.time() - start_time
return {
"status": "FAILURE",
"error": str(e),
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
@celery_app.task(

View File

@@ -1,4 +1,34 @@
{
"v0.2.5": {
"introduction": {
"codeName": "行云",
"releaseDate": "2026-2-26",
"upgradePosition": "🐻 精炼根基,优化核心用户体验与系统稳定性",
"coreUpgrades": [
"1. 用户体验与国际化 🎨<br>* 语言参数修复:语言偏好现正确保留<br>* 邮箱修改支持:用户可直接在用户管理系统中修改邮箱地址",
"2. 工作流可视化增强 💬<br>* 循环与迭代节点输出展示:实时显示执行进度和中间输出,便于调试复杂迭代过程<br>* 变量支持回车选择:支持回车键确认变量选择,简化工作流配置流程",
"3. 优化模型管理 ⚙️<br>* 模型广场移除自定义模型,优化模型使用体验",
"4. 稳健性与缺陷修复 🔧<br>* 知识图谱构建修复:解决知识图谱构建流程稳定性问题,确保更可靠的实体提取和关系映射",
"<br>",
"版本 0.2.5 通过解决国际化边界情况和改进工作流透明度,构建更具生产就绪性的平台。工作流可视化改进为更复杂的调试和监控能力奠定基础。未来将继续深化企业就绪性,扩展用户管理功能、优化知识图谱智能和增强工作流编排能力,在可观测性、性能优化和无缝集成模式方面持续改进。",
"智慧致远 🐻✨"
]
},
"introduction_en": {
"codeName": "Flowing Clouds",
"releaseDate": "2026-2-26",
"upgradePosition": "🐻 Refined foundations with enhanced user experience and system stability",
"coreUpgrades": [
"1. User Experience & Internationalization 🎨<br>* Language parameter fix: language preferences are now correctly retained<br>* Email Update Support: Users can now modify email addresses directly in user management system",
"2. Workflow Visualization Enhancements 💬<br>* Loop & Iteration Node Output Display: Real-time display of execution progress and intermediate outputs for easier debugging<br>* Variable Selection with Enter Key: Enabled Enter key confirmation for streamlined variable assignment",
"3. Optimized Model Management ⚙️<br>* Custom models have been removed from the Model marketplace to optimize the model usage experience",
"4. Robustness & Bug Fixes 🔧<br>* Knowledge Graph Construction Fix: Addressed stability issues in knowledge graph pipeline for more reliable entity extraction and relationship mapping",
"<br>",
"Version 0.2.5 matures MemoryBear's operational foundations by addressing internationalization edge cases and improving workflow transparency. The workflow visualization improvements lay groundwork for sophisticated debugging and monitoring capabilities. Looking forward, we will deepen enterprise readiness by expanding user management features, refining knowledge graph intelligence, and enhancing workflow orchestration with continued improvements in observability, performance optimization, and seamless integration patterns.",
"Intelligent Resilience 🐻✨"
]
}
},
"v0.2.4": {
"introduction": {
"codeName": "智远",

View File

@@ -0,0 +1,44 @@
"""202602281918
Revision ID: 4bf27c66ae63
Revises: 7672d8f0f939
Create Date: 2026-02-28 19:18:38.332468
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '4bf27c66ae63'
down_revision: Union[str, None] = '7672d8f0f939'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Add columns as nullable first
op.add_column('ontology_class', sa.Column('is_system_default', sa.Boolean(), nullable=True, comment='是否为系统默认类型'))
op.add_column('ontology_scene', sa.Column('is_system_default', sa.Boolean(), nullable=True, comment='是否为系统默认场景'))
# Set default value for existing rows
op.execute("UPDATE ontology_class SET is_system_default = false WHERE is_system_default IS NULL")
op.execute("UPDATE ontology_scene SET is_system_default = false WHERE is_system_default IS NULL")
# Now make columns NOT NULL
op.alter_column('ontology_class', 'is_system_default', nullable=False)
op.alter_column('ontology_scene', 'is_system_default', nullable=False)
op.create_index(op.f('ix_ontology_scene_is_system_default'), 'ontology_scene', ['is_system_default'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_ontology_scene_is_system_default'), table_name='ontology_scene')
op.drop_column('ontology_scene', 'is_system_default')
op.drop_column('ontology_class', 'is_system_default')
# ### end Alembic commands ###

2
web/.gitignore vendored
View File

@@ -23,4 +23,4 @@ dist-ssr
*.sln
*.sw?
vite.config.js
package-lock.json
package-lock.json

View File

@@ -1,205 +1,195 @@
# i18n 中英文对比报告
# Memory Bear 前端项目 - 中英文国际化对比报告
## 📊 统计概览
生成时间: 2024
- **中文键总数**: 1136
- **英文键总数**: 1052
- **中文缺失**: 27 个键
- **英文缺失**: 111 个键
## 📊 概览统计
### 文件信息
- **中文文件**: `src/i18n/zh.ts`
- **英文文件**: `src/i18n/en.ts`
### 模块统计
| 模块名称 | 中文键数 | 英文键数 | 状态 |
|---------|---------|---------|------|
| translation | ✅ | ✅ | 完整 |
## 🔍 详细对比分析
### 1. 主要模块对比
#### 1.1 基础信息 (title, memoryBear)
-**完全匹配**
- 中文: "记忆熊.AI"
- 英文: "Memory Bear.AI"
#### 1.2 首页模块 (index)
-**完全匹配** - 包含所有子键
#### 1.3 版本信息 (version)
-**完全匹配**
#### 1.4 快速操作 (quickActions)
-**完全匹配** - 包含所有功能入口
#### 1.5 引导模块 (guide)
-**完全匹配**
#### 1.6 首页引导 (indexTour)
-**完全匹配**
#### 1.7 菜单模块 (menu)
-**完全匹配** - 包含所有导航项
#### 1.8 仪表盘 (dashboard)
-**完全匹配** - 包含所有统计指标
#### 1.9 表格 (table)
-**完全匹配**
#### 1.10 头部 (header)
-**完全匹配**
#### 1.11 语言 (language)
-**完全匹配**
#### 1.12 用户管理 (user)
-**完全匹配** - 包含所有用户相关功能
#### 1.13 时区 (timezones)
-**完全匹配** - 包含全球主要时区
#### 1.14 通用 (common)
-**完全匹配** - 包含所有通用操作和提示
#### 1.15 模型管理 (model)
-**完全匹配**
#### 1.16 新模型管理 (modelNew)
-**完全匹配**
#### 1.17 知识库 (knowledgeBase)
-**完全匹配** - 包含所有知识库功能
- 包含知识图谱相关配置
#### 1.18 API (api)
-**完全匹配**
#### 1.19 记忆管理 (memory)
-**完全匹配**
#### 1.20 成员管理 (member)
-**完全匹配**
#### 1.21 记忆摘要 (memorySummary)
-**完全匹配**
#### 1.22 遗忘引擎 (forgettingEngine)
-**完全匹配**
#### 1.23 应用管理 (application)
-**完全匹配** - 包含所有应用配置功能
- 包含工作流、Agent配置等
#### 1.24 用户记忆 (userMemory)
-**完全匹配** - 包含所有记忆类型
#### 1.25 空间管理 (space)
-**完全匹配**
#### 1.26 记忆萃取引擎 (memoryExtractionEngine)
-**完全匹配** - 包含所有配置参数
#### 1.27 记忆对话 (memoryConversation)
-**完全匹配**
#### 1.28 登录 (login)
-**完全匹配**
#### 1.29 空状态 (empty)
-**完全匹配**
#### 1.30 API密钥 (apiKey)
-**完全匹配**
#### 1.31 工具管理 (tool)
-**完全匹配** - 包含MCP服务、内置工具、自定义工具
#### 1.32 工作流 (workflow)
-**完全匹配** - 包含所有节点配置
#### 1.33 情感引擎 (emotionEngine)
-**完全匹配**
#### 1.34 情感详情 (statementDetail)
-**完全匹配**
#### 1.35 反思引擎 (reflectionEngine)
-**完全匹配**
#### 1.36 定价 (pricing)
-**完全匹配** - 包含所有套餐信息
#### 1.37 遗忘详情 (forgetDetail)
-**完全匹配**
#### 1.38 情景记忆详情 (episodicDetail)
-**完全匹配**
#### 1.39 内隐记忆详情 (implicitDetail)
-**完全匹配**
#### 1.40 短期记忆详情 (shortTermDetail)
-**完全匹配**
#### 1.41 感知记忆详情 (perceptualDetail)
-**完全匹配**
#### 1.42 外显记忆详情 (explicitDetail)
-**完全匹配**
#### 1.43 工作记忆详情 (workingDetail)
-**完全匹配**
#### 1.44 本体工程 (ontology)
-**完全匹配**
#### 1.45 提示词工程 (prompt)
-**完全匹配**
#### 1.46 技能库 (skills)
-**完全匹配**
## ✅ 结论
### 整体评估
- **状态**: 🟢 完全同步
- **中英文键值对**: 完全匹配
- **结构一致性**: 100%
### 优点
1. ✅ 所有模块的中英文翻译完整
2. ✅ 键名结构完全一致
3. ✅ 嵌套层级对应准确
4. ✅ 特殊字符和变量占位符使用正确
5. ✅ 时区、语言等枚举值完整
### 建议
1. 定期检查新增功能的国际化覆盖
2. 建议添加自动化测试确保中英文键值对同步
3. 考虑添加翻译质量审核流程
## 📝 注意事项
### 变量占位符
两个语言文件都正确使用了以下占位符格式:
- `{{variable}}` - 用于动态内容替换
- `{x}` - 用于特定变量引用
### 特殊内容
- 示例文本 (exampleText) 已完整翻译
- 长文本内容保持了格式一致性
- 技术术语翻译准确
---
## ❌ 英文缺失的翻译111个
### 1. Application 模块 (3个)
- `application.cluster` - 集群
- `application.clusterDesc` - 创建Agent集群
- `application.fullAmount` - 全量
### 2. Role 角色管理模块 (15个)
- `role.roleManagement` - 角色管理
- `role.roleId` - 角色ID
- `role.roleName` - 角色名称
- `role.roleCode` - 角色编码
- `role.description` - 角色描述
- `role.status` - 状态
- `role.enabled` - 已启用
- `role.disabled` - 已停用
- `role.createTime` - 创建时间
- `role.createRole` - 新建角色
- `role.editRole` - 编辑角色
- `role.roleTemplate` - 角色模板
- `role.emptyTemplate` - 空模板
- `role.adminTemplate` - 管理员模板
- `role.userTemplate` - 用户模板
- `role.confirmDelete` - 确定要删除这个角色吗?
- `role.createSuccess` - 角色创建成功
- `role.updateSuccess` - 角色更新成功
- `role.deleteSuccess` - 角色删除成功
- `role.createFailed` - 角色创建失败
- `role.updateFailed` - 角色更新失败
- `role.deleteFailed` - 角色删除失败
### 3. Tenant 租户管理模块 (20个)
- `tenant.tenantId` - 租户ID
- `tenant.tenantName` - 租户名称
- `tenant.contactPerson` - 联系人
- `tenant.contactInfo` - 联系方式
- `tenant.status` - 状态
- `tenant.enabled` - 启用
- `tenant.disabled` - 禁用
- `tenant.expiryDate` - 到期时间
- `tenant.createTenant` - 新增租户
- `tenant.editTenant` - 编辑租户
- `tenant.searchPlaceholder` - 搜索租户ID、名称、联系人或联系方式
- `tenant.confirmDelete` - 确定要删除该租户吗?
- `tenant.confirmBatchDelete` - 确定要批量删除选中的租户吗?
- `tenant.fetchFailed` - 获取租户数据失败
- `tenant.batchEnableSuccess` - 批量启用成功
- `tenant.batchEnableFailed` - 批量启用失败
- `tenant.batchDisableSuccess` - 批量停用成功
- `tenant.batchDisableFailed` - 批量停用失败
- `tenant.exportSuccess` - 导出成功
- `tenant.batchDeleteSuccess` - 批量删除成功
- `tenant.batchDeleteFailed` - 批量删除失败
- `tenant.saveFailed` - 保存失败
- `tenant.batchImport` - 批量导入
### 4. User 用户管理模块 (13个)
- `user.tenantName` - 所属租户
- `user.password` - 密码
- `user.expiryDate` - 有效期
- `user.expiryDateDue` - 有效期至
- `user.batchImport` - 批量导入
- `user.batchImportUser` - 批量导入用户
- `user.downloadTemplate` - 下载导入模板
- `user.templateDownloadSuccess` - 模板下载成功
- `user.startImport` - 开始导入
- `user.batchImportSuccess` - 批量导入成功
- `user.importFailed` - 导入失败,请检查文件格式
- `user.noFileSelected` - 请选择要导入的文件
- `user.onlyXlsxOrCsv` - 只能上传 .xlsx 或 .csv 格式的文件
- `user.reselect` - 重新选择
- `user.noFileSelectedTip` - 未选择任何文件
- `user.downloadTemplateTip` - 请下载模板,填写用户信息后上传。
### 5. Product 产品管理模块 (13个)
- `product.applicationManagement` - 应用管理
- `product.createApplication` - 创建应用
- `product.applicationName` - 应用名称
- `product.applicationIcon` - 应用图标
- `product.applicationNameRequired` - 请输入应用名称
- `product.associationStatus` - 关联状态
- `product.associated` - 已关联
- `product.notAssociated` - 未关联
- `product.unassociate` - 解除关联
- `product.unassociateSuccess` - 解除关联成功
- `product.unassociateFailed` - 解除关联失败
- `product.viewKey` - 查看KEY
- `product.viewStats` - 查看统计
- `product.disableSuccess` - 停用成功
- `product.enableSuccess` - 启用成功
- `product.operationFailed` - 操作失败
### 6. 其他模块 (47个)
- `count` - 计数: {{count}}
- `increment` - 增加
- `decrement` - 减少
- `reset` - 重置
- `switchLanguage` - 切换语言
- `home.title` - 首页
- `home.welcome` - 欢迎使用我们的带单页路由的 React 应用!
- `home.counterCard` - 计数器演示
- `home.aboutCard` - 关于我们
- `home.workflowCard` - 工作流编辑器
- `home.websocketDemoCard` - WebSocket 演示
- `home.sseDemoCard` - SSE演示
- `workflow.title` - 工作流编辑器
- `workflow.description` - 拖拽节点创建连接,构建您的工作流程。点击节点可进行配置。
- `workflow.addNode` - 添加节点
- `workflow.deleteNode` - 删除选中
- `workflow.saveWorkflow` - 保存工作流
- `workflow.startNode` - 触发节点
- `workflow.conditionNode` - 条件判断
- `workflow.actionNode` - 执行动作
- `workflow.endNode` - 结束节点
- `workflow.newNode` - 新节点
- `workflow.node` - 节点
- `workflow.nodesCreated` - 已创建节点
- `workflow.loadingNodes` - 正在加载节点 {{progress}}%
- `workflow.loadingFailed` - 加载节点失败
- `workflow.create5kNodes` - 创建5000节点
- `workflow.create10kNodes` - 创建10000节点
- `notFound.title` - 页面未找到
- `notFound.description` - 请求的页面不存在。
- `notFound.backToHome` - 返回首页
---
## ✅ 中文缺失的翻译27个
### 1. Common 通用模块 (1个)
- `common.operateSuccess` - Operation successful
### 2. KnowledgeBase 知识库模块 (3个)
- `knowledgeBase.models` - Model
- `knowledgeBase.owner` - Owner
- `knowledgeBase.operation` - Operation
### 3. Application 应用模块 (15个)
- `application.multi_agent` - Cluster
- `application.multi_agentDesc` - Create an Agent Cluster
- `application.current` - Current
- `application.versionName` - Version Name
- `application.versionNameTip` - Version number format: v[major version number].[next version number].[revision number] (e.g. v1.3.0)
- `application.agentName` - Agent Name
- `application.roleType` - Role Type
- `application.coordinator` - Coordinator
- `application.analyzer` - Analyzer
- `application.executor` - Executor
- `application.reviewer` - Reviewer
- `application.updateSubAgent` - Update Sub Agent
- `application.subAgentMaxLength` - Sub Agent maximum {{maxLength}}
- `application.capabilities` - Capabilities
### 4. Space 空间模块 (5个)
- `space.storageType` - Storage Type
- `space.rag` - RAG storage
- `space.ragDesc` - Based on vector retrieval, suitable for document Q&A and semantic search
- `space.neo4j` - Graph storage
- `space.neo4jDesc` - Based on knowledge graph, suitable for relational reasoning and path query
### 5. MemoryExtractionEngine 记忆提取引擎模块 (4个)
- `memoryExtractionEngine.coreEntitiesAfterDedup` - Core entities after deduplication
- `memoryExtractionEngine.extractRelationalTriples` - Extracted relational triples (partial)
- `memoryExtractionEngine.extractRelationalTriplesDesc` - There are a total of {{count}} segments with clear semantic boundaries
- `memoryExtractionEngine.theEffectOfEntityDisambiguationLLMDriven` - The effect of entity disambiguation (LLM driven)
---
## 🎯 建议
### 优先级 1 - 核心功能模块(需要立即补充)
1. **Role 角色管理** - 完整模块缺失15个键
2. **Tenant 租户管理** - 完整模块缺失20个键
3. **Product 产品管理** - 完整模块缺失13个键
4. **User 用户管理扩展** - 批量导入功能缺失13个键
### 优先级 2 - 功能增强(建议补充)
1. **Application 应用模块** - 多代理相关功能15个键
2. **Space 空间模块** - 存储类型配置5个键
3. **MemoryExtractionEngine** - 实体去重相关4个键
### 优先级 3 - 演示/测试功能(可选)
1. **Home/Workflow/NotFound** - 演示页面30个键
2. **通用计数器功能** - 测试功能5个键
---
## 📝 下一步行动
1. **补充英文翻译**: 优先补充 Role、Tenant、Product、User 模块的英文翻译
2. **补充中文翻译**: 补充 Application、Space、MemoryExtractionEngine 模块的中文翻译
3. **清理无用翻译**: 如果 Home/Workflow 等演示功能不再使用,可以考虑从中文文件中移除
4. **建立翻译规范**: 建议建立翻译键的命名规范和审查流程,避免未来出现遗漏
**报告生成完成**

View File

@@ -27,12 +27,45 @@ const ChatContent: FC<ChatContentProps> = ({
}) => {
// Scroll container reference for controlling auto-scroll to bottom
const scrollContainerRef = useRef<(HTMLDivElement | null)>(null)
const prevDataLengthRef = useRef(data.length);
const isScrolledToBottomRef = useRef(true); // Track if user is scrolled to bottom
// Track scroll position to determine if user is at bottom
useEffect(() => {
const handleScroll = () => {
if (scrollContainerRef.current) {
const { scrollTop, scrollHeight, clientHeight } = scrollContainerRef.current;
// Consider user is at bottom if within 20px of the bottom
isScrolledToBottomRef.current = scrollHeight - scrollTop - clientHeight < 20;
}
};
const container = scrollContainerRef.current;
if (container) {
container.addEventListener('scroll', handleScroll);
// Initial check
handleScroll();
}
return () => {
if (container) {
container.removeEventListener('scroll', handleScroll);
}
};
}, []);
// Auto-scroll to bottom when data changes to show latest messages
// When data array length remains unchanged, if data is updated and user manually scrolled up, don't auto-scroll to bottom
// When data array length changes, auto-scroll to bottom
// If already scrolled to bottom, will auto-scroll to bottom
useEffect(() => {
setTimeout(() => {
if (scrollContainerRef.current) {
scrollContainerRef.current.scrollTop = scrollContainerRef.current.scrollHeight;
// Auto-scroll if data length changed OR user is currently at bottom
if (data.length !== prevDataLengthRef.current || isScrolledToBottomRef.current) {
scrollContainerRef.current.scrollTop = scrollContainerRef.current.scrollHeight;
}
prevDataLengthRef.current = data.length;
}
}, 0);
}, [data])

View File

@@ -2,10 +2,10 @@
* @Author: ZhaoYing
* @Date: 2026-02-02 15:12:42
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-04 14:06:28
* @Last Modified time: 2026-02-28 17:28:41
*/
/**
* BasicLayout Component
* BasicAuthLayout Component
*
* A minimal layout wrapper that provides:
* - User information initialization
@@ -26,12 +26,12 @@ import { useUser } from '@/store/user';
* Basic layout component for pages without navigation UI.
* Fetches user info and storage type on mount, then renders child routes.
*/
const BasicLayout: FC = () => {
const BasicAuthLayout: FC = () => {
const { getUserInfo } = useUser();
// Fetch user information and storage type on component mount
useEffect(() => {
getUserInfo();
getUserInfo(undefined, true); // Pass true to skip navigation jump
}, [getUserInfo]);
return (
@@ -42,4 +42,4 @@ const BasicLayout: FC = () => {
)
};
export default BasicLayout;
export default BasicAuthLayout;

View File

@@ -20,6 +20,7 @@
import { type FC, type Key, type ReactNode, useEffect } from 'react';
import { type RadioGroupProps } from 'antd';
import clsx from 'clsx'
import { useTranslation } from 'react-i18next';
/** Radio card option interface */
interface RadioCardOption {
@@ -33,6 +34,8 @@ interface RadioCardOption {
icon?: string;
/** Whether the option is disabled */
disabled?: boolean;
/** Whether the option is recommended */
recommend?: boolean;
/** Additional properties */
[key: string]: string | number | boolean | undefined | null | Key;
}
@@ -63,6 +66,7 @@ const RadioGroupCard: FC<RadioCardProps> = ({
allowClear = true,
block = false,
}) => {
const { t } = useTranslation();
/** Listen to value changes and trigger side effects via onValueChange callback */
useEffect(() => {
if (onValueChange) {
@@ -91,12 +95,13 @@ const RadioGroupCard: FC<RadioCardProps> = ({
})}>
{/* Render each option as a selectable card */}
{options.map(option => (
<div key={String(option.value)} className={clsx("rb:border rb:rounded-lg rb:w-full rb:p-[20px_12px] rb:text-center rb:cursor-pointer", {
<div key={String(option.value)} className={clsx("rb:relative rb:border rb:rounded-lg rb:w-full rb:p-[20px_12px] rb:text-center rb:cursor-pointer", {
'rb:bg-[rgba(21,94,239,0.06)] rb:border-[#155EEF]': option.value === value,
'rb:border-[#EBEBEB] rb:bg-[#ffffff]': option.value !== value,
'rb:opacity-[0.75]': option.disabled,
'rb:flex rb:items-center rb:text-left rb:gap-4': block,
})} onClick={() => handleChange(option)}>
{option.recommend && <div className="rb:absolute rb:right-0 rb:top-0 rb:bg-[#FF5D34] rb:rounded-[0px_7px_0px_8px] rb:text-[12px] rb:text-white rb:font-regular rb:leading-4 rb:p-[4px_8px]">{t('common.recommend')}</div>}
{/* Use custom render or default card layout */}
{itemRender ? itemRender(option) : (
<>

View File

@@ -41,6 +41,8 @@ interface SearchInputProps {
className?: string;
/** Input size */
size?: InputProps['size']
/** Maximum length of the input value */
maxLength?: number;
}
/** Search input component with debounce and throttle support */

View File

@@ -452,6 +452,9 @@ export const en = {
nextStep: 'Next Step',
prevStep: 'Previous Step',
exportSuccess: 'Export successful',
recommend: 'Recommend',
logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`,
imageSquareRequired: 'Please upload a square image',
},
model: {
searchPlaceholder: 'search model…',
@@ -541,7 +544,8 @@ export const en = {
ollama: "Ollama",
xinference: "Xinference",
gpustack: "Gpustack",
bedrock: "Bedrock"
bedrock: "Bedrock",
nameInvalid: 'Model name can only contain letters, numbers, underscores and spaces, cannot be empty or pure whitespace',
},
modelNew: {
group: 'Model Group',
@@ -1642,6 +1646,10 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
scene_type_distribution: 'Scene Type Distribution',
general_type_distribution: 'General Type Distribution',
unmatched: 'Unmatched',
disagreementCase: 'Disagreement Case',
Pruned: 'Pruned',
pruning: 'Pruning',
pruning_desc: 'Text pruning {{count}} fragments'
},
memoryConversation: {
searchPlaceholder: 'Enter user ID...',

View File

@@ -1031,6 +1031,9 @@ export const zh = {
nextStep: '下一步',
prevStep: '上一步',
exportSuccess: '导出成功',
recommend: '推荐',
logoTip: `支持图片格式JPG、PNG\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`,
imageSquareRequired: '请上传正方形比例图片',
},
model: {
searchPlaceholder: '搜索模型…',
@@ -1178,7 +1181,8 @@ export const zh = {
ollama: "Ollama",
xinference: "Xinference",
gpustack: "Gpustack",
bedrock: "Bedrock"
bedrock: "Bedrock",
nameInvalid: '模型名称只能包含字母、数字、下划线和空格, 不能为空或纯空格',
},
timezones: {
'Asia/Shanghai': '中国标准时间 (UTC+8)',
@@ -1639,6 +1643,10 @@ export const zh = {
scene_type_distribution: '场景类型',
general_type_distribution: '通用类型',
unmatched: '未匹配',
disagreementCase: '不一致案例',
Pruned: '已剪枝',
pruning: '剪枝',
pruning_desc: '文本剪枝{{count}}个片段'
},
memoryConversation: {
chatEmpty:'有什么我可以帮您的吗?',

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-02 16:33:54
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-04 18:30:10
* @Last Modified time: 2026-02-28 17:21:20
*/
/**
* User Store
@@ -44,7 +44,7 @@ export interface UserState {
/** Update login information */
updateLoginInfo: (values: LoginInfo) => void;
/** Get user information */
getUserInfo: (flag?: boolean) => void;
getUserInfo: (flag?: boolean, notNeedJump?: boolean) => void;
/** Clear user information */
clearUserInfo: () => void;
/** Logout user */
@@ -73,13 +73,13 @@ export const useUser = create<UserState>((set, get) => ({
cookieUtils.set('refreshToken', values.refresh_token);
set({ loginInfo: values });
},
getUserInfo: async (flag?: boolean) => {
getUserInfo: async (flag?: boolean, notNeedJump?: boolean) => {
if (!cookieUtils.get('authToken')) {
return
}
const { checkJump } = get()
const localUser = JSON.parse(localStorage.getItem('user') || '{}') as User;
if (localUser.id) {
if (localUser.id && !notNeedJump) {
checkJump()
return
}

View File

@@ -0,0 +1,50 @@
/*
* @Author: ZhaoYing
* @Date: 2026-03-02 13:46:53
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-02 14:38:33
*/
/**
* Form validation utilities
*/
interface UploadFile {
originFileObj: Blob;
[key: string]: unknown;
}
/**
* Validate if uploaded image is square (width === height)
* @param errorMessage - Error message to display when validation fails
* @returns Ant Design form validator
*/
export const validateSquareImage = (errorMessage: string = 'Image must be square') => {
return (_: unknown, value: UploadFile | UploadFile[] | undefined) => {
if (!value || (Array.isArray(value) && value.length === 0)) {
return Promise.resolve();
}
const file = Array.isArray(value) ? value[0] : value;
if (file?.originFileObj) {
return new Promise<void>((resolve, reject) => {
const img = new Image();
img.onload = () => {
if (img.width === img.height) {
resolve();
} else {
reject(new Error(errorMessage));
}
};
img.onerror = () => reject(new Error('Failed to load image'));
img.src = URL.createObjectURL(file.originFileObj);
});
}
return Promise.resolve();
};
};
// - Cannot be empty or pure whitespace
// - Cannot start with a space
export const stringRegExp = /^[a-zA-Z0-9\u4e00-\u9fa5][a-zA-Z0-9\u4e00-\u9fa5\s]*$/

View File

@@ -12,6 +12,7 @@ import dayjs from 'dayjs'
import type { ApiKey, ApiKeyModalRef } from '../types';
import RbModal from '@/components/RbModal'
import { createApiKey, updateApiKey } from '@/api/apiKey';
import { stringRegExp } from '@/utils/validator';
const FormItem = Form.Item;
@@ -78,7 +79,7 @@ const ApiKeyModal = forwardRef<ApiKeyModalRef, CreateModalProps>(({
form.validateFields()
.then((values) => {
const { memory, rag, expires_at, ...rest } = values
let scopes = []
const scopes = []
if (memory) {
scopes.push('memory')
@@ -130,7 +131,11 @@ const ApiKeyModal = forwardRef<ApiKeyModalRef, CreateModalProps>(({
<FormItem
name="name"
label={t('apiKey.name')}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
rules={[
{ required: true, message: t('common.pleaseEnter') },
{ max: 50 },
{ pattern: stringRegExp, message: t('common.nameInvalid') },
]}
>
<Input placeholder={t('common.enter')} />
</FormItem>
@@ -138,6 +143,7 @@ const ApiKeyModal = forwardRef<ApiKeyModalRef, CreateModalProps>(({
<FormItem
name="description"
label={t('apiKey.description')}
rules={[{ max: 500 }]}
>
<Input.TextArea placeholder={t('common.pleaseEnter')} rows={3} />
</FormItem>

View File

@@ -169,8 +169,8 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
getApplicationConfig(id as string).then(res => {
const response = res as Config
const { skills, variables } = response
let allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : []
let allTools = Array.isArray(response.tools) ? response.tools : []
const allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : []
const allTools = Array.isArray(response.tools) ? response.tools : []
const memoryContent = response.memory?.memory_config_id
const parsedMemoryContent = memoryContent === null || memoryContent === ''
? undefined
@@ -431,7 +431,11 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
</Button>
</div>
<Form.Item name="system_prompt" className="rb:mb-0!">
<Form.Item
name="system_prompt"
className="rb:mb-0!"
rules={[{ max: 10000 }]}
>
<Input.TextArea
placeholder={t('application.promptPlaceholder')}
styles={{

View File

@@ -29,7 +29,7 @@ const Api: FC<{ application: Application | null }> = ({ application }) => {
const { t } = useTranslation();
const activeMethods = ['POST'];
const { message, modal } = App.useApp()
const copyContent = window.location.origin + '/v1/chat'
const copyContent = window.location.origin + '/v1/app/chat'
const apiKeyModalRef = useRef<ApiKeyModalRef>(null);
const apiKeyConfigModalRef = useRef<ApiKeyConfigModalRef>(null);
const [apiKeyList, setApiKeyList] = useState<ApiKey[]>([])

View File

@@ -21,6 +21,7 @@ import WorkflowIcon from '@/assets/images/application/workflow.svg'
import type { ApplicationModalData, ApplicationModalRef, Application } from '../types'
import RbModal from '@/components/RbModal'
import { addApplication, updateApplication } from '@/api/application'
import { stringRegExp } from '@/utils/validator';
const FormItem = Form.Item;
@@ -131,13 +132,18 @@ const ApplicationModal = forwardRef<ApplicationModalRef, ApplicationModalProps>(
<FormItem
name="name"
label={t('application.applicationName')}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
rules={[
{ required: true, message: t('common.pleaseEnter') },
{ max: 50 },
{ pattern: stringRegExp, message: t('common.nameInvalid') },
]}
>
<Input placeholder={t('common.enter')} />
</FormItem>
<FormItem
name="description"
label={t('application.description')}
rules={[{ max: 500 }]}
>
<Input.TextArea placeholder={t('common.enter')} />
</FormItem>

View File

@@ -152,7 +152,11 @@ const MemberModal = forwardRef<MemberModalRef, MemberModalProps>(({
<FormItem
name="email"
label={t('member.email')}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
rules={[
{ required: true, message: t('common.pleaseEnter') },
{ type: 'email' },
{ max: 100 },
]}
>
<Input placeholder={t('common.enterPlaceholder', { title: t('member.email') })} disabled={!!editingUser} />
</FormItem>

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 17:30:11
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-09 21:04:14
* @Last Modified time: 2026-03-02 11:41:12
*/
/**
* Result Component
@@ -91,7 +91,7 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
setDeduplication({...initObj} as ModuleItem)
setTestResult({} as TestResult)
const handleStreamMessage = (list: SSEMessage[]) => {
list.forEach((data: AnyObject) => {
switch(data.event) {
case 'text_preprocessing': // Start text preprocessing
@@ -104,7 +104,7 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
case 'text_preprocessing_result': // Text preprocessing in progress
setTextPreprocessing(prev => ({
...prev,
data: [...prev.data, data.data?.data]
data: [...prev.data, data.data?.deleted_messages ? { deleted_messages: data.data?.deleted_messages } : data.data?.data],
}))
break
case 'text_preprocessing_complete': // Text preprocessing complete
@@ -193,9 +193,9 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
dialogue_text: t('memoryExtractionEngine.exampleText'),
custom_text: runForm.getFieldValue('custom_text')
}, handleStreamMessage)
.finally(() => {
setRunLoading(false)
})
.finally(() => {
setRunLoading(false)
})
}
const completedNum = [textPreprocessing, knowledgeExtraction, creatingNodesEdges, deduplication].filter(item => item.status === 'completed').length
const deduplicationData = groupDataByType(deduplication.data, 'result_type')
@@ -251,10 +251,10 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
</div>
</>
: !testResult || Object.keys(testResult).length === 0
? <RbAlert color="orange" icon={<ExclamationCircleFilled />} className="rb:mb-3.5">
{t('memoryExtractionEngine.warning')}
</RbAlert>
: <RbAlert color="green" icon={<ExclamationCircleFilled />} className="rb:mb-3.5">
? <RbAlert color="orange" icon={<ExclamationCircleFilled />} className="rb:mb-3.5">
{t('memoryExtractionEngine.warning')}
</RbAlert>
: <RbAlert color="green" icon={<ExclamationCircleFilled />} className="rb:mb-3.5">
{t('memoryExtractionEngine.success')}
</RbAlert>
}
@@ -266,15 +266,28 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
headerType="borderL"
headerClassName="rb:before:bg-[#155EEF]!"
>
{textPreprocessing.data.map((vo, index) => (
<div key={index} className="rb:mb-3 rb:text-[12px] rb:text-[#5B6167] rb:leading-4 rb:font-regular">
<Markdown content={'-' + t('memoryExtractionEngine.fragment') + vo.chunk_index + ': ' + (vo.content.startsWith('\n') ? vo.content : '\n' + vo.content)} />
</div>
))}
{textPreprocessing.data.map((vo, index) => {
if (vo.deleted_messages) {
return <div key={index} className="rb:mb-3 rb:pb-1 rb:border-b rb:border-b-[#EBEBEB]">
<div className="rb:font-medium rb:text-[12px] rb:mb-2">{t('memoryExtractionEngine.Pruned')}</div>
{vo.deleted_messages.map((msg: any, idx: number) => (
<div key={idx} className="rb:text-[12px] rb:text-[#5B6167] rb:leading-4 rb:font-regular">
<Markdown content={'-' + t('memoryExtractionEngine.pruning') + (idx + 1) + ': ' + msg.content} />
</div>
))}
</div>
}
return (
<div key={index} className="rb:mb-3 rb:text-[12px] rb:text-[#5B6167] rb:leading-4 rb:font-regular">
<Markdown content={'-' + t('memoryExtractionEngine.fragment') + vo.chunk_index + ': ' + (vo.content.startsWith('\n') ? vo.content : '\n' + vo.content)} />
</div>
)
})}
{formatTime(textPreprocessing)}
{textPreprocessing.result &&
<RbAlert color="blue" icon={<CheckCircleFilled />} className="rb:mt-3">
{t('memoryExtractionEngine.text_preprocessing_desc', { count: textPreprocessing.result.total_chunks })},
{t('memoryExtractionEngine.pruning_desc', { count: textPreprocessing.result.pruning.deleted_count || 0 })},
{t('memoryExtractionEngine.text_preprocessing_desc', { count: textPreprocessing.result.total_chunks })},
{t('memoryExtractionEngine.chunkerStrategy')}: {t(`memoryExtractionEngine.${lowercaseFirst(textPreprocessing.result.chunker_strategy)}`)}
</RbAlert>
}
@@ -286,7 +299,7 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
headerType="borderL"
headerClassName="rb:before:bg-[#155EEF]!"
>
{knowledgeExtraction.data.map((vo, index) =>
{knowledgeExtraction.data.map((vo, index) =>
<div key={index} className="rb:mb-3 rb:text-[12px] rb:text-[#5B6167] rb:leading-4 rb:font-regular">{vo.statement}</div>
)}
{formatTime(knowledgeExtraction)}
@@ -345,31 +358,30 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
{Object.keys(resultObj).map((key, index) => {
const keys = (resultObj as Record<string, string>)[key].split('.')
return (
<div key={index}>
<div className="rb:text-[24px] rb:leading-7.5 rb:font-extrabold">{(testResult?.[keys[0] as keyof TestResult] as any)?.[keys[1]]}</div>
<div className="rb:text-[12px] rb:text-[#5B6167] rb:leading-4 rb:font-regular">{t(`memoryExtractionEngine.${key}`)}</div>
<div className="rb:mt-1 rb:text-[12px] rb:text-[#369F21] rb:leading-3.5 rb:font-regular">
{}
{key === 'extractTheNumberOfEntities' && testResult.dedup
? t(`memoryExtractionEngine.${key}Desc`, {
num: testResult.dedup.total_merged_count,
exact: testResult.dedup.breakdown.exact,
fuzzy: testResult.dedup.breakdown.fuzzy,
llm: testResult.dedup.breakdown.llm,
})
: key === 'numberOfEntityDisambiguation' && testResult.disambiguation
? t(`memoryExtractionEngine.${key}Desc`, { num: testResult.disambiguation.effects?.length, block_count: testResult.disambiguation.block_count })
: key === 'numberOfRelationalTriples' && testResult.triplets
? t(`memoryExtractionEngine.${key}Desc`, { num: testResult.triplets.count })
:t(`memoryExtractionEngine.${key}Desc`)
}
<div key={index}>
<div className="rb:text-[24px] rb:leading-7.5 rb:font-extrabold">{(testResult?.[keys[0] as keyof TestResult] as any)?.[keys[1]]}</div>
<div className="rb:text-[12px] rb:text-[#5B6167] rb:leading-4 rb:font-regular">{t(`memoryExtractionEngine.${key}`)}</div>
<div className="rb:mt-1 rb:text-[12px] rb:text-[#369F21] rb:leading-3.5 rb:font-regular">
{key === 'extractTheNumberOfEntities' && testResult.dedup
? t(`memoryExtractionEngine.${key}Desc`, {
num: testResult.dedup.total_merged_count,
exact: testResult.dedup.breakdown.exact,
fuzzy: testResult.dedup.breakdown.fuzzy,
llm: testResult.dedup.breakdown.llm,
})
: key === 'numberOfEntityDisambiguation' && testResult.disambiguation
? t(`memoryExtractionEngine.${key}Desc`, { num: testResult.disambiguation.effects?.length, block_count: testResult.disambiguation.block_count })
: key === 'numberOfRelationalTriples' && testResult.triplets
? t(`memoryExtractionEngine.${key}Desc`, { num: testResult.triplets.count })
:t(`memoryExtractionEngine.${key}Desc`)
}
</div>
</div>
</div>
)})}
)})}
</div>
</RbCard>
}
{testResult?.dedup?.impact && testResult.dedup.impact?.length > 0 &&
<RbCard
title={t('memoryExtractionEngine.entityDeduplicationImpact')}
@@ -388,7 +400,7 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
</RbAlert>
</RbCard>
}
{testResult?.disambiguation && testResult.disambiguation?.effects?.length > 0 &&
<RbCard
title={t('memoryExtractionEngine.theEffectOfEntityDisambiguationLLMDriven')}
@@ -399,7 +411,7 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
<div key={index} className={clsx("rb:text-[12px] rb:text-[#5B6167] rb:leading-4", {
'rb:mt-4': index > 0,
})}>
<div className="rb:font-medium rb:mb-2">Disagreement Case {index +1}:</div>
<div className="rb:font-medium rb:mb-2">{t('memoryExtractionEngine.disagreementCase')} {index +1}:</div>
-{item.left.name}({item.left.type}) vs {item.right.name}({item.right.type}) <span className="rb:text-[#369F21]">{item.result}</span>
</div>
))}
@@ -409,7 +421,7 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
</RbAlert>
</RbCard>
}
{testResult?.core_entities && testResult?.core_entities.length > 0 &&
<RbCard
title={t('memoryExtractionEngine.coreEntitiesAfterDedup')}
@@ -433,7 +445,7 @@ const Result: FC<ResultProps> = ({ loading, handleSave }) => {
</div>
</RbCard>
}
{testResult?.triplet_samples && testResult?.triplet_samples.length > 0 &&
<RbCard
title={t('memoryExtractionEngine.extractRelationalTriples')}

View File

@@ -18,6 +18,7 @@ import RbModal from '@/components/RbModal'
import { createMemoryConfig, updateMemoryConfig } from '@/api/memory'
import { getOntologyScenesSimpleUrl } from '@/api/ontology'
import CustomSelect from '@/components/CustomSelect';
import { stringRegExp } from '@/utils/validator';
const FormItem = Form.Item;
@@ -110,7 +111,11 @@ const MemoryForm = forwardRef<MemoryFormRef, MemoryFormProps>(({
<FormItem
name="config_name"
label={t('memory.configurationName')}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
rules={[
{ required: true, message: t('common.pleaseEnter') },
{ max: 50 },
{ pattern: stringRegExp, message: t('common.nameInvalid') },
]}
>
<Input placeholder={t('common.pleaseEnter')} />
</FormItem>
@@ -118,6 +123,7 @@ const MemoryForm = forwardRef<MemoryFormRef, MemoryFormProps>(({
<FormItem
name="config_desc"
label={t('memory.desc')}
rules={[{ max: 500 }]}
>
<Input.TextArea placeholder={t('common.pleaseEnter')} />
</FormItem>

View File

@@ -1,8 +1,8 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-03 16:50:10
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 16:50:10
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-27 10:20:51
*/
/**
* Model List View
@@ -21,7 +21,7 @@ import PageEmpty from '@/components/Empty/PageEmpty';
import Tag from '@/components/Tag';
import KeyConfigModal from './components/KeyConfigModal'
import ModelListDetail from './components/ModelListDetail'
import { getLogoUrl } from './utils'
import { getListLogoUrl } from './utils'
/**
* Model list component
@@ -70,7 +70,7 @@ const ModelList = forwardRef<BaseRef, { query: any; handleEdit: (vo?: ModelListI
<RbCard
key={item.provider}
title={t(`modelNew.${item.provider}`)}
avatarUrl={getLogoUrl(item.logo)}
avatarUrl={getListLogoUrl(item.provider, item.logo)}
avatar={
<div className="rb:w-12 rb:h-12 rb:rounded-lg rb:mr-3.25 rb:bg-[#155eef] rb:flex rb:items-center rb:justify-center rb:text-[28px] rb:text-[#ffffff]">
{item.provider[0].toUpperCase()}

View File

@@ -1,8 +1,8 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-03 16:49:28
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 16:49:28
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-28 17:24:05
*/
/**
* Custom Model Modal
@@ -20,6 +20,7 @@ import CustomSelect from '@/components/CustomSelect'
import UploadImages from '@/components/Upload/UploadImages'
import { updateCustomModel, addCustomModel, modelTypeUrl, modelProviderUrl } from '@/api/models'
import { getFileLink } from '@/api/fileStorage'
import { validateSquareImage, stringRegExp } from '@/utils/validator'
/**
* Custom model modal component
@@ -50,7 +51,7 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
setModel(model);
form.setFieldsValue({
...model,
logo: model.logo ? { url: model.logo, uid: model.logo, status: 'done', name: 'logo' } : undefined
logo: model.logo && model.logo.startsWith('http') ? { url: model.logo, uid: model.logo, status: 'done', name: 'logo' } : undefined
});
} else {
setIsEdit(false);
@@ -65,7 +66,7 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
const res = isEdit ? updateCustomModel(model.id, rest) : addCustomModel(data)
res.then(() => {
refresh && refresh(isEdit)
refresh?.(isEdit)
handleClose()
message.success(isEdit ? t('common.updateSuccess') : t('common.createSuccess'))
})
@@ -79,7 +80,7 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
.validateFields()
.then((values) => {
const { logo, ...rest } = values;
let formData: CustomModelForm = {
const formData: CustomModelForm = {
...rest
}
@@ -125,14 +126,22 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
name="logo"
label={t('modelNew.logo')}
valuePropName="fileList"
rules={[{ required: true, message: t('common.pleaseSelect') }]}
rules={[
{ required: true, message: t('common.pleaseSelect') },
{ validator: validateSquareImage(t('common.imageSquareRequired')) }
]}
extra={t('common.logoTip')?.split('\n').map((vo, index) => <div key={index}>{vo}</div>)}
>
<UploadImages />
<UploadImages fileSize={2} />
</Form.Item>
<Form.Item
name="name"
label={t('modelNew.name')}
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('modelNew.name') }) }]}
rules={[
{ required: true, message: t('common.inputPlaceholder', { title: t('modelNew.name') }) },
{ max: 50 },
{ pattern: stringRegExp, message: t('common.nameInvalid') },
]}
>
<Input placeholder={t('common.pleaseEnter')} />
</Form.Item>
@@ -166,6 +175,7 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
<Form.Item
name="description"
label={t('modelNew.description')}
rules={[{ max: 500 }]}
>
<Input.TextArea placeholder={t('common.pleaseEnter')} />
</Form.Item>

View File

@@ -1,8 +1,8 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-03 16:49:33
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 16:49:33
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-02 12:23:13
*/
/**
* Group Model Modal
@@ -21,6 +21,7 @@ import { updateCompositeModel, modelTypeUrl, addCompositeModel } from '@/api/mod
import UploadImages from '@/components/Upload/UploadImages'
import ModelImplement from './ModelImplement'
import { getFileLink } from '@/api/fileStorage'
import { validateSquareImage, stringRegExp } from '@/utils/validator'
/**
* Group model modal component
@@ -133,15 +134,26 @@ const GroupModelModal = forwardRef<GroupModelModalRef, GroupModelModalProps>(({
name="logo"
label={t('modelNew.logo')}
valuePropName="fileList"
rules={[{ required: true, message: t('common.pleaseSelect') }]}
rules={[
{ required: true, message: t('common.pleaseSelect') },
{ validator: validateSquareImage(t('common.imageSquareRequired')) }
]}
extra={t('common.logoTip')?.split('\n').map((vo, index) => <div key={index}>{vo}</div>)}
>
<UploadImages />
<UploadImages
fileSize={2}
fileType={['png', 'jpg']}
/>
</Form.Item>
<Form.Item
name="name"
label={t('modelNew.name')}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
rules={[
{ required: true, message: t('common.pleaseEnter') },
{ max: 50 },
{ pattern: stringRegExp, message: t('common.nameInvalid') },
]}
>
<Input placeholder={t('common.pleaseEnter')} />
</Form.Item>
@@ -165,6 +177,7 @@ const GroupModelModal = forwardRef<GroupModelModalRef, GroupModelModalProps>(({
<Form.Item
name="description"
label={t('modelNew.description')}
rules={[{ max: 500 }]}
>
<Input.TextArea placeholder={t('common.pleaseEnter')} />
</Form.Item>

View File

@@ -121,6 +121,7 @@ const tabKeys = ['group', 'list', 'square']
{activeTab !== 'list' &&
<Form.Item name="search" noStyle>
<SearchInput
maxLength={50}
placeholder={t(`modelNew.${activeTab}SearchPlaceholder`)}
className="rb:w-70!"
/>

View File

@@ -1,8 +1,8 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-03 16:50:22
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 16:50:22
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-27 10:22:46
*/
/**
* Utility functions for Model Management
@@ -40,5 +40,26 @@ export const getLogoUrl = (logo?: string) => {
return logo
}
return ICONS[logo as keyof typeof ICONS] || undefined
}
/**
* Get logo URL from provider name or URL
* @param provider - Provider name
* @param logo - Provider name or logo URL
* @returns Logo URL or undefined
*/
export const getListLogoUrl = (provider?: string, logo?: string) => {
let url = ICONS[provider as keyof typeof ICONS]
if (url) return url
if (!logo) {
return undefined
}
if (logo.startsWith('http')) {
return logo
}
return ICONS[logo as keyof typeof ICONS] || undefined
}

View File

@@ -182,7 +182,10 @@ const OntologyClassExtractModal = forwardRef<OntologyClassExtractModalRef, Ontol
<FormItem
name="scenario"
label={t('ontology.scenario')}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
rules={[
{ required: true, message: t('common.pleaseEnter') },
{ max: 2000 },
]}
>
<Input.TextArea placeholder={t('ontology.scenarioPlaceholder')} />
</FormItem>

View File

@@ -11,6 +11,7 @@ import { useTranslation } from 'react-i18next';
import type { AddClassItem, OntologyClassModalRef } from '../types'
import RbModal from '@/components/RbModal'
import { createOntologyClass } from '@/api/ontology'
import { stringRegExp } from '@/utils/validator';
const FormItem = Form.Item;
@@ -105,7 +106,11 @@ const OntologyClassModal = forwardRef<OntologyClassModalRef, OntologyClassModalP
<FormItem
name="class_name"
label={t('ontology.class_name')}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
rules={[
{ required: true, message: t('common.pleaseEnter') },
{ max: 50 },
{ pattern: stringRegExp, message: t('common.nameInvalid') },
]}
>
<Input placeholder={t('common.enter')} />
</FormItem>
@@ -113,6 +118,7 @@ const OntologyClassModal = forwardRef<OntologyClassModalRef, OntologyClassModalP
<FormItem
name="class_description"
label={t('ontology.class_description')}
rules={[{ max: 500 }]}
>
<Input.TextArea placeholder={t('ontology.classDescriptionPlaceholder')} />
</FormItem>

View File

@@ -11,6 +11,7 @@ import { useTranslation } from 'react-i18next';
import type { OntologyItem, OntologyModalData, OntologyModalRef } from '../types'
import RbModal from '@/components/RbModal'
import { createOntologyScene, updateOntologyScene } from '@/api/ontology'
import { stringRegExp } from '@/utils/validator';
const FormItem = Form.Item;
@@ -109,7 +110,11 @@ const OntologyModal = forwardRef<OntologyModalRef, OntologyModalProps>(({
<FormItem
name="scene_name"
label={t('ontology.scene_name')}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
rules={[
{ required: true, message: t('common.pleaseEnter') },
{ max: 50 },
{ pattern: stringRegExp, message: t('common.nameInvalid') },
]}
>
<Input placeholder={t('common.enter')} />
</FormItem>
@@ -117,6 +122,7 @@ const OntologyModal = forwardRef<OntologyModalRef, OntologyModalProps>(({
<FormItem
name="scene_description"
label={t('ontology.scene_description')}
rules={[{ max: 500 }]}
>
<Input.TextArea placeholder={t('ontology.descriptionPlaceholder')} />
</FormItem>

View File

@@ -17,6 +17,7 @@ import type { AiPromptModalRef } from '@/views/ApplicationConfig/types'
import exitIcon from '@/assets/images/knowledgeBase/exit.png';
import type { SkillFormData } from '../types'
import { getSkillDetail, createSkill, updateSkill } from '@/api/skill'
import { stringRegExp } from '@/utils/validator';
/**
* Skill Configuration Page Component
@@ -110,7 +111,7 @@ const SkillConfig: FC = () => {
// Format tools data for API
const formData = {
...rest,
tools: tools?.map((item: any) => ({
tools: tools?.map((item) => ({
tool_id: item.tool_id,
operation: item.operation
}))
@@ -144,13 +145,18 @@ const SkillConfig: FC = () => {
<Form.Item
name="name"
label={t('skills.name')}
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('skills.name') }) }]}
rules={[
{ required: true, message: t('common.inputPlaceholder', { title: t('skills.name') }) },
{ max: 50 },
{ pattern: stringRegExp, message: t('common.nameInvalid') },
]}
>
<Input placeholder={t('common.pleaseEnter')} />
</Form.Item>
<Form.Item
name="description"
label={t('skills.description')}
rules={[{ max: 500 }]}
>
<Input.TextArea placeholder={t('skills.descriptionPlaceholder')} />
</Form.Item>

View File

@@ -17,6 +17,8 @@ export interface SkillFormData {
tools: Array<{
/** Tool identifier */
tool_id: string;
/** Tool operation/action */
operation?: string;
}>;
/** Skill configuration settings */
config: {

View File

@@ -23,6 +23,7 @@ import UploadImages from '@/components/Upload/UploadImages'
import { getFileLink } from '@/api/fileStorage'
import ragIcon from '@/assets/images/space/rag.png'
import neo4jIcon from '@/assets/images/space/neo4j.png'
import { stringRegExp } from '@/utils/validator';
const FormItem = Form.Item;
@@ -34,8 +35,8 @@ interface SpaceModalProps {
}
/** Storage types */
const types: StorageType[] = [
'rag',
'neo4j',
'rag',
]
/** Type icons mapping */
const typeIcons: Record<StorageType, string> = {
@@ -91,7 +92,7 @@ const SpaceModal = forwardRef<SpaceModalRef, SpaceModalProps>(({
setCurrentStep(1)
} else {
const { icon, ...rest } = values
let formData: SpaceModalData = {
const formData: SpaceModalData = {
...rest
}
if (icon?.response?.data.file_id) {
@@ -154,6 +155,9 @@ const SpaceModal = forwardRef<SpaceModalRef, SpaceModalProps>(({
<Form
form={form}
layout="vertical"
initialValues={{
storage_type: types[0],
}}
>
<Form.Item
name="icon"
@@ -161,14 +165,19 @@ const SpaceModal = forwardRef<SpaceModalRef, SpaceModalProps>(({
valuePropName="fileList"
hidden={currentStep === 1}
rules={[{ required: true, message: t('common.selectPlaceholder', { title: t('space.spaceIcon') }) }]}
extra={t('common.logoTip')?.split('\n').map((vo, index) => <div key={index}>{vo}</div>)}
>
<UploadImages />
<UploadImages fileSize={2} />
</Form.Item>
<FormItem
name="name"
label={t('space.spaceName')}
hidden={currentStep === 1}
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('space.spaceName') }) }]}
rules={[
{ required: true, message: t('common.inputPlaceholder', { title: t('space.spaceName') }) },
{ max: 50 },
{ pattern: stringRegExp, message: t('common.nameInvalid') },
]}
>
<Input placeholder={t('common.inputPlaceholder', { title: t('space.spaceName') })} />
</FormItem>
@@ -183,7 +192,8 @@ const SpaceModal = forwardRef<SpaceModalRef, SpaceModalProps>(({
value: type,
label: t(`space.${type}`),
labelDesc: t(`space.${type}Desc`),
icon: typeIcons[type]
icon: typeIcons[type],
recommend: type === 'neo4j',
}))}
block={true}
/>

View File

@@ -4,7 +4,7 @@
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-03 10:12:33
*/
import { useEffect, useState, type FC } from 'react';
import { useEffect, useState, useRef, type FC } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { $getSelection, $isRangeSelection, $isTextNode, COMMAND_PRIORITY_HIGH, KEY_ENTER_COMMAND, KEY_ARROW_DOWN_COMMAND, KEY_ARROW_UP_COMMAND, KEY_ESCAPE_COMMAND } from 'lexical';
@@ -30,6 +30,26 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }>
const [showSuggestions, setShowSuggestions] = useState(false);
const [selectedIndex, setSelectedIndex] = useState(0);
const [popupPosition, setPopupPosition] = useState({ top: 0, left: 0 });
const popupRef = useRef<HTMLDivElement>(null);
const scrollSelectedIntoView = () => {
if (!popupRef.current) return;
const selectedElement = popupRef.current.querySelector('[data-selected="true"]');
if (!selectedElement) return;
const container = popupRef.current;
const element = selectedElement as HTMLElement;
const containerRect = container.getBoundingClientRect();
const elementRect = element.getBoundingClientRect();
if (elementRect.bottom > containerRect.bottom) {
container.scrollTop += elementRect.bottom - containerRect.bottom;
} else if (elementRect.top < containerRect.top) {
container.scrollTop -= containerRect.top - elementRect.top;
}
};
// Listen to editor updates and show suggestions when '/' is typed
useEffect(() => {
@@ -140,7 +160,7 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }>
};
// Group suggestions by node ID
const groupedSuggestions = options.reduce((groups: Record<string, any[]>, suggestion) => {
const groupedSuggestions = options.reduce((groups: Record<string, Suggestion[]>, suggestion) => {
const { nodeData } = suggestion
const nodeId = nodeData.id as string;
if (!groups[nodeId]) {
@@ -190,7 +210,9 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }>
while (nextIndex < allOptions.length && allOptions[nextIndex].disabled) {
nextIndex++;
}
return nextIndex >= allOptions.length ? prev : nextIndex;
const newIndex = nextIndex >= allOptions.length ? prev : nextIndex;
setTimeout(() => scrollSelectedIntoView(), 0);
return newIndex;
});
return true;
}
@@ -210,7 +232,9 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }>
while (prevIndex >= 0 && allOptions[prevIndex].disabled) {
prevIndex--;
}
return prevIndex < 0 ? prev : prevIndex;
const newIndex = prevIndex < 0 ? prev : prevIndex;
setTimeout(() => scrollSelectedIntoView(), 0);
return newIndex;
});
return true;
}
@@ -247,6 +271,7 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }>
}
return (
<div
ref={popupRef}
data-autocomplete-popup="true"
onMouseDown={(e) => e.preventDefault()}
style={{
@@ -279,6 +304,7 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }>
return (
<div
key={option.key}
data-selected={selectedIndex === globalIndex}
style={{
padding: '8px 12px',
cursor: option.disabled ? 'not-allowed' : 'pointer',

View File

@@ -88,4 +88,4 @@ export default defineConfig({
},
},
},
})
})