From 33238d34c94d421050c5c0579c6bb1e005bd05a8 Mon Sep 17 00:00:00 2001 From: lixiangcheng1 Date: Thu, 26 Feb 2026 10:17:44 +0800 Subject: [PATCH 001/164] [fix]Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) --- api/app/tasks.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/api/app/tasks.py b/api/app/tasks.py index d60af6e5..499b93a5 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -255,7 +255,7 @@ def parse_document(file_path: str, document_id: uuid.UUID): progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n" return result - try: + def sync_task(): trio.run( lambda: _run( row=task, @@ -270,6 +270,10 @@ def parse_document(file_path: str, document_id: uuid.UUID): with_community=with_community, ) ) + try: + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(sync_task) + future.result() # Blocks until the task completes except Exception as e: progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n" progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)" From 67053ab8aeb23440934ef2581af3499ae56be95b Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Thu, 26 Feb 2026 13:35:07 +0800 Subject: [PATCH 002/164] fix(workspace member): After the space inviter is removed, it can still be invited again. --- api/app/repositories/workspace_repository.py | 1 + api/app/services/workspace_service.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/api/app/repositories/workspace_repository.py b/api/app/repositories/workspace_repository.py index 70ed7521..87b0e20f 100644 --- a/api/app/repositories/workspace_repository.py +++ b/api/app/repositories/workspace_repository.py @@ -115,6 +115,7 @@ class WorkspaceRepository: self.db.query(Workspace) .join(WorkspaceMember, Workspace.id == WorkspaceMember.workspace_id) .filter(WorkspaceMember.user_id == user_id) + .filter(WorkspaceMember.is_active.is_(True)) .filter(Workspace.is_active.is_(True)) .order_by(Workspace.updated_at.desc()) .all() diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py index 9ee98fa0..6f102695 100644 --- a/api/app/services/workspace_service.py +++ b/api/app/services/workspace_service.py @@ -70,10 +70,10 @@ def delete_workspace_member( _check_workspace_admin_permission(db, workspace_id, user) workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id) if not workspace_member: - raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_MEMBER_NOT_FOUND) + raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_NOT_FOUND) if workspace_member.workspace_id != workspace_id: - raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND) + raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_NOT_FOUND) try: workspace_member.is_active = False From 4f0b653a822119c0ccf2fbbc0b75262113743794 Mon Sep 17 00:00:00 2001 From: lixiangcheng1 Date: Thu, 26 Feb 2026 19:04:42 +0800 Subject: [PATCH 003/164] =?UTF-8?q?=E3=80=90fix]The=20complexity=20and=20v?= =?UTF-8?q?olume=20of=20the=20document=20content=20require=20an=20extended?= =?UTF-8?q?=20timeframe?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/app/celery_app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/app/celery_app.py b/api/app/celery_app.py index db78a368..265cd2ab 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -43,8 +43,8 @@ celery_app.conf.update( task_ignore_result=False, # 超时设置 - task_time_limit=1800, # 30分钟硬超时 - task_soft_time_limit=1500, # 25分钟软超时 + task_time_limit=3600, # 60分钟硬超时 + task_soft_time_limit=3000, # 50分钟软超时 # Worker 设置 (per-worker settings are in docker-compose command line) worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution From 550bd4da231f67e78f713a2424574efb3a356596 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Feb 2026 14:47:23 +0800 Subject: [PATCH 004/164] [fix]Modify the person who generates the user summary --- .../core/memory/utils/prompt/prompt_utils.py | 14 +++++++--- .../utils/prompt/prompts/user_summary.jinja2 | 26 ++++++++++++++----- api/app/services/user_memory_service.py | 24 ++++++++++++++++- 3 files changed, 54 insertions(+), 10 deletions(-) diff --git a/api/app/core/memory/utils/prompt/prompt_utils.py b/api/app/core/memory/utils/prompt/prompt_utils.py index 50d31f2a..d88f50cf 100644 --- a/api/app/core/memory/utils/prompt/prompt_utils.py +++ b/api/app/core/memory/utils/prompt/prompt_utils.py @@ -400,7 +400,8 @@ async def render_user_summary_prompt( user_id: str, entities: str, statements: str, - language: str = "zh" + language: str = "zh", + user_display_name: str = None ) -> str: """ Renders the user summary prompt using the user_summary.jinja2 template. @@ -410,16 +411,22 @@ async def render_user_summary_prompt( entities: Core entities with frequency information statements: Representative statement samples language: The language to use for summary generation ("zh" for Chinese, "en" for English) + user_display_name: Display name for the user (e.g., other_name or "该用户"/"the user") Returns: Rendered prompt content as string """ + # 如果没有提供 user_display_name,使用默认值 + if user_display_name is None: + user_display_name = "该用户" if language == "zh" else "the user" + template = prompt_env.get_template("user_summary.jinja2") rendered_prompt = template.render( user_id=user_id, entities=entities, statements=statements, - language=language + language=language, + user_display_name=user_display_name ) # 记录渲染结果到提示日志 @@ -429,7 +436,8 @@ async def render_user_summary_prompt( 'user_id': user_id, 'entities_len': len(entities), 'statements_len': len(statements), - 'language': language + 'language': language, + 'user_display_name': user_display_name }) return rendered_prompt diff --git a/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 b/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 index 35619112..30b48719 100644 --- a/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 @@ -14,8 +14,8 @@ Your task is to generate a comprehensive user profile based on the provided enti {% endif %} ===Inputs=== -{% if user_id %} -- User ID: {{ user_id }} +{% if user_display_name %} +- User Display Name: {{ user_display_name }} {% endif %} {% if entities %} - Core Entities & Frequency: {{ entities }} @@ -33,6 +33,20 @@ Your task is to generate a comprehensive user profile based on the provided enti 3. Avoid excessive adjectives and empty phrases 4. Strictly follow the output format specified below +{% if language == "zh" %} +**【严格人称规定】** +- 在描述用户时,必须使用"{{ user_display_name }}"作为人称 +- 绝对禁止使用用户ID(如 {{ user_id }})来称呼用户 +- 绝对禁止在摘要中出现任何形式的UUID或ID字符串 +- 如果需要指代用户,只能使用"{{ user_display_name }}"或相应的代词(他/她/TA) +{% else %} +**【STRICT PRONOUN RULES】** +- When describing the user, you MUST use "{{ user_display_name }}" as the reference +- It is ABSOLUTELY FORBIDDEN to use the user ID (such as {{ user_id }}) to refer to the user +- It is ABSOLUTELY FORBIDDEN to include any form of UUID or ID string in the summary +- If you need to refer to the user, you can ONLY use "{{ user_display_name }}" or appropriate pronouns (he/she/they) +{% endif %} + **Section-Specific Requirements:** {% if language == "zh" %} @@ -103,13 +117,13 @@ Your task is to generate a comprehensive user profile based on the provided enti {% if language == "zh" %} Example Input: -- User ID: user_12345 +- User Display Name: 张三 - Core Entities & Frequency: 产品经理 (15), AI (12), 深圳 (10), 数据分析 (8), 团队协作 (7) - Representative Statement Samples: 我在深圳从事产品经理工作已经5年了 | 我相信好的产品源于对用户需求的深刻理解 | 我喜欢在团队中起到协调作用 | 数据驱动决策是我的工作原则 Example Output: 【基本介绍】 -我是张三,一名充满热情的高级产品经理。在过去的5年里,我专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。我相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。 +张三是一名充满热情的高级产品经理,在深圳工作。在过去的5年里,张三专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。张三相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。 【性格特点】 性格开朗,善于沟通,注重细节。喜欢在团队中起到协调作用,帮助大家达成共识。面对挑战时保持乐观,相信每个问题都有解决方案。 @@ -121,13 +135,13 @@ Example Output: "让每一个产品决策都充满温度。" {% else %} Example Input: -- User ID: user_12345 +- User Display Name: John - Core Entities & Frequency: Product Manager (15), AI (12), San Francisco (10), Data Analysis (8), Team Collaboration (7) - Representative Statement Samples: I have been working as a product manager in San Francisco for 5 years | I believe good products come from deep understanding of user needs | I enjoy playing a coordinating role in teams | Data-driven decision making is my work principle Example Output: 【Basic Introduction】 -This is a passionate senior product manager based in San Francisco. Over the past 5 years, they have focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. They believe good products stem from deep understanding of user needs and continuous exploration of technological possibilities. +John is a passionate senior product manager based in San Francisco. Over the past 5 years, John has focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. John believes good products stem from deep understanding of user needs and continuous exploration of technological possibilities. 【Personality Traits】 Outgoing personality with excellent communication skills and attention to detail. Enjoys playing a coordinating role in teams, helping everyone reach consensus. Maintains optimism when facing challenges, believing every problem has a solution. diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 80413c12..e34756b9 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -1163,11 +1163,32 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st """ from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt from app.core.language_utils import validate_language + from app.repositories.end_user_repository import EndUserRepository + from app.db import get_db import re # 验证语言参数 language = validate_language(language) + # 获取用户的 other_name 字段 + user_display_name = "该用户" if language == "zh" else "the user" + if end_user_id: + try: + # 获取数据库会话并查询用户信息 + db = next(get_db()) + try: + repo = EndUserRepository(db) + end_user = repo.get_by_id(uuid.UUID(end_user_id)) + if end_user and end_user.other_name: + user_display_name = end_user.other_name + logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}") + else: + logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}") + finally: + db.close() + except Exception as e: + logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}") + # 创建 UserSummaryHelper 实例 user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_end_user_id", "group_123")) @@ -1184,7 +1205,8 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st user_id=user_summary_tool.user_id, entities=", ".join(entity_lines) if entity_lines else "(空)" if language == "zh" else "(empty)", statements=" | ".join(statement_samples) if statement_samples else "(空)" if language == "zh" else "(empty)", - language=language + language=language, + user_display_name=user_display_name ) messages = [ From 97d8168824a2d9afb20b44a26fe99ca8cf1230ba Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Feb 2026 14:59:28 +0800 Subject: [PATCH 005/164] [fix]Reduce the default number of items returned for popular tags --- api/app/core/memory/analytics/hot_memory_tags.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index f99b811e..5ffc6fed 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -139,10 +139,10 @@ async def get_raw_tags_from_db( return [(record["name"], record["frequency"]) for record in results] -async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: +async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = False) -> List[Tuple[str, int]]: """ 获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。 - 查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。 + 查询更多的标签(limit=10)给LLM提供更丰富的上下文进行筛选。 Args: end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id From f7d92be5ea452f38142a47f2272ec2ce69d397d3 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Feb 2026 15:08:06 +0800 Subject: [PATCH 006/164] [changes]Ensure that there are sufficient labels for LLM to process, and control the number of label returns. --- api/app/core/memory/analytics/hot_memory_tags.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index 5ffc6fed..abb0f138 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -142,11 +142,11 @@ async def get_raw_tags_from_db( async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = False) -> List[Tuple[str, int]]: """ 获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。 - 查询更多的标签(limit=10)给LLM提供更丰富的上下文进行筛选。 + 查询更多的标签(40条)给LLM提供更丰富的上下文进行筛选,但最终返回数量由limit参数控制。 Args: end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id - limit: 返回的标签数量限制 + limit: 最终返回的标签数量限制(默认10) by_user: 是否按user_id查询(默认False,按end_user_id查询) Raises: @@ -161,8 +161,9 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = # 使用项目的Neo4jConnector connector = Neo4jConnector() try: - # 1. 从数据库获取原始排名靠前的标签 - raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user) + # 1. 从数据库获取原始排名靠前的标签(查询40条给LLM提供更丰富的上下文) + query_limit = 40 + raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user) if not raw_tags_with_freq: return [] @@ -177,7 +178,8 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = if tag in meaningful_tag_names: final_tags.append((tag, freq)) - return final_tags + # 4. 限制返回的标签数量 + return final_tags[:limit] finally: # 确保关闭连接 await connector.close() From 5253cf3899e265208713d95b3281457b05298b2a Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Feb 2026 16:09:22 +0800 Subject: [PATCH 007/164] [fix]Address the shortcomings of intelligent pruning --- .../data_preprocessing/data_pruning.py | 518 ++++++++++++++---- 1 file changed, 423 insertions(+), 95 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index d19e511b..2d0142c6 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -5,14 +5,17 @@ - 对话级一次性抽取判定相关性 - 仅对"不相关对话"的消息按比例删除 - 重要信息(时间、编号、金额、联系方式、地址等)优先保留 +- 改进版:增强重要性判断、智能填充消息识别、问答对保护、并发优化 """ +import asyncio import os import hashlib import json import re +from collections import OrderedDict from datetime import datetime -from typing import List, Optional +from typing import List, Optional, Dict, Tuple, Set from pydantic import BaseModel, Field from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext @@ -36,6 +39,23 @@ class DialogExtractionResponse(BaseModel): keywords: List[str] = Field(default_factory=list) +class MessageImportanceResponse(BaseModel): + """消息重要性批量判断的结构化返回(用于LLM语义判断)。 + + - importance_scores: 消息索引到重要性分数的映射 (0-10分) + - reasons: 可选的判断理由 + """ + importance_scores: Dict[int, int] = Field(default_factory=dict, description="消息索引到重要性分数(0-10)的映射") + reasons: Optional[Dict[int, str]] = Field(default_factory=dict, description="可选的判断理由") + + +class QAPair(BaseModel): + """问答对模型,用于识别和保护对话中的问答结构。""" + question_idx: int = Field(..., description="问题消息的索引") + answer_idx: int = Field(..., description="答案消息的索引") + confidence: float = Field(default=1.0, description="问答对的置信度(0-1)") + + class SemanticPruner: """语义剪枝:在预处理与分块之间过滤与场景不相关内容。 @@ -43,109 +63,353 @@ class SemanticPruner: 重要信息(时间、编号、金额、联系方式、地址等)优先保留。 """ - def __init__(self, config: Optional[PruningConfig] = None, llm_client=None): - cfg_dict = get_pruning_config() if config is None else config.model_dump() - self.config = PruningConfig.model_validate(cfg_dict) + def __init__(self, config: Optional[PruningConfig] = None, llm_client=None, language: str = "zh", max_concurrent: int = 5): + # 如果没有提供config,使用默认配置 + if config is None: + # 使用默认的剪枝配置 + config = PruningConfig( + pruning_switch=False, # 默认关闭剪枝,保持向后兼容 + pruning_scene="education", + pruning_threshold=0.5 + ) + + self.config = config self.llm_client = llm_client + self.language = language # 保存语言配置 + self.max_concurrent = max_concurrent # 新增:最大并发数 + # Load Jinja2 template self.template = prompt_env.get_template("extracat_Pruning.jinja2") - # 对话抽取缓存:避免同一对话重复调用 LLM / 重复渲染 - self._dialog_extract_cache: dict[str, DialogExtractionResponse] = {} + + # 对话抽取缓存:使用 OrderedDict 实现 LRU 缓存 + self._dialog_extract_cache: OrderedDict[str, DialogExtractionResponse] = OrderedDict() + self._cache_max_size = 1000 # 缓存大小限制 + # 运行日志:收集关键终端输出,便于写入 JSON self.run_logs: List[str] = [] - # 采用顺序处理,移除并发配置以简化与稳定执行 + + # 扩展的填充词库(包含表情符号和网络用语) + self._extended_fillers = [ + # 基础寒暄 + "你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦", + "好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢", + "拜拜", "再见", "88", "拜", "回见", + # 口头禅 + "哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia", + "额", "呃", "啊", "诶", "唉", "哎", "嗯哼", + # 确认词 + "是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了", + # 标点和符号 + "。。。", "...", "???", "???", "!!!", "!!!", + # 表情符号(文本形式) + "[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]", + "[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]", + "[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]", + "[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]", + # 网络用语 + "hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok", + "emmm", "emm", "em", "mmp", "wtf", "omg", + ] def _is_important_message(self, message: ConversationMessage) -> bool: """基于启发式规则识别重要信息消息,优先保留。 - - 含日期/时间(如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)。 - - 含编号/ID/订单号/申请号/账号/电话/金额等关键字段。 - - 关键词:"时间"、"日期"、"编号"、"订单"、"流水"、"金额"、"¥"、"元"、"电话"、"手机号"、"邮箱"、"地址"。 + 改进版:增强了规则覆盖范围,包括: + - 含日期/时间(如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午) + - 含编号/ID/订单号/申请号/账号/电话/金额等关键字段 + - 关键词:"时间"、"日期"、"编号"、"订单"、"流水"、"金额"、"¥"、"元"、"电话"、"手机号"、"邮箱"、"地址" + - 新增:问句识别、决策性语句、承诺性语句 """ - import re text = message.msg.strip() if not text: return False + patterns = [ - r"\b\d{4}-\d{1,2}-\d{1,2}\b", - r"\b\d{1,2}:\d{2}\b", + # 原有模式 + r"\d{4}-\d{1,2}-\d{1,2}", # 修复:移除 \b 边界,因为中文前后没有单词边界 + r"\d{1,2}:\d{2}", # 修复:移除 \b r"\d{4}年\d{1,2}月\d{1,2}日", - r"上午|下午|AM|PM", - r"订单号|工单|申请号|编号|ID|账号|账户", - r"电话|手机号|微信|QQ|邮箱", - r"地址|地点", - r"金额|费用|价格|¥|¥|\d+元", - r"时间|日期|有效期|截止", + r"上午|下午|AM|PM|今天|明天|后天|昨天|前天|本周|下周|上周|本月|下月|上月", + r"订单号|工单|申请号|编号|ID|账号|账户|流水号|单号", + r"电话|手机号|微信|QQ|邮箱|联系方式", + r"地址|地点|位置|门牌号", + r"金额|费用|价格|¥|¥|\d+元|人民币|美元|欧元", + r"时间|日期|有效期|截止|期限|到期", + # 新增模式 + r"什么|为什么|怎么|如何|哪里|哪个|谁|多少|几点|何时", # 问句关键词 + r"必须|一定|务必|需要|要求|规定|应该", # 决策性语句 + r"承诺|保证|确保|负责|同意|答应", # 承诺性语句 + r"\d{11}", # 11位手机号 + r"\d{3,4}-\d{7,8}", # 固定电话 + r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", # 邮箱 ] + for p in patterns: if re.search(p, text, flags=re.IGNORECASE): return True + + # 检查是否为问句(以问号结尾或包含疑问词) + if text.endswith("?") or text.endswith("?"): + return True + return False + def _importance_score(self, message: ConversationMessage) -> int: """为重要消息打分,用于在保留比例内优先保留更关键的内容。 - 简单启发:匹配到的类别越多、越关键分值越高。 + 改进版:更细致的评分体系(0-10分) """ - import re text = message.msg.strip() score = 0 + weights = [ - (r"\b\d{4}-\d{1,2}-\d{1,2}\b", 3), - (r"\b\d{1,2}:\d{2}\b", 2), + # 高优先级(4-5分) + (r"订单号|工单|申请号|编号|ID|账号|账户", 5), + (r"金额|费用|价格|¥|¥|\d+元", 5), + (r"\d{11}", 4), # 手机号 + (r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱 + + # 中优先级(2-3分) + (r"\d{4}-\d{1,2}-\d{1,2}", 3), # 修复:移除 \b (r"\d{4}年\d{1,2}月\d{1,2}日", 3), - (r"订单号|工单|申请号|编号|ID|账号|账户", 4), - (r"电话|手机号|微信|QQ|邮箱", 3), - (r"地址|地点", 2), - (r"金额|费用|价格|¥|¥|\d+元", 4), - (r"时间|日期|有效期|截止", 2), + (r"电话|手机号|微信|QQ|联系方式", 3), + (r"地址|地点|位置", 2), + (r"时间|日期|有效期|截止|明天|后天|下周|下月", 2), # 新增时间相关词 + + # 低优先级(1分) + (r"\d{1,2}:\d{2}", 1), # 修复:移除 \b + (r"上午|下午|AM|PM", 1), ] + for p, w in weights: if re.search(p, text, flags=re.IGNORECASE): score += w - return score + + # 问句加分 + if text.endswith("?") or text.endswith("?"): + score += 2 + + # 长度加分(较长的消息通常包含更多信息) + if len(text) > 50: + score += 1 + if len(text) > 100: + score += 1 + + return min(score, 10) # 最高10分 def _is_filler_message(self, message: ConversationMessage) -> bool: """检测典型寒暄/口头禅/确认类短消息,用于跳过LLM分类以加速。 + 改进版:扩展了填充词库,支持表情符号和网络用语 满足以下之一视为填充消息: - - 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体; - - 常见词:你好/您好/在吗/嗯/嗯嗯/哦/好的/好/行/可以/不可以/谢谢/拜拜/再见/哈哈/呵呵/哈哈哈/。。。/??。 + - 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体 + - 在扩展填充词库中 + - 纯表情符号 """ - import re t = message.msg.strip() if not t: return True - # 常见填充语 - fillers = [ - "你好", "您好", "在吗", "嗯", "嗯嗯", "哦", "好的", "好", "行", "可以", "不可以", "谢谢", - "拜拜", "再见", "哈哈", "呵呵", "哈哈哈", "。。。", "??", "??" - ] - if t in fillers: + + # 检查是否在扩展填充词库中 + if t in self._extended_fillers: return True + + # 检查是否为纯表情符号(方括号包裹) + if re.fullmatch(r"(\[[^\]]+\])+", t): + return True + + # 检查是否为纯emoji(Unicode表情) + emoji_pattern = re.compile( + "[" + "\U0001F600-\U0001F64F" # 表情符号 + "\U0001F300-\U0001F5FF" # 符号和象形文字 + "\U0001F680-\U0001F6FF" # 交通和地图符号 + "\U0001F1E0-\U0001F1FF" # 旗帜 + "\U00002702-\U000027B0" + "\U000024C2-\U0001F251" + "]+", flags=re.UNICODE + ) + if emoji_pattern.fullmatch(t): + return True + # 长度与字符类型判断 if len(t) <= 8: # 非数字、无关键实体的短文本 if not re.search(r"[0-9]", t) and not self._is_important_message(message): # 主要是标点或简单确认词 - if re.fullmatch(r"[。!?,.!?…·\s]+", t) or t in fillers: + if re.fullmatch(r"[。!?,.!?…·\s]+", t): return True + return False + + async def _batch_evaluate_importance_with_llm( + self, + messages: List[ConversationMessage], + context: str = "" + ) -> Dict[int, int]: + """使用LLM批量评估消息的重要性(语义层面)。 + + Args: + messages: 消息列表 + context: 对话上下文(可选) + + Returns: + 消息索引到重要性分数(0-10)的映射 + """ + if not self.llm_client or not messages: + return {} + + # 构建批量评估的提示词 + msg_list = [] + for idx, msg in enumerate(messages): + msg_list.append(f"{idx}. {msg.msg}") + + msg_text = "\n".join(msg_list) + + prompt = f"""请评估以下消息的重要性,给每条消息打分(0-10分): +- 0-2分:无意义的寒暄、口头禅、纯表情 +- 3-5分:一般性对话,有一定信息量但不关键 +- 6-8分:包含重要信息(时间、地点、人物、事件等) +- 9-10分:关键决策、承诺、重要数据 + +对话上下文: +{context if context else "无"} + +待评估的消息: +{msg_text} + +请以JSON格式返回,格式为: +{{ + "importance_scores": {{ + "0": 分数, + "1": 分数, + ... + }} +}} +""" + + try: + messages_for_llm = [ + {"role": "system", "content": "你是一个专业的对话分析助手,擅长评估消息的重要性。"}, + {"role": "user", "content": prompt} + ] + + response = await self.llm_client.response_structured( + messages_for_llm, + MessageImportanceResponse + ) + + # 转换字符串键为整数键 + return {int(k): v for k, v in response.importance_scores.items()} + except Exception as e: + self._log(f"[剪枝-LLM] 批量重要性评估失败: {str(e)[:100]}") + return {} + + def _identify_qa_pairs(self, messages: List[ConversationMessage]) -> List[QAPair]: + """识别对话中的问答对,用于保护问答结构的完整性。 + + Args: + messages: 消息列表 + + Returns: + 问答对列表 + """ + qa_pairs = [] + + for i in range(len(messages) - 1): + current_msg = messages[i].msg.strip() + next_msg = messages[i + 1].msg.strip() + + # 简单规则:如果当前消息是问句,下一条消息可能是答案 + is_question = ( + current_msg.endswith("?") or + current_msg.endswith("?") or + any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗"]) + ) + + if is_question and next_msg: + # 检查下一条消息是否像答案(不是另一个问句) + is_answer = not (next_msg.endswith("?") or next_msg.endswith("?")) + + if is_answer: + qa_pairs.append(QAPair( + question_idx=i, + answer_idx=i + 1, + confidence=0.8 # 基于规则的置信度 + )) + + return qa_pairs + + def _get_protected_indices( + self, + messages: List[ConversationMessage], + qa_pairs: List[QAPair], + window_size: int = 2 + ) -> Set[int]: + """获取需要保护的消息索引集合(问答对+上下文窗口)。 + + Args: + messages: 消息列表 + qa_pairs: 问答对列表 + window_size: 上下文窗口大小(前后各保留几条消息) + + Returns: + 需要保护的消息索引集合 + """ + protected = set() + + for qa_pair in qa_pairs: + # 保护问答对本身 + protected.add(qa_pair.question_idx) + protected.add(qa_pair.answer_idx) + + # 保护上下文窗口 + for offset in range(-window_size, window_size + 1): + q_idx = qa_pair.question_idx + offset + a_idx = qa_pair.answer_idx + offset + + if 0 <= q_idx < len(messages): + protected.add(q_idx) + if 0 <= a_idx < len(messages): + protected.add(a_idx) + + return protected async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse: """对话级一次性抽取:从整段对话中提取重要信息并判定相关性。 - - 仅使用 LLM 结构化输出; + 改进版: + - LRU缓存管理 + - 重试机制 + - 降级策略 """ # 缓存命中则直接返回(场景+内容作为键) cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest() + + # LRU缓存:如果命中,移到末尾(最近使用) if cache_key in self._dialog_extract_cache: + self._dialog_extract_cache.move_to_end(cache_key) return self._dialog_extract_cache[cache_key] - rendered = self.template.render(pruning_scene=self.config.pruning_scene, dialog_text=dialog_text) - log_template_rendering("extracat_Pruning.jinja2", {"pruning_scene": self.config.pruning_scene}) + # LRU缓存大小限制:超过限制时删除最旧的条目 + if len(self._dialog_extract_cache) >= self._cache_max_size: + # 删除最旧的条目(OrderedDict的第一个) + oldest_key = next(iter(self._dialog_extract_cache)) + del self._dialog_extract_cache[oldest_key] + self._log(f"[剪枝-缓存] LRU缓存已满,删除最旧条目") + + rendered = self.template.render( + pruning_scene=self.config.pruning_scene, + dialog_text=dialog_text, + language=self.language + ) + log_template_rendering("extracat_Pruning.jinja2", { + "pruning_scene": self.config.pruning_scene, + "language": self.language + }) log_prompt_rendering("pruning-extract", rendered) - # 强制使用 LLM;移除正则回退 + # 强制使用 LLM if not self.llm_client: raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。") @@ -153,12 +417,32 @@ class SemanticPruner: {"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"}, {"role": "user", "content": rendered}, ] - try: - ex = await self.llm_client.response_structured(messages, DialogExtractionResponse) - self._dialog_extract_cache[cache_key] = ex - return ex - except Exception as e: - raise RuntimeError("LLM 结构化抽取失败;请检查 LLM 配置或重试。") from e + + # 重试机制 + max_retries = 3 + for attempt in range(max_retries): + try: + ex = await self.llm_client.response_structured(messages, DialogExtractionResponse) + self._dialog_extract_cache[cache_key] = ex + return ex + except Exception as e: + if attempt < max_retries - 1: + self._log(f"[剪枝-LLM] 第 {attempt + 1} 次尝试失败,重试中... 错误: {str(e)[:100]}") + await asyncio.sleep(0.5 * (attempt + 1)) # 指数退避 + continue + else: + # 降级策略:标记为相关,避免误删 + self._log(f"[剪枝-LLM] LLM 调用失败 {max_retries} 次,使用降级策略(标记为相关)") + fallback_response = DialogExtractionResponse( + is_related=True, + times=[], + ids=[], + amounts=[], + contacts=[], + addresses=[], + keywords=[] + ) + return fallback_response def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool: """判断消息是否包含任意抽取到的重要片段。""" @@ -248,12 +532,15 @@ class SemanticPruner: async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]: """数据集层面:全局消息级剪枝,保留所有对话。 - - 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动。 - - 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留。 - - 删除总量 = 阈值 * 全部不相关可删消息数,按可删容量比例分配;顺序删除。 - - 保证每段对话至少保留1条消息,不会删除整段对话。 + 改进版: + - 并发处理对话级相关性判断 + - 问答对识别和保护 + - 优化删除策略,保持上下文连贯性 + - 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动 + - 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留 + - 保证每段对话至少保留1条消息,不会删除整段对话 """ - # 如果剪枝功能关闭,直接返回原始数据集。 + # 如果剪枝功能关闭,直接返回原始数据集 if not self.config.pruning_switch: return dialogs @@ -264,29 +551,36 @@ class SemanticPruner: proportion = 0.9 if proportion < 0.0: proportion = 0.0 - evaluated_dialogs = [] # list of dicts: {dialog, is_related} self._log( f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}" ) - # 对话级相关性分类(一次性对整段对话文本进行判断,顺序执行并复用缓存) - evaluated_dialogs = [] - for idx, dd in enumerate(dialogs): - try: - ex = await self._extract_dialog_important(dd.content) - evaluated_dialogs.append({ - "dialog": dd, - "is_related": bool(ex.is_related), - "index": idx, - "extraction": ex - }) - except Exception: - evaluated_dialogs.append({ - "dialog": dd, - "is_related": True, - "index": idx, - "extraction": None - }) + + # 并发处理对话级相关性分类 + semaphore = asyncio.Semaphore(self.max_concurrent) + + async def classify_dialog(idx: int, dd: DialogData): + async with semaphore: + try: + ex = await self._extract_dialog_important(dd.content) + return { + "dialog": dd, + "is_related": bool(ex.is_related), + "index": idx, + "extraction": ex + } + except Exception as e: + self._log(f"[剪枝-并发] 对话 {idx} 分类失败: {str(e)[:100]}") + return { + "dialog": dd, + "is_related": True, # 失败时标记为相关,避免误删 + "index": idx, + "extraction": None + } + + # 并发执行所有对话的分类 + tasks = [classify_dialog(idx, dd) for idx, dd in enumerate(dialogs)] + evaluated_dialogs = await asyncio.gather(*tasks) # 统计相关 / 不相关对话 not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]] @@ -300,7 +594,6 @@ class SemanticPruner: inds = [i["index"] + 1 for i in items] if len(inds) <= cap: return inds - # 超过上限时只打印前cap个,并标注总数 return inds[:cap] + ["...", f"共{len(inds)}个"] rel_inds = _fmt_indices(related_dialogs) @@ -309,59 +602,83 @@ class SemanticPruner: result: List[DialogData] = [] if not_related_dialogs: - # 为每个不相关对话进行一次性抽取,识别重要/不重要(避免逐条 LLM) + # 为每个不相关对话进行分析 per_dialog_info = {} total_unrelated = 0 - total_capacity = 0 + for d in not_related_dialogs: dd = d["dialog"] extraction = d.get("extraction") if extraction is None: extraction = await self._extract_dialog_important(dd.content) + # 合并所有重要标记 tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords msgs = dd.context.msgs - # 分类消息 - imp_unrel_msgs = [m for m in msgs if self._msg_matches_tokens(m, tokens) or self._is_important_message(m)] - unimp_unrel_msgs = [m for m in msgs if m not in imp_unrel_msgs] + + # 识别问答对 + qa_pairs = self._identify_qa_pairs(msgs) + protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=1) + + # 分类消息(考虑问答对保护) + imp_unrel_msgs = [] + unimp_unrel_msgs = [] + + for idx, m in enumerate(msgs): + # 问答对中的消息自动标记为重要 + if idx in protected_indices: + imp_unrel_msgs.append((idx, m)) + elif self._msg_matches_tokens(m, tokens) or self._is_important_message(m): + imp_unrel_msgs.append((idx, m)) + elif not self._is_filler_message(m): + unimp_unrel_msgs.append((idx, m)) + # 填充消息不加入任何列表,优先删除 + # 重要消息按重要性排序 - imp_sorted_ids = [id(m) for m in sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))] + imp_sorted = sorted(imp_unrel_msgs, key=lambda x: self._importance_score(x[1])) + imp_sorted_ids = [id(m) for _, m in imp_sorted] + info = { "dialog": dd, "total_msgs": len(msgs), "unrelated_count": len(msgs), "imp_ids_sorted": imp_sorted_ids, - "unimp_ids": [id(m) for m in unimp_unrel_msgs], + "unimp_ids": [id(m) for _, m in unimp_unrel_msgs], + "protected_indices": protected_indices, + "qa_pairs_count": len(qa_pairs), } per_dialog_info[d["index"]] = info total_unrelated += info["unrelated_count"] - # 全局删除配额:比例作用于全部不相关消息(重要+不重要) + + # 全局删除配额计算 global_delete = int(total_unrelated * proportion) if proportion > 0 and total_unrelated > 0 and global_delete == 0: global_delete = 1 - # 每段的最大可删容量:不重要全部 + 重要最多删除 floor(len(重要)*比例),且至少保留1条消息 + + # 每段的最大可删容量 capacities = [] for d in not_related_dialogs: idx = d["index"] info = per_dialog_info[idx] - # 统计重要数量 imp_count = len(info["imp_ids_sorted"]) unimp_count = len(info["unimp_ids"]) imp_cap = int(imp_count * proportion) cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1)) capacities.append(cap) + total_capacity = sum(capacities) if global_delete > total_capacity: - print(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}(重要消息按比例保留)。将按最大可删执行。") + self._log(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}。将按最大可删执行。") global_delete = total_capacity - # 配额分配:按不相关消息占比分配到各对话,但不超过各自容量 + # 配额分配 alloc = [] for i, d in enumerate(not_related_dialogs): idx = d["index"] info = per_dialog_info[idx] share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0 alloc.append(min(share, capacities[i])) + allocated = sum(alloc) rem = global_delete - allocated turn = 0 @@ -378,34 +695,40 @@ class SemanticPruner: break turn += 1 - # 应用删除:相关对话不动;不相关按分配先删不重要,再删重要(低分优先) + # 应用删除 total_deleted_confirm = 0 for d in evaluated_dialogs: dd = d["dialog"] msgs = dd.context.msgs original = len(msgs) + if d["is_related"]: result.append(dd) continue + idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None) if idx_in_unrel is None: result.append(dd) continue + quota = alloc[idx_in_unrel] info = per_dialog_info[d["index"]] - # 计算本对话重要最多可删数量 + + # 计算删除ID imp_count = len(info["imp_ids_sorted"]) imp_del_cap = int(imp_count * proportion) - # 先构造顺序删除的"不重要ID集合"(按出现顺序前 quota 条) + unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))]) del_unimp = min(quota, len(unimp_delete_ids)) rem_quota = quota - del_unimp - # 再从重要里选低分优先的删除ID(不超过 imp_del_cap) + imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)]) + deleted_here = 0 actual_unimp_deleted = 0 actual_imp_deleted = 0 kept = [] + for m in msgs: mid = id(m) if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp: @@ -417,26 +740,30 @@ class SemanticPruner: deleted_here += 1 continue kept.append(m) + if not kept and msgs: kept = [msgs[0]] + dd.context.msgs = kept total_deleted_confirm += deleted_here + + qa_info = f",问答对={info['qa_pairs_count']}" if info['qa_pairs_count'] > 0 else "" self._log( - f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}" + f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}{qa_info}" ) result.append(dd) - self._log(f"[剪枝-数据集] 全局消息级顺序剪枝完成,总删除 {total_deleted_confirm} 条(不相关消息,重要按比例保留)。") + + self._log(f"[剪枝-数据集] 全局消息级剪枝完成,总删除 {total_deleted_confirm} 条(保护问答对和上下文)。") else: - # 全部相关:不执行剪枝 result = [d["dialog"] for d in evaluated_dialogs] + self._log(f"[剪枝-数据集] 剩余对话数={len(result)}") - # 将本次剪枝阶段的终端输出保存为 JSON 文件(仅在剪枝器内部完成) + # 保存日志 try: from app.core.config import settings settings.ensure_memory_output_dir() log_output_path = settings.get_memory_output_path("pruned_terminal.json") - # 去除日志前缀标签(如 [剪枝-数据集]、[剪枝-对话])后再解析为结构化字段保存 sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs] payload = self._parse_logs_to_structured(sanitized_logs) with open(log_output_path, "w", encoding="utf-8") as f: @@ -448,6 +775,7 @@ class SemanticPruner: if not result: print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断") return dialogs + return result def _log(self, msg: str) -> None: From f7aed9dd9807f7746bece974c7f964cfc2bbe540 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Feb 2026 16:45:34 +0800 Subject: [PATCH 008/164] [fix]Correct the flaws existing in the semantic segmentation method --- .../knowledge_extraction/chunk_extraction.py | 205 ++++++++++++++---- 1 file changed, 160 insertions(+), 45 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py index 40e98507..bbbf1c51 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py @@ -1,5 +1,7 @@ import os -from typing import Optional +from typing import Optional, List, Any +from enum import Enum +from pathlib import Path from app.core.logging_config import get_memory_logger from app.core.memory.models.message_models import DialogData, Chunk @@ -10,6 +12,20 @@ from app.core.memory.utils.config.config_utils import get_chunker_config logger = get_memory_logger(__name__) +class ChunkerStrategy(Enum): + """Supported chunking strategies.""" + RECURSIVE = "RecursiveChunker" + SEMANTIC = "SemanticChunker" + LATE = "LateChunker" + NEURAL = "NeuralChunker" + LLM = "LLMChunker" + + @classmethod + def get_valid_strategies(cls) -> List[str]: + """Get list of valid strategy names.""" + return [strategy.value for strategy in cls] + + class DialogueChunker: """A class that processes dialogues and fills them with chunks based on a specified strategy. @@ -17,23 +33,51 @@ class DialogueChunker: of different chunking strategies to dialogue data. """ - def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client=None): + def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client: Optional[Any] = None): """Initialize the DialogueChunker with a specific chunking strategy. Args: chunker_strategy: The chunking strategy to use (default: RecursiveChunker) - Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker + Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker, LLMChunker + llm_client: LLM client instance (required for LLMChunker strategy) + + Raises: + ValueError: If chunker_strategy is invalid or required parameters are missing """ - self.chunker_strategy = chunker_strategy - chunker_config_dict = get_chunker_config(chunker_strategy) - self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict) + # Validate strategy + valid_strategies = ChunkerStrategy.get_valid_strategies() + if chunker_strategy not in valid_strategies: + raise ValueError( + f"Invalid chunker_strategy: '{chunker_strategy}'. " + f"Must be one of {valid_strategies}" + ) - if self.chunker_config.chunker_strategy == "LLMChunker": - self.chunker_client = ChunkerClient(self.chunker_config, llm_client) - else: - self.chunker_client = ChunkerClient(self.chunker_config) + self.chunker_strategy = chunker_strategy + logger.info(f"Initializing DialogueChunker with strategy: {chunker_strategy}") + + try: + # Load and validate configuration + chunker_config_dict = get_chunker_config(chunker_strategy) + if not chunker_config_dict: + raise ValueError(f"Failed to load configuration for strategy: {chunker_strategy}") + + self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict) + + # Initialize chunker client + if self.chunker_config.chunker_strategy == "LLMChunker": + if not llm_client: + raise ValueError("llm_client is required for LLMChunker strategy") + self.chunker_client = ChunkerClient(self.chunker_config, llm_client) + else: + self.chunker_client = ChunkerClient(self.chunker_config) + + logger.info(f"DialogueChunker initialized successfully with strategy: {chunker_strategy}") + + except Exception as e: + logger.error(f"Failed to initialize DialogueChunker: {e}", exc_info=True) + raise - async def process_dialogue(self, dialogue: DialogData) -> list[Chunk]: + async def process_dialogue(self, dialogue: DialogData) -> List[Chunk]: """Process a dialogue by generating chunks and adding them to the DialogData object. Args: @@ -43,54 +87,125 @@ class DialogueChunker: A list of Chunk objects Raises: - ValueError: If chunking fails or returns empty chunks + ValueError: If dialogue is invalid or chunking fails + Exception: If chunking process encounters an error """ - result_dialogue = await self.chunker_client.generate_chunks(dialogue) - chunks = result_dialogue.chunks - - if not chunks or len(chunks) == 0: + # Validate input + if not dialogue: + raise ValueError("dialogue cannot be None") + + if not dialogue.context or not dialogue.context.msgs: raise ValueError( - f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. " - f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, " - f"Strategy: {self.chunker_config.chunker_strategy}" + f"Dialogue {dialogue.ref_id} has no messages to chunk. " + f"Context: {dialogue.context is not None}, " + f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}" ) + + logger.info( + f"Processing dialogue {dialogue.ref_id} with {len(dialogue.context.msgs)} messages " + f"using strategy: {self.chunker_strategy}" + ) + + try: + # Generate chunks + result_dialogue = await self.chunker_client.generate_chunks(dialogue) + chunks = result_dialogue.chunks - return chunks + # Validate results + if not chunks or len(chunks) == 0: + raise ValueError( + f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. " + f"Messages: {len(dialogue.context.msgs)}, " + f"Content length: {len(dialogue.content) if dialogue.content else 0}, " + f"Strategy: {self.chunker_config.chunker_strategy}" + ) - def save_chunking_results(self, dialogue: DialogData, output_path: Optional[str] = None) -> str: + logger.info( + f"Successfully generated {len(chunks)} chunks for dialogue {dialogue.ref_id}. " + f"Total characters processed: {len(dialogue.content) if dialogue.content else 0}" + ) + + return chunks + + except ValueError: + # Re-raise validation errors + raise + except Exception as e: + logger.error( + f"Error processing dialogue {dialogue.ref_id} with strategy {self.chunker_strategy}: {e}", + exc_info=True + ) + raise + + def save_chunking_results( + self, + chunks: List[Chunk], + dialogue: DialogData, + output_path: Optional[str] = None, + preview_length: int = 100 + ) -> str: """Save the chunking results to a file and return the output path. Args: - dialogue: The processed DialogData object with chunks - output_path: Optional path to save the output + chunks: List of Chunk objects to save + dialogue: The DialogData object that was processed + output_path: Optional path to save the output (defaults to current directory) + preview_length: Maximum length of content preview (default: 100) Returns: The path where the output was saved + + Raises: + ValueError: If chunks or dialogue is invalid + IOError: If file writing fails """ - if not output_path: - output_path = os.path.join( - os.path.dirname(__file__), "..", "..", - f"chunker_output_{self.chunker_strategy.lower()}.txt" - ) - - output_lines = [ - f"=== Chunking Results ({self.chunker_strategy}) ===", - f"Dialogue ID: {dialogue.ref_id}", - f"Original conversation has {len(dialogue.context.msgs)} messages", - f"Total characters: {len(dialogue.content)}", - f"Generated {len(dialogue.chunks)} chunks:" - ] + # Validate input + if not chunks: + raise ValueError("chunks list cannot be empty") + if not dialogue: + raise ValueError("dialogue cannot be None") - for i, chunk in enumerate(dialogue.chunks): - output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters") - output_lines.append(f" Content preview: {chunk.content}...") - if chunk.metadata: - output_lines.append(f" Metadata: {chunk.metadata}") + # Generate default output path if not provided + if not output_path: + output_dir = Path(__file__).parent.parent.parent + output_path = str(output_dir / f"chunker_output_{self.chunker_strategy.lower()}.txt") + + logger.info(f"Saving chunking results to: {output_path}") + + try: + # Prepare output content + output_lines = [ + f"=== Chunking Results ({self.chunker_strategy}) ===", + f"Dialogue ID: {dialogue.ref_id}", + f"Original conversation has {len(dialogue.context.msgs) if dialogue.context else 0} messages", + f"Total characters: {len(dialogue.content) if dialogue.content else 0}", + f"Generated {len(chunks)} chunks:", + "" + ] + + for i, chunk in enumerate(chunks, 1): + content_preview = chunk.content[:preview_length] if chunk.content else "" + if len(chunk.content) > preview_length: + content_preview += "..." + + output_lines.append(f" Chunk {i}: {len(chunk.content)} characters") + output_lines.append(f" Content preview: {content_preview}") + if chunk.metadata: + output_lines.append(f" Metadata: {chunk.metadata}") + output_lines.append("") - with open(output_path, "w", encoding="utf-8") as f: - f.write("\n".join(output_lines)) + # Write to file + with open(output_path, "w", encoding="utf-8") as f: + f.write("\n".join(output_lines)) - logger.info(f"Chunking results saved to: {output_path}") - return output_path + logger.info(f"Successfully saved chunking results to: {output_path}") + return output_path + + except IOError as e: + logger.error(f"Failed to write chunking results to {output_path}: {e}", exc_info=True) + raise + except Exception as e: + logger.error(f"Unexpected error saving chunking results: {e}", exc_info=True) + raise From 9916cf3265e11fe27cd74e725a01a557657f75c0 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Sat, 28 Feb 2026 10:29:14 +0800 Subject: [PATCH 009/164] feat(workflow): add Dify workflow import adapter and related APIs --- api/app/aioRedis.py | 44 +- api/app/controllers/app_controller.py | 72 +- api/app/core/config.py | 22 +- api/app/core/workflow/adapters/__init__.py | 8 + .../core/workflow/adapters/base_adapter.py | 88 +++ .../core/workflow/adapters/base_converter.py | 75 ++ .../core/workflow/adapters/dify/__init__.py | 4 + .../core/workflow/adapters/dify/converter.py | 659 ++++++++++++++++++ .../workflow/adapters/dify/dify_adapter.py | 239 +++++++ api/app/core/workflow/adapters/errors.py | 75 ++ .../workflow/adapters/memory_bear/__init__.py | 4 + .../memory_bear/memory_bear_adapter.py | 76 ++ api/app/core/workflow/adapters/registry.py | 34 + .../engine/stream_output_coordinator.py | 2 +- api/app/core/workflow/nodes/assigner/node.py | 2 + api/app/core/workflow/nodes/end/config.py | 22 +- api/app/core/workflow/nodes/enums.py | 1 + .../core/workflow/nodes/http_request/node.py | 2 +- .../core/workflow/nodes/knowledge/config.py | 2 +- api/app/core/workflow/nodes/start/config.py | 73 +- api/app/schemas/app_schema.py | 4 + api/app/schemas/workflow_schema.py | 47 +- api/app/services/app_service.py | 61 +- api/app/services/workflow_import_service.py | 102 +++ api/app/services/workflow_service.py | 31 + 25 files changed, 1625 insertions(+), 124 deletions(-) create mode 100644 api/app/core/workflow/adapters/__init__.py create mode 100644 api/app/core/workflow/adapters/base_adapter.py create mode 100644 api/app/core/workflow/adapters/base_converter.py create mode 100644 api/app/core/workflow/adapters/dify/__init__.py create mode 100644 api/app/core/workflow/adapters/dify/converter.py create mode 100644 api/app/core/workflow/adapters/dify/dify_adapter.py create mode 100644 api/app/core/workflow/adapters/errors.py create mode 100644 api/app/core/workflow/adapters/memory_bear/__init__.py create mode 100644 api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py create mode 100644 api/app/core/workflow/adapters/registry.py create mode 100644 api/app/services/workflow_import_service.py diff --git a/api/app/aioRedis.py b/api/app/aioRedis.py index c729a3dc..f758dd15 100644 --- a/api/app/aioRedis.py +++ b/api/app/aioRedis.py @@ -10,7 +10,6 @@ from app.core.config import settings # 设置日志记录器 logger = logging.getLogger(__name__) - # 创建连接池 pool = ConnectionPool.from_url( f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}", @@ -21,6 +20,7 @@ pool = ConnectionPool.from_url( ) aio_redis = redis.StrictRedis(connection_pool=pool) + async def get_redis_connection(): """获取Redis连接""" try: @@ -29,7 +29,8 @@ async def get_redis_connection(): logger.error(f"Redis连接失败: {str(e)}") return None -async def aio_redis_set(key: str, val: str|dict, expire: int = None): + +async def aio_redis_set(key: str, val: str | dict, expire: int = None): """设置Redis键值 Args: @@ -40,7 +41,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None): try: if isinstance(val, dict): val = json.dumps(val, ensure_ascii=False) - + if expire is not None: # 设置带过期时间的键值 await aio_redis.set(key, val, ex=expire) @@ -50,6 +51,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None): except Exception as e: logger.error(f"Redis set错误: {str(e)}") + async def aio_redis_get(key: str): """获取Redis键值""" try: @@ -58,6 +60,7 @@ async def aio_redis_get(key: str): logger.error(f"Redis get错误: {str(e)}") return None + async def aio_redis_delete(key: str): """删除Redis键""" try: @@ -66,6 +69,7 @@ async def aio_redis_delete(key: str): logger.error(f"Redis delete错误: {str(e)}") return None + async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool: """发布消息到Redis频道""" try: @@ -78,9 +82,10 @@ async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool: logger.error(f"Redis发布错误: {str(e)}") return False + class RedisSubscriber: """Redis订阅器""" - + def __init__(self, channel: str): self.channel = channel self.conn = None @@ -88,25 +93,25 @@ class RedisSubscriber: self.is_closed = False self._queue = asyncio.Queue() self._task = None - + async def start(self): """开始订阅""" if self.is_closed or self._task: return - + self._task = asyncio.create_task(self._receive_messages()) logger.info(f"开始订阅: {self.channel}") - + async def _receive_messages(self): """接收消息""" try: self.conn = await get_redis_connection() if not self.conn: return - + self.pubsub = self.conn.pubsub() await self.pubsub.subscribe(self.channel) - + while not self.is_closed: try: message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01) @@ -127,7 +132,7 @@ class RedisSubscriber: finally: await self._queue.put(None) await self._cleanup() - + async def _cleanup(self): """清理资源""" if self.pubsub: @@ -141,7 +146,7 @@ class RedisSubscriber: await self.conn.close() except Exception: pass - + async def get_message(self) -> Optional[Dict[str, Any]]: """获取消息""" if self.is_closed: @@ -153,7 +158,7 @@ class RedisSubscriber: except Exception as e: logger.error(f"获取消息错误: {str(e)}") return None - + async def close(self): """关闭订阅器""" if self.is_closed: @@ -163,32 +168,33 @@ class RedisSubscriber: self._task.cancel() await self._cleanup() + class RedisPubSubManager: """Redis发布订阅管理器""" - + def __init__(self): self.subscribers = {} - + async def publish(self, channel: str, message: Dict[str, Any]) -> bool: return await aio_redis_publish(channel, message) - + def get_subscriber(self, channel: str) -> RedisSubscriber: if channel in self.subscribers: subscriber = self.subscribers[channel] if not subscriber.is_closed: return subscriber - + subscriber = RedisSubscriber(channel) self.subscribers[channel] = subscriber return subscriber - + def cancel_subscription(self, channel: str) -> bool: if channel in self.subscribers: asyncio.create_task(self.subscribers[channel].close()) del self.subscribers[channel] return True return False - + def cancel_all_subscriptions(self) -> int: count = len(self.subscribers) for subscriber in self.subscribers.values(): @@ -196,6 +202,6 @@ class RedisPubSubManager: self.subscribers.clear() return count + # 全局实例 pubsub_manager = RedisPubSubManager() - diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index f1508114..e2849ad6 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -1,7 +1,8 @@ import uuid from typing import Optional, Annotated -from fastapi import APIRouter, Depends, Path +import yaml +from fastapi import APIRouter, Depends, Path, Form, UploadFile, File from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session @@ -17,12 +18,13 @@ from app.repositories.end_user_repository import EndUserRepository from app.schemas import app_schema from app.schemas.response_schema import PageData, PageMeta from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema -from app.schemas.workflow_schema import WorkflowConfigUpdate +from app.schemas.workflow_schema import WorkflowConfigUpdate, WorkflowImportSave from app.services import app_service, workspace_service from app.services.agent_config_helper import enrich_agent_config from app.services.app_service import AppService -from app.services.workflow_service import WorkflowService, get_workflow_service from app.services.app_statistics_service import AppStatisticsService +from app.services.workflow_import_service import WorkflowImportService +from app.services.workflow_service import WorkflowService, get_workflow_service router = APIRouter(prefix="/apps", tags=["Apps"]) logger = get_business_logger() @@ -65,7 +67,7 @@ def list_apps( # 当 ids 存在且不为 None 时,根据 ids 获取应用 if ids is not None: - app_ids = [id.strip() for id in ids.split(',') if id.strip()] + app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()] items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) items = [service._convert_to_schema(app, workspace_id) for app in items_orm] return success(data=items) @@ -879,6 +881,60 @@ async def update_workflow_config( return success(data=WorkflowConfigSchema.model_validate(cfg)) +@router.get("/{app_id}/workflow/export") +@cur_workspace_access_guard() +async def export_workflow_config( + app_id: uuid.UUID, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)] +): + """导出工作流配置为YAML文件""" + workflow_service = WorkflowService(db) + + return success(data={ + "content": workflow_service.export_workflow_dsl(app_id=app_id), + }) + + +@router.post("/workflow/import") +@cur_workspace_access_guard() +async def import_workflow_config( + file: UploadFile = File(...), + platform: str = Form(...), + app_id: str = Form(None), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) + +): + """从YAML内容导入工作流配置""" + if not file.filename.lower().endswith((".yaml", ".yml")): + return fail(msg="Only yaml file is allowed", code=BizCode.BAD_REQUEST) + + raw_text = (await file.read()).decode("utf-8") + import_service = WorkflowImportService(db) + config = yaml.safe_load(raw_text) + result = await import_service.upload_config(platform, config) + return success(data=result) + + +@router.post("/workflow/import/save") +@cur_workspace_access_guard() +async def save_workflow_import( + data: WorkflowImportSave, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + import_service = WorkflowImportService(db) + app = await import_service.save_workflow( + user_id=current_user.id, + workspace_id=current_user.current_workspace_id, + temp_id=data.temp_id, + name=data.name, + description=data.description, + ) + return success(data=app_schema.App.model_validate(app)) + + @router.get("/{app_id}/statistics", summary="应用统计数据") @cur_workspace_access_guard() def get_app_statistics( @@ -889,12 +945,14 @@ def get_app_statistics( current_user=Depends(get_current_user), ): """获取应用统计数据 - + Args: app_id: 应用ID start_date: 开始时间戳(毫秒) end_date: 结束时间戳(毫秒) - + db: 数据库连接 + current_user: 当前用户 + Returns: - daily_conversations: 每日会话数统计 - total_conversations: 总会话数 @@ -931,6 +989,8 @@ def get_workspace_api_statistics( Args: start_date: 开始时间戳(毫秒) end_date: 结束时间戳(毫秒) + db: 数据库连接 + current_user: 当前用户 Returns: 每日统计数据列表,每项包含: diff --git a/api/app/core/config.py b/api/app/core/config.py index 3a0c97b4..0962b545 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -16,18 +16,18 @@ class Settings: # cloud: SaaS 云服务版(全功能,按量计费) # enterprise: 企业私有化版(License 控制) DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community") - + # License 配置(企业版) LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json") LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com") - + # 计费服务配置(SaaS 版) BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "") - + # 基础 URL(用于 SSO 回调等) BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000") FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000") - + ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true" # API Keys Configuration OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "") @@ -57,7 +57,6 @@ class Settings: REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379")) REDIS_DB: int = int(os.getenv("REDIS_DB", "1")) REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "") - # ElasticSearch configuration ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1") @@ -91,7 +90,7 @@ class Settings: # Single Sign-On configuration ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true" - + # SSO 免登配置 SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300")) SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}") @@ -130,7 +129,7 @@ class Settings: # Server Configuration SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1") - FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api") + FILE_LOCAL_SERVER_URL: str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api") # ======================================================================== # Internal Configuration (not in .env, used by application code) @@ -225,6 +224,7 @@ class Settings: LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true" # workflow config + WORKFLOW_IMPORT_CACHE_TIMEOUT: int = int(os.getenv("WORKFLOW_IMPORT_CACHE_TIMEOUT", 1800)) WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600)) # ======================================================================== @@ -232,20 +232,20 @@ class Settings: # ======================================================================== # 通用本体文件路径列表(逗号分隔) GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl") - + # 是否启用通用本体类型功能 ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true" - + # Prompt 中最大类型数量 MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50")) - + # 核心通用类型列表(逗号分隔) CORE_GENERAL_TYPES: str = os.getenv( "CORE_GENERAL_TYPES", "Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building," "Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject" ) - + # 实验模式开关(允许通过 API 动态切换本体配置) ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true" diff --git a/api/app/core/workflow/adapters/__init__.py b/api/app/core/workflow/adapters/__init__.py new file mode 100644 index 00000000..141aa4ab --- /dev/null +++ b/api/app/core/workflow/adapters/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/24 15:54 +from app.core.workflow.adapters.dify.dify_adapter import DifyAdapter +from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter + +__all__ = ["DifyAdapter", "MemoryBearAdapter"] diff --git a/api/app/core/workflow/adapters/base_adapter.py b/api/app/core/workflow/adapters/base_adapter.py new file mode 100644 index 00000000..601c8ff2 --- /dev/null +++ b/api/app/core/workflow/adapters/base_adapter.py @@ -0,0 +1,88 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/24 15:58 +from abc import ABC, abstractmethod +from collections import defaultdict +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, Field + +from app.core.workflow.adapters.errors import ExceptionDefineition +from app.schemas.workflow_schema import ( + EdgeDefinition, + NodeDefinition, + VariableDefinition, + ExecutionConfig, + TriggerConfig +) + + +class PlatformType(StrEnum): + MEMORY_BEAR = "memory_bear" + DIFY = "dify" + COZE = "coze" + + +class PlatformMetadata(BaseModel): + platform_name: str + version: str + support_node_types: list[str] + + +class WorkflowParserResult(BaseModel): + success: bool + platform: PlatformMetadata + execution_config: ExecutionConfig + origin_config: dict[str, Any] + trigger: TriggerConfig | None + edges: list[EdgeDefinition] = Field(default_factory=list) + nodes: list[NodeDefinition] = Field(default_factory=list) + variables: list[VariableDefinition] = Field(default_factory=list) + warnings: list[ExceptionDefineition] = Field(default_factory=list) + errors: list[ExceptionDefineition] = Field(default_factory=list) + + +class WorkflowImportResult(BaseModel): + success: bool + temp_id: str | None = Field(..., description="cache id") + workflow_id: str | None = Field(..., description="workflow id") + edges: list[EdgeDefinition] = Field(default_factory=list) + nodes: list[NodeDefinition] = Field(default_factory=list) + variables: list[VariableDefinition] = Field(default_factory=list) + warnings: list[ExceptionDefineition] = Field(default_factory=list) + errors: list[ExceptionDefineition] = Field(default_factory=list) + + +class BasePlatformAdapter(ABC): + def __init__(self, config: dict[str, Any]): + self.config = config + self.nodes: list[NodeDefinition] = [] + self.edges: list[EdgeDefinition] = [] + self.conv_variables: list[VariableDefinition] = [] + + self.errors = [] + self.warnings = [] + + self.branch_node_cache = defaultdict(list) + self.error_branch_node_cache = [] + + @abstractmethod + def get_metadata(self) -> PlatformMetadata: + """get platform metadata""" + pass + + @abstractmethod + def validate_config(self) -> bool: + """platform configuration validate""" + pass + + @abstractmethod + def parse_workflow(self) -> WorkflowParserResult: + """parse platform configuration to local config""" + pass + + @abstractmethod + def map_node_type(self, platform_node_type: str) -> str: + pass diff --git a/api/app/core/workflow/adapters/base_converter.py b/api/app/core/workflow/adapters/base_converter.py new file mode 100644 index 00000000..eebde971 --- /dev/null +++ b/api/app/core/workflow/adapters/base_converter.py @@ -0,0 +1,75 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/26 14:32 +from abc import ABC, abstractmethod + +from app.core.workflow.variable.base_variable import DEFAULT_VALUE, VariableType + + +class BaseConverter(ABC): + @staticmethod + def _convert_string(var): + try: + return str(var) + except: + return DEFAULT_VALUE(VariableType.STRING) + + @staticmethod + def _convert_boolean(var): + try: + return bool(var) + except: + return DEFAULT_VALUE(VariableType.BOOLEAN) + + @staticmethod + def _convert_number(var): + try: + return float(var) + except: + return DEFAULT_VALUE(VariableType.NUMBER) + + @staticmethod + def _convert_object(var): + try: + return dict(var) + except: + return DEFAULT_VALUE(VariableType.OBJECT) + + @staticmethod + @abstractmethod + def _convert_file(var): + pass + + @staticmethod + def _convert_array_string(var): + try: + return list(var) + except: + return DEFAULT_VALUE(VariableType.ARRAY_STRING) + + @staticmethod + def _convert_array_number(var): + try: + return list(var) + except: + return DEFAULT_VALUE(VariableType.ARRAY_NUMBER) + + @staticmethod + def _convert_array_boolean(var): + try: + return list(var) + except: + return DEFAULT_VALUE(VariableType.ARRAY_BOOLEAN) + + @staticmethod + def _convert_array_object(var): + try: + return list(var) + except: + return DEFAULT_VALUE(VariableType.ARRAY_OBJECT) + + @staticmethod + @abstractmethod + def _convert_array_file(var): + pass diff --git a/api/app/core/workflow/adapters/dify/__init__.py b/api/app/core/workflow/adapters/dify/__init__.py new file mode 100644 index 00000000..7774dcaa --- /dev/null +++ b/api/app/core/workflow/adapters/dify/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 18:20 diff --git a/api/app/core/workflow/adapters/dify/converter.py b/api/app/core/workflow/adapters/dify/converter.py new file mode 100644 index 00000000..0e92b2c7 --- /dev/null +++ b/api/app/core/workflow/adapters/dify/converter.py @@ -0,0 +1,659 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 18:21 +import base64 +import re +from typing import Any +from urllib.parse import quote + +from app.core.workflow.adapters.base_converter import BaseConverter +from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \ + ExceptionType +from app.core.workflow.nodes.assigner import AssignerNodeConfig +from app.core.workflow.nodes.assigner.config import AssignmentItem +from app.core.workflow.nodes.base_config import VariableDefinition +from app.core.workflow.nodes.code import CodeNodeConfig +from app.core.workflow.nodes.code.config import InputVariable, OutputVariable +from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig +from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig +from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \ + CycleVariable +from app.core.workflow.nodes.end import EndNodeConfig +from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \ + HttpContentType, HttpErrorHandle +from app.core.workflow.nodes.http_request import HttpRequestNodeConfig +from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \ + HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig +from app.core.workflow.nodes.if_else import IfElseNodeConfig +from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig +from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig +from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig +from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig +from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig +from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig +from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig +from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig +from app.core.workflow.nodes.question_classifier.config import ClassifierConfig +from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig +from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE + + +class DifyConverter(BaseConverter): + errors: list + warnings: list + branch_node_cache: dict + error_branch_node_cache: list + + def __init__(self): + self.CONFIG_CONVERT_MAP = { + "start": self.convert_start_node_config, + "llm": self.convert_llm_node_config, + "answer": self.convert_end_node_config, + "if-else": self.convert_if_else_node_config, + "loop": self.convert_loop_node_config, + "iteration": self.convert_iteration_node_config, + "assigner": self.convert_assigner_node_config, + "code": self.convert_code_node_config, + "http-request": self.convert_http_node_config, + "template-transform": self.convert_jinja_render_node_config, + "knowledge-retrieval": self.convert_knowledge_node_config, + "parameter-extractor": self.convert_parameter_extractor_node_config, + "question-classifier": self.convert_question_classifier_node_config, + "variable-aggregator": self.convert_variable_aggregator, + "loop-start": lambda x: {}, + "iteration-start": lambda x: {}, + "loop-end": lambda x: {}, + } + + def get_node_convert(self, node_type): + func = self.CONFIG_CONVERT_MAP.get(node_type, None) + return func + + @staticmethod + def is_variable(expression) -> bool: + return bool(re.match(r"\{\{#(.*?)#}}", expression)) + + @staticmethod + def process_var_selector(var_selector): + if not var_selector: + return "" + selector = var_selector.split('.') + if len(selector) != 2: + raise Exception(f"invalid variable selector: {var_selector}") + if selector[0] == "conversation": + selector[0] = "conv" + var_selector = ".".join(selector) + mapping = { + "sys.query": "sys.message" + } + + var_selector = mapping.get(var_selector, var_selector) + return var_selector + + def _process_list_variable_litearl(self, variable_selector: list) -> str | None: + if not self.process_var_selector(".".join(variable_selector)): + return None + return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}" + + def trans_variable_format(self, content): + pattern = re.compile(r"\{\{#(.*?)#}}") + + def replacer(match: re.Match) -> str: + raw_name = match.group(1) + new_name = self.process_var_selector(raw_name) + return f"{{{{{new_name}}}}}" + + return pattern.sub(replacer, content) + + @staticmethod + def _convert_file(var): + pass + + @staticmethod + def _convert_array_file(var): + pass + + @staticmethod + def variable_type_map(source_type) -> VariableType | None: + type_map = { + "file": VariableType.FILE, + "paragraph": VariableType.STRING, + "text-input": VariableType.STRING, + "number": VariableType.NUMBER, + "checkbox": VariableType.BOOLEAN, + "file-list": VariableType.ARRAY_FILE, + "select": VariableType.STRING, + } + var_type = type_map.get(source_type, source_type) + return var_type + + def convert_variable_type(self, target_type: VariableType, origin_value: Any): + if not origin_value: + return DEFAULT_VALUE(target_type) + try: + match target_type: + case VariableType.STRING: + return self._convert_string(origin_value) + case VariableType.NUMBER: + return self._convert_number(origin_value) + case VariableType.BOOLEAN: + return self._convert_boolean(origin_value) + case VariableType.FILE: + return self._convert_file(origin_value) + case VariableType.ARRAY_FILE: + return self._convert_array_file(origin_value) + case _: + return origin_value + except: + raise Exception(f"convert variable failed: {target_type}") + + @staticmethod + def convert_compare_operator(operator): + operator_map = { + "is": ComparisonOperator.EQ, + "is not": ComparisonOperator.NE, + "=": ComparisonOperator.EQ, + "≠": ComparisonOperator.NE, + ">": ComparisonOperator.GT, + "<": ComparisonOperator.LT, + "≥": ComparisonOperator.GE, + "≤": ComparisonOperator.LE, + "not empty": ComparisonOperator.NOT_EMPTY, + } + return operator_map.get(operator, operator) + + @staticmethod + def convert_assignment_operator(operator): + operator_map = { + "+=": AssignmentOperator.ADD, + "-=": AssignmentOperator.SUBTRACT, + "*=": AssignmentOperator.MULTIPLY, + "/=": AssignmentOperator.DIVIDE, + "over-write": AssignmentOperator.COVER, + "remove-last": AssignmentOperator.REMOVE_LAST, + "remove-first": AssignmentOperator.REMOVE_FIRST, + + } + return operator_map.get(operator, operator) + + @staticmethod + def convert_http_auth_type(auth_type): + auth_type_map = { + "no-auth": HttpAuthType.NONE, + "bearer": HttpAuthType.BEARER, + "basic": HttpAuthType.BASIC, + "custom": HttpAuthType.CUSTOM, + } + return auth_type_map.get(auth_type, auth_type) + + @staticmethod + def convert_http_content_type(content_type): + content_type_map = { + "none": HttpContentType.NONE, + "form-data": HttpContentType.FROM_DATA, + "x-www-form-urlencoded": HttpContentType.WWW_FORM, + "json": HttpContentType.JSON, + "raw-text": HttpContentType.RAW, + "binary": HttpContentType.BINARY, + } + return content_type_map.get(content_type, content_type) + + @staticmethod + def convert_http_error_handle_type(handle_type): + handle_type_map = { + "none": HttpErrorHandle.NONE, + "fail-branch": HttpErrorHandle.BRANCH, + "default-value": HttpErrorHandle.DEFAULT, + } + return handle_type_map.get(handle_type, handle_type) + + def convert_start_node_config(self, node: dict) -> dict: + node_data = node["data"] + start_vars = [] + for var in node_data["variables"]: + var_type = self.variable_type_map(var["type"]) + if not var_type: + self.errors.append( + UnsupportVariableType( + scope=node["id"], + name=var["variable"], + var_type=var["type"], + node_id=node["id"], + node_name=node_data["title"] + ) + ) + continue + + if var_type in ["file", "array[file]"]: + self.errors.append( + ExceptionDefineition( + type=ExceptionType.VARIABLE, + node_id=node["id"], + node_name=node_data["title"], + name=var["variable"], + detail=f"Unsupport Variable type for start node: {var_type}" + ) + ) + continue + + var_def = VariableDefinition( + name=var["variable"], + type=var_type, + required=var["required"], + default=self.convert_variable_type( + var_type, var["default"] + ), + description=var["label"], + max_length=var.get("max_length"), + ) + start_vars.append(var_def) + return StartNodeConfig( + variables=start_vars + ).model_dump() + + def convert_question_classifier_node_config(self, node: dict) -> dict: + node_data = node["data"] + self.warnings.append( + UnknowModelWarning( + node_id=node["id"], + node_name=node_data["title"], + model_name=node_data["model"].get("name") + ) + ) + categories = [] + for category in node_data["classes"]: + self.branch_node_cache[node["id"]].append(category["id"]) + categories.append( + ClassifierConfig( + class_name=category["name"], + ) + ) + + return QuestionClassifierNodeConfig.model_construct( + input_variable=self._process_list_variable_litearl(node_data["query_variable_selector"]), + user_supplement_prompt=self.trans_variable_format(node_data["instructions"]), + categories=categories, + ).model_dump() + + def convert_llm_node_config(self, node: dict) -> dict: + node_data = node["data"] + self.warnings.append( + UnknowModelWarning( + node_id=node["id"], + node_name=node_data["title"], + model_name=node_data["model"].get("name") + ) + ) + context = self._process_list_variable_litearl(node_data["context"]["variable_selector"]) + memory = MemoryWindowSetting( + enable=bool(node_data.get("memory")), + enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)), + window_size=int(node_data.get("memory", {}).get("window", {}).get("size", 20)) + ) + messages = [] + for message in node_data["prompt_template"]: + messages.append( + MessageConfig( + role=message["role"], + content=self.trans_variable_format(message["text"]) + ) + ) + if memory.enable: + messages.append( + MessageConfig( + role="user", + content=self.trans_variable_format(node_data["memory"]["query_prompt_template"]) + ) + ) + vision = node_data["vision"]["enabled"] + vision_input = self._process_list_variable_litearl( + node_data["vision"]["configs"]["variable_selector"] + ) if vision else None + return LLMNodeConfig.model_construct( + model_id=None, + context=context, + memory=memory, + vision=vision, + vision_input=vision_input, + messages=messages + ).model_dump() + + def convert_end_node_config(self, node: dict) -> dict: + node_data = node["data"] + return EndNodeConfig( + output=self.trans_variable_format(node_data["answer"]), + ).model_dump() + + def convert_if_else_node_config(self, node: dict) -> dict: + node_data = node["data"] + cases = [] + for case in node_data["cases"]: + case_id = case["id"] + logical_operator = case["logical_operator"] + conditions = [] + for condition in case["conditions"]: + right_value = condition["value"] + condition_detail = ConditionDetail( + operator=self.convert_compare_operator(condition["comparison_operator"]), + left="{{" + self.process_var_selector(".".join(condition["variable_selector"])) + "}}", + right=self.trans_variable_format( + right_value + ) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type( + self.variable_type_map(condition["varType"]), + condition["value"] + ), + input_type=ValueInputType.VARIABLE + if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT, + ) + conditions.append(condition_detail) + cases.append( + ConditionBranchConfig( + logical_operator=logical_operator, + expressions=conditions + ) + ) + self.branch_node_cache[node["id"]].append(case_id) + return IfElseNodeConfig( + cases=cases + ).model_dump() + + def convert_loop_node_config(self, node: dict) -> dict: + node_data = node["data"] + logical_operator = node_data["logical_operator"] + conditions = [] + for condition in node_data["break_conditions"]: + right_value = condition["value"] + conditions.append( + LoopConditionDetail( + operator=self.convert_compare_operator(condition["comparison_operator"]), + left=self._process_list_variable_litearl(condition["variable_selector"]), + right=self.trans_variable_format( + right_value + ) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type( + self.variable_type_map(condition["varType"]), + condition["value"] + ), + input_type=ValueInputType.VARIABLE + if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT, + ) + ) + condition_config = ConditionsConfig( + logical_operator=logical_operator, + expressions=conditions + ) + loop_variables = [] + for variable in node_data["loop_variables"]: + right_input_type = variable["value_type"] + right_value_type = self.variable_type_map(variable["var_type"]) + if right_input_type == ValueInputType.VARIABLE: + right_value = self._process_list_variable_litearl(variable["value"]) + else: + right_value = self.convert_variable_type(right_value_type, variable["value"]) + loop_variables.append( + CycleVariable( + name=variable["label"], + type=right_value_type, + value=right_value, + input_type=right_input_type + ) + ) + return LoopNodeConfig( + condition=condition_config, + cycle_vars=loop_variables, + max_loop=node_data["loop_count"] + ).model_dump() + + def convert_iteration_node_config(self, node: dict) -> dict: + node_data = node["data"] + return IterationNodeConfig( + input=self._process_list_variable_litearl(node_data["iterator_selector"]), + parallel=node_data["is_parallel"], + parallel_count=node_data["parallel_nums"], + output=self._process_list_variable_litearl(node_data["output_selector"]), + output_type=self.variable_type_map(node_data["output_type"]), + flatten=node_data["flatten_output"], + ).model_dump() + + def convert_assigner_node_config(self, node: dict) -> dict: + node_data = node["data"] + assignments = [] + for assignment in node_data["items"]: + if assignment.get("operation") is None or assignment.get("value") is None: + continue + assignments.append( + AssignmentItem( + variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]), + value=self._process_list_variable_litearl( + assignment["value"] + ) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"], + operation=self.convert_assignment_operator(assignment["operation"]) + ) + ) + return AssignerNodeConfig( + assignments=assignments + ).model_dump() + + def convert_code_node_config(self, node: dict) -> dict: + node_data = node["data"] + input_variables = [] + for input_variable in node_data["variables"]: + input_variables.append( + InputVariable( + name=input_variable["variable"], + variable=self._process_list_variable_litearl(input_variable["value_selector"]), + ) + ) + + output_variables = [] + for output_variable in node_data["outputs"]: + output_variables.append( + OutputVariable( + name=output_variable, + type=node_data["outputs"][output_variable]["type"], + ) + ) + + code = base64.b64encode(quote(node_data["code"]).encode("utf-8")).decode("utf-8") + + return CodeNodeConfig( + input_variables=input_variables, + language=node_data["code_language"], + output_variables=output_variables, + code=code + ).model_dump() + + def convert_http_node_config(self, node: dict) -> dict: + node_data = node["data"] + if node_data["authorization"] != 'no-auth': + auth_type = self.convert_http_auth_type(node_data["authorization"]["config"]["type"]) + auth_config = HttpAuthConfig( + auth_type=auth_type, + header=node_data["authorization"]["config"].get("header"), + api_key=node_data["authorization"]["config"].get("api_key"), + ) + else: + auth_config = HttpAuthConfig() + + content_type = self.convert_http_content_type(node_data["body"]["type"]) + if content_type == HttpContentType.FROM_DATA: + body_content = [] + for content in node_data["body"]["data"]: + body_content.append( + HttpFormData( + key=self.trans_variable_format(content["key"]), + type=content["type"], + value=self.trans_variable_format(content["value"]), + ) + ) + elif content_type == HttpContentType.WWW_FORM: + body_content = {} + for content in node_data["body"]["data"]: + body_content[ + self.trans_variable_format(content["key"]) + ] = self.trans_variable_format(content["value"]) + else: + if node_data["body"]["data"]: + body_content = node_data["body"]["data"][0]["value"] + else: + body_content = "" + + headers = {} + for header in node_data["headers"].split("\n"): + if not header: + continue + + key_value = header.split(":") + if len(key_value) == 2: + headers[ + self.trans_variable_format(key_value[0]) + ] = self.trans_variable_format(key_value[1]) + else: + self.warnings.append(ExceptionDefineition( + type=ExceptionType.CONFIG, + node_id=node["id"], + node_name=node_data["title"], + detail=f"Invalid header/param - {header}", + )) + + params = {} + for param in node_data["params"].split("\n"): + if not param: + continue + + key_value = param.split(":") + if len(key_value) == 2: + params[ + self.trans_variable_format(key_value[0]) + ] = self.trans_variable_format(key_value[1]) + else: + self.warnings.append(ExceptionDefineition( + type=ExceptionType.CONFIG, + node_id=node["id"], + node_name=node_data["title"], + detail=f"Invalid header/param - {param}", + )) + + error_handle_type = self.convert_http_error_handle_type( + node_data.get("error_strategy", "none") + ) + default_value = None + if error_handle_type == HttpErrorHandle.DEFAULT: + default_body = "" + default_header = {} + default_status_code = 0 + for var in node_data["default_value"]: + if var["key"] == "body": + default_body = var["value"] + elif var["key"] == "header": + default_header = var["value"] + elif var["key"] == "status_code": + default_status_code = var["value"] + default_value = HttpErrorDefaultTamplete( + body=default_body, + headers=default_header, + status_code=default_status_code, + ) + + self.error_branch_node_cache.append(node['id']) + return HttpRequestNodeConfig( + method=node_data["method"].upper(), + url=node_data["url"], + auth=auth_config, + body=HttpContentTypeConfig( + content_type=self.convert_http_content_type(node_data["body"]["type"]), + data=body_content, + ), + headers=headers, + params=params, + verify_ssl=node_data["ssl_verify"], + timeouts=HttpTimeOutConfig( + connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5, + read_timeout=node_data["timeout"]["max_read_timeout"] or 5, + write_timeout=node_data["timeout"]["max_write_timeout"] or 5, + ), + retry=HttpRetryConfig( + enable=node_data["retry_config"]["retry_enabled"], + max_attempts=node_data["retry_config"]["max_retries"], + retry_interval=node_data["retry_config"]["retry_interval"], + ), + error_handle=HttpErrorHandleConfig( + method=error_handle_type, + default=default_value, + ) + ).model_dump() + + def convert_jinja_render_node_config(self, node: dict) -> dict: + node_data = node["data"] + mapping = [] + for variable in node_data["variables"]: + mapping.append(VariablesMappingConfig( + name=variable["variable"], + value=self._process_list_variable_litearl(variable["value_selector"]) + )) + return JinjaRenderNodeConfig( + template=node_data["template"], + mapping=mapping, + ).model_dump() + + def convert_knowledge_node_config(self, node: dict) -> dict: + node_data = node["data"] + self.warnings.append(ExceptionDefineition( + node_id=node["id"], + node_name=node_data["title"], + type=ExceptionType.CONFIG, + detail=f"Please reconfigure the Knowledge Retrieval node.", + )) + return KnowledgeRetrievalNodeConfig.model_construct( + query=self._process_list_variable_litearl(node_data["query_variable_selector"]), + ).model_dump() + + def convert_parameter_extractor_node_config(self, node: dict) -> dict: + node_data = node["data"] + self.warnings.append( + UnknowModelWarning( + node_id=node["id"], + node_name=node_data["title"], + model_name=node_data["model"].get("name") + ) + ) + params = [] + for param in node_data["parameters"]: + params.append( + ParamsConfig( + name=param["name"], + desc=param["description"], + required=param["required"], + type=param["type"], + ) + ) + return ParameterExtractorNodeConfig.model_construct( + text=self._process_list_variable_litearl(node_data["query"]), + params=params, + prompt=node_data["instruction"] + ).model_dump() + + def convert_variable_aggregator(self, node: dict) -> dict: + node_data = node["data"] + group_enable = node_data["advanced_settings"]["group_enabled"] + group_variables = {} + group_type = {} + if not group_enable: + group_variables["output"] = [ + self._process_list_variable_litearl(variable) + for variable in node_data["variables"] + ] + group_type["output"] = node_data["output_type"] + else: + for group in node_data["advanced_settings"]["groups"]: + group_variables[group["group_name"]] = [ + self._process_list_variable_litearl(variable) + for variable in group["variables"] + ] + group_type[group["group_name"]] = group["output_type"] + + return VariableAggregatorNodeConfig( + group=group_enable, + group_variables=group_variables, + group_type=group_type, + ).model_dump() diff --git a/api/app/core/workflow/adapters/dify/dify_adapter.py b/api/app/core/workflow/adapters/dify/dify_adapter.py new file mode 100644 index 00000000..48a0cbd6 --- /dev/null +++ b/api/app/core/workflow/adapters/dify/dify_adapter.py @@ -0,0 +1,239 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/24 16:05 +from typing import Any + +from app.core.logging_config import get_logger +from app.core.workflow.adapters.base_adapter import ( + BasePlatformAdapter, + PlatformMetadata, + PlatformType, + WorkflowParserResult +) +from app.core.workflow.adapters.dify.converter import DifyConverter +from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType +from app.core.workflow.nodes.enums import NodeType +from app.schemas.workflow_schema import ( + NodeDefinition, + EdgeDefinition, + VariableDefinition, + TriggerConfig, + ExecutionConfig +) + +logger = get_logger() + + +class DifyAdapter(BasePlatformAdapter, DifyConverter): + NODE_TYPE_MAPPING = { + "start": NodeType.START, + "llm": NodeType.LLM, + "answer": NodeType.END, + "if-else": NodeType.IF_ELSE, + "loop-start": NodeType.CYCLE_START, + "iteration-start": NodeType.CYCLE_START, + "assigner": NodeType.ASSIGNER, + "loop": NodeType.LOOP, + "iteration": NodeType.ITERATION, + "loop-end": NodeType.BREAK, + "code": NodeType.CODE, + "http-request": NodeType.HTTP_REQUEST, + "template-transform": NodeType.JINJARENDER, + "knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL, + "parameter-extractor": NodeType.PARAMETER_EXTRACTOR, + "question-classifier": NodeType.QUESTION_CLASSIFIER, + "variable-aggregator": NodeType.VAR_AGGREGATOR + } + + def __init__(self, config: dict[str, Any]): + DifyConverter.__init__(self) + BasePlatformAdapter.__init__(self, config) + + def get_metadata(self) -> PlatformMetadata: + return PlatformMetadata( + platform_name=PlatformType.DIFY, + version="0.5.0", + support_node_types=list(self.NODE_TYPE_MAPPING.keys()) + ) + + def map_node_type(self, platform_node_type) -> str: + return self.NODE_TYPE_MAPPING.get(platform_node_type) + + @property + def origin_nodes(self): + return self.config.get("workflow").get("graph").get("nodes") + + @property + def origin_edges(self): + return self.config.get("workflow").get("graph").get("edges") + + @staticmethod + def _valid_nodes(node: dict[str, Any]): + if "data" not in node: + return False + if "type" not in node["data"]: + return False + if "id" not in node or "type" not in node: + return False + return True + + def validate_config(self) -> bool: + require_fields = frozenset({'app', 'dependencies', 'kind', 'version', 'workflow'}) + if not all(field in self.config for field in require_fields): + return False + + for node in self.origin_nodes: + if not self._valid_nodes(node): + return False + return True + + def parse_workflow(self) -> WorkflowParserResult: + for node in self.origin_nodes: + node = self._convert_node(node) + if node: + self.nodes.append(node) + nodes_id = [node.id for node in self.nodes] + for edge in self.origin_edges: + source = edge["source"] + target = edge["target"] + if source not in nodes_id or target not in nodes_id: + continue + edge = self._convert_edge(edge) + if edge: + self.edges.append(edge) + # + for variable in self.config.get("workflow").get("conversation_variables"): + con_var = self._convert_variable(variable) + if variable: + self.conv_variables.append(con_var) + # + # for variables in config.get("workflow").get("environment_variables"): + # variable = self._convert_variable(variables) + # conv_variables.append(variable) + + trigger = self._convert_trigger({}) + execution_config = self._convert_execution({}) + + return WorkflowParserResult( + success=not self.errors and not self.warnings, + platform=self.get_metadata(), + execution_config=execution_config, + origin_config=self.config, + trigger=trigger, + edges=self.edges, + nodes=self.nodes, + variables=self.conv_variables, + warnings=self.warnings, + errors=self.errors + ) + + def _convert_cycle_node_position(self, node_id: str, position: dict): + for node in self.origin_nodes: + if node["id"] == node_id: + return { + "x": node["position"]["x"] + position["x"], + "y": node["position"]["y"] + position["y"] + } + self.errors.append( + ExceptionDefineition( + type=ExceptionType.NODE, + node_id=node_id, + detail="parent cycle node not found" + ) + ) + raise Exception("parent cycle node not found") + + def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None: + node_data = node["data"] + try: + return NodeDefinition( + id=node["id"], + type=self.map_node_type(node_data["type"]), + name=node_data.get("title"), + cycle=node.get("parentId"), + description=None, + config=self._convert_node_config(node), + position={ + "x": node["position"]["x"], + "y": node["position"]["y"] + } if node.get("parentId") is None else self._convert_cycle_node_position( + node["parentId"], + node["position"] + ), + error_handling=None, + cache=None + ) + except Exception as e: + logger.debug(f"convert node error - {e}", exc_info=True) + + def _convert_node_config(self, node: dict): + node_data = node["data"] + node_type = node_data["type"] + try: + converter = self.get_node_convert(node_type) + if converter is None: + raise Exception(f"node type not supported - {node_type}") + return converter(node) + except Exception as e: + self.errors.append(ExceptionDefineition( + type=ExceptionType.NODE, + node_id=node["id"], + node_name=node["data"]["title"], + detail=f"convert node error - {e}", + )) + raise e + + def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None: + try: + + source = edge["source"] + target = edge["target"] + edge_id = edge["id"] + label = None + if source in self.branch_node_cache: + case_id = "-".join(edge_id.split("-")[1:-2]) + if case_id == "false": + label = f'CASE{len(self.branch_node_cache[source])+1}' + else: + label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}' + if source in self.error_branch_node_cache: + case_id = "-".join(edge_id.split("-")[1:-2]) + if case_id == "source": + label = "SUCCESS" + else: + label = "ERROR" + return EdgeDefinition( + id=edge["id"], + source=source, + target=target, + label=label, + ) + except Exception as e: + self.errors.append(ExceptionDefineition( + type=ExceptionType.EDGE, + detail=f"convert edge error - {e}", + )) + return None + + def _convert_variable(self, variable) -> VariableDefinition | None: + try: + return VariableDefinition( + name=variable["name"], + default=variable["value"], + type=variable["value_type"], + ) + except Exception as e: + self.errors.append(ExceptionDefineition( + type=ExceptionType.VARIABLE, + name=variable.get("name"), + detail=f"convert variable error - {e}", + )) + + def _convert_trigger(self, trigger: dict[str, Any]) -> TriggerConfig | None: + pass + + def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig: + return ExecutionConfig() + + diff --git a/api/app/core/workflow/adapters/errors.py b/api/app/core/workflow/adapters/errors.py new file mode 100644 index 00000000..c0340a5e --- /dev/null +++ b/api/app/core/workflow/adapters/errors.py @@ -0,0 +1,75 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/26 11:29 +from enum import StrEnum + +from pydantic import BaseModel + + +class ExceptionType(StrEnum): + NODE = "node" + EDGE = "edge" + VARIABLE = "variable" + TRIGGER = "trigger" + EXECUTION = "execution" + CONFIG = "config" + PLATFORM = "platform" + UNKNOWN = "unknown" + + +class ExceptionDefineition(BaseModel): + type: ExceptionType + detail: str + + node_id: str | None = None + node_name: str | None = None + + scope: str | None = None + name: str | None = None + + +class UnknowModelWarning(ExceptionDefineition): + type: ExceptionType = ExceptionType.NODE + + def __init__(self, node_id, node_name, model_name): + super().__init__( + detail=f"Please specify the model mapping manually for model: {model_name}", + node_id=node_id, + node_name=node_name + ) + + +class UnknowError(ExceptionDefineition): + type: ExceptionType = ExceptionType.UNKNOWN + + def __init__(self, detail: str, **kwargs): + super().__init__(detail=detail, **kwargs) + + +class UnsupportPlatform(ExceptionDefineition): + type: ExceptionType = ExceptionType.PLATFORM + + def __init__(self, platform: str): + super().__init__(detail=f"Unsupport platform {platform}") + + +class UnsupportVariableType(ExceptionDefineition): + type: ExceptionType = ExceptionType.VARIABLE + + def __init__(self, scope, name, var_type: str, **kwargs): + super().__init__(scope=scope, name=name, detail=f"Unsupport variable type:[{var_type}]", **kwargs) + + +class InvalidConfiguration(ExceptionDefineition): + type: ExceptionType = ExceptionType.CONFIG + + def __init__(self): + super().__init__(detail="Invalid workflow configuration format") + + +class UnsupportNodeType(ExceptionDefineition): + type: ExceptionType = ExceptionType.NODE + + def __init__(self, node_id: str, node_type: str): + super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}") diff --git a/api/app/core/workflow/adapters/memory_bear/__init__.py b/api/app/core/workflow/adapters/memory_bear/__init__.py new file mode 100644 index 00000000..f314662f --- /dev/null +++ b/api/app/core/workflow/adapters/memory_bear/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/26 11:30 diff --git a/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py new file mode 100644 index 00000000..0e3f459f --- /dev/null +++ b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py @@ -0,0 +1,76 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 14:11 +from typing import Any + +from app.core.workflow.adapters.base_adapter import ( + PlatformMetadata, + PlatformType, + BasePlatformAdapter, + WorkflowParserResult +) +from app.schemas.workflow_schema import ExecutionConfig + + +class MemoryBearAdapter(BasePlatformAdapter): + NODE_TYPE_MAPPING = {} + + @property + def origin_nodes(self): + return self.config.get("workflow").get("nodes") + + @property + def origin_edges(self): + return self.config.get("workflow").get("edges") + + @property + def origin_variables(self): + return self.config.get("workflow").get("variables") + + def get_metadata(self) -> PlatformMetadata: + return PlatformMetadata( + platform_name=PlatformType.MEMORY_BEAR, + version="0.2.5", + support_node_types=list(self.NODE_TYPE_MAPPING.keys()) + ) + + def map_node_type(self, platform_node_type) -> str: + return platform_node_type + + @staticmethod + def _valid_nodes(node: dict[str, Any]): + if "type" not in node["data"]: + return False + if "id" not in node or "type" not in node: + return False + return True + + def validate_config(self) -> bool: + require_fields = frozenset({'app', 'workflow'}) + if not all(field in self.config for field in require_fields): + return False + + for node in self.origin_nodes: + if not self._valid_nodes(node): + return False + return True + + def parse_workflow(self) -> WorkflowParserResult: + self.nodes = self.origin_nodes + self.edges = self.origin_edges + self.conv_variables = self.origin_variables + + return WorkflowParserResult( + success=True, + platform=self.get_metadata(), + execution_config=ExecutionConfig(), + origin_config=self.config, + trigger=None, + edges=self.edges, + nodes=self.nodes, + variables=self.conv_variables, + warnings=self.warnings, + errors=self.errors, + + ) diff --git a/api/app/core/workflow/adapters/registry.py b/api/app/core/workflow/adapters/registry.py new file mode 100644 index 00000000..10012676 --- /dev/null +++ b/api/app/core/workflow/adapters/registry.py @@ -0,0 +1,34 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 14:19 +from typing import Any + +from app.core.workflow.adapters import DifyAdapter, MemoryBearAdapter +from app.core.workflow.adapters.base_adapter import BasePlatformAdapter, PlatformType + + +class PlatformAdapterRegistry: + _adapters: dict[str, type[BasePlatformAdapter]] = {} + + @classmethod + def register(cls, platform: str, adapter: type[BasePlatformAdapter]): + cls._adapters[platform] = adapter + + @classmethod + def get_adapter(cls, platform: str, config: dict[str, Any]) -> BasePlatformAdapter: + if platform not in cls._adapters: + raise ValueError(f"Unsupported platform: {platform}") + return cls._adapters.get(platform)(config) + + @classmethod + def list_platforms(cls) -> list[str]: + return list(cls._adapters.keys()) + + @classmethod + def is_supported(cls, platform: str) -> bool: + return platform in cls._adapters + + +PlatformAdapterRegistry.register(PlatformType.MEMORY_BEAR, MemoryBearAdapter) +PlatformAdapterRegistry.register(PlatformType.DIFY, DifyAdapter) diff --git a/api/app/core/workflow/engine/stream_output_coordinator.py b/api/app/core/workflow/engine/stream_output_coordinator.py index 5155a76f..ba6af156 100644 --- a/api/app/core/workflow/engine/stream_output_coordinator.py +++ b/api/app/core/workflow/engine/stream_output_coordinator.py @@ -13,7 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool logger = get_logger(__name__) SCOPE_PATTERN = re.compile( - r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}" + r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*}}" ) diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index be51f81d..4c897d5a 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -88,6 +88,8 @@ class AssignerNode(BaseNode): await operator.remove_first() case AssignmentOperator.REMOVE_LAST: await operator.remove_last() + case AssignmentOperator.EXTEND: + await operator.extend() case _: raise ValueError(f"Invalid Operator: {assignment.operation}") logger.info(f"Node {self.node_id}: execution completed") diff --git a/api/app/core/workflow/nodes/end/config.py b/api/app/core/workflow/nodes/end/config.py index f534dfb5..5c2a6c2a 100644 --- a/api/app/core/workflow/nodes/end/config.py +++ b/api/app/core/workflow/nodes/end/config.py @@ -17,17 +17,17 @@ class EndNodeConfig(BaseNodeConfig): description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}" ) - # 输出变量定义 - output_variables: list[VariableDefinition] = Field( - default_factory=lambda: [ - VariableDefinition( - name="output", - type=VariableType.STRING, - description="工作流的最终输出" - ) - ], - description="输出变量定义(自动生成,通常不需要修改)" - ) + # # 输出变量定义 + # output_variables: list[VariableDefinition] = Field( + # default_factory=lambda: [ + # VariableDefinition( + # name="output", + # type=VariableType.STRING, + # description="工作流的最终输出" + # ) + # ], + # description="输出变量定义(自动生成,通常不需要修改)" + # ) class Config: json_schema_extra = { diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 6ad1c6a8..0579bdf5 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -61,6 +61,7 @@ class AssignmentOperator(StrEnum): APPEND = "append" REMOVE_LAST = "remove_last" REMOVE_FIRST = "remove_first" + EXTEND = "extend" class HttpRequestMethod(StrEnum): diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index cdb34b57..df899940 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -236,5 +236,5 @@ class HttpRequestNode(BaseNode): logger.warning( f"Node {self.node_id}: HTTP request failed, switching to error handling branch" ) - return "ERROR" + return {"output": "ERROR"} raise RuntimeError("http request failed") diff --git a/api/app/core/workflow/nodes/knowledge/config.py b/api/app/core/workflow/nodes/knowledge/config.py index 5475636e..56afe004 100644 --- a/api/app/core/workflow/nodes/knowledge/config.py +++ b/api/app/core/workflow/nodes/knowledge/config.py @@ -40,7 +40,7 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig): ) knowledge_bases: list[KnowledgeBaseConfig] = Field( - ..., + default_factory=list, description="Knowledge base config" ) diff --git a/api/app/core/workflow/nodes/start/config.py b/api/app/core/workflow/nodes/start/config.py index 98390bf7..3f795f1e 100644 --- a/api/app/core/workflow/nodes/start/config.py +++ b/api/app/core/workflow/nodes/start/config.py @@ -3,7 +3,6 @@ from pydantic import Field from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition -from app.core.workflow.variable.base_variable import VariableType class StartNodeConfig(BaseNodeConfig): @@ -21,42 +20,42 @@ class StartNodeConfig(BaseNodeConfig): description="自定义输入变量列表,这些变量会作为 Start 节点的输出" ) - # 输出变量定义 - output_variables: list[VariableDefinition] = Field( - default_factory=lambda: [ - VariableDefinition( - name="message", - type=VariableType.STRING, - description="用户输入的消息" - ), - VariableDefinition( - name="conversation_vars", - type=VariableType.OBJECT, - description="会话级变量" - ), - VariableDefinition( - name="execution_id", - type=VariableType.STRING, - description="执行 ID" - ), - VariableDefinition( - name="conversation_id", - type=VariableType.STRING, - description="会话 ID" - ), - VariableDefinition( - name="workspace_id", - type=VariableType.STRING, - description="工作空间 ID" - ), - VariableDefinition( - name="user_id", - type=VariableType.STRING, - description="用户 ID" - ) - ], - description="输出变量定义(自动生成,通常不需要修改)" - ) + # # 输出变量定义 + # output_variables: list[VariableDefinition] = Field( + # default_factory=lambda: [ + # VariableDefinition( + # name="message", + # type=VariableType.STRING, + # description="用户输入的消息" + # ), + # VariableDefinition( + # name="conversation_vars", + # type=VariableType.OBJECT, + # description="会话级变量" + # ), + # VariableDefinition( + # name="execution_id", + # type=VariableType.STRING, + # description="执行 ID" + # ), + # VariableDefinition( + # name="conversation_id", + # type=VariableType.STRING, + # description="会话 ID" + # ), + # VariableDefinition( + # name="workspace_id", + # type=VariableType.STRING, + # description="工作空间 ID" + # ), + # VariableDefinition( + # name="user_id", + # type=VariableType.STRING, + # description="用户 ID" + # ) + # ], + # description="输出变量定义(自动生成,通常不需要修改)" + # ) class Config: json_schema_extra = { diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 8cf81b92..eeb73a01 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -5,6 +5,8 @@ from enum import Enum, StrEnum from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator +from app.schemas.workflow_schema import WorkflowConfigCreate + # ---------- Multimodal File Support ---------- @@ -196,6 +198,8 @@ class AppCreate(BaseModel): # only for type=multi_agent multi_agent_config: Optional[Dict[str, Any]] = None + workflow_config: Optional[WorkflowConfigCreate] = None + class AppUpdate(BaseModel): name: Optional[str] = None diff --git a/api/app/schemas/workflow_schema.py b/api/app/schemas/workflow_schema.py index bdef825e..9e15f227 100644 --- a/api/app/schemas/workflow_schema.py +++ b/api/app/schemas/workflow_schema.py @@ -18,7 +18,10 @@ class NodeConfig(BaseModel): class NodeDefinition(BaseModel): """节点定义""" id: str = Field(..., description="节点唯一标识") - type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code") + type: str = Field( + ..., + description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code" + ) name: str | None = Field(None, description="节点名称") cycle: str | None = Field(None, description="父循环节点id") description: str | None = Field(None, description="节点描述") @@ -30,12 +33,12 @@ class NodeDefinition(BaseModel): class EdgeDefinition(BaseModel): """边定义""" - id: str | None = Field(None, description="边唯一标识(可选)") + id: str | None = Field(default=None, description="边唯一标识(可选)") source: str = Field(..., description="源节点 ID") target: str = Field(..., description="目标节点 ID") - type: str | None = Field(None, description="边类型: normal, error") - condition: str | None = Field(None, description="条件表达式(条件边)") - label: str | None = Field(None, description="边标签") + type: str | None = Field(default=None, description="边类型: normal, error") + condition: str | None = Field(default=None, description="条件表达式(条件边)") + label: str | None = Field(default=None, description="边标签") class VariableDefinition(BaseModel): @@ -44,7 +47,7 @@ class VariableDefinition(BaseModel): type: str = Field(default="string", description="变量类型: string, number, boolean, object, array") required: bool = Field(default=False, description="是否必填") default: Any = Field(None, description="默认值") - description: str | None = Field(None, description="变量描述") + description: str | None = Field(default=None, description="变量描述") class ExecutionConfig(BaseModel): @@ -61,6 +64,13 @@ class TriggerConfig(BaseModel): config: dict[str, Any] = Field(default_factory=dict, description="触发器配置") +class WorkflowImportSave(BaseModel): + """工作流导入请求""" + temp_id: str + name: str + description: str + + # ==================== 工作流配置 ==================== class WorkflowConfigCreate(BaseModel): @@ -84,7 +94,7 @@ class WorkflowConfigUpdate(BaseModel): class WorkflowConfig(BaseModel): """工作流配置输出""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID app_id: uuid.UUID nodes: list[dict[str, Any]] @@ -95,11 +105,11 @@ class WorkflowConfig(BaseModel): is_active: bool created_at: datetime.datetime updated_at: datetime.datetime - + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -123,7 +133,8 @@ class WorkflowExecutionResponse(BaseModel): output_data: dict[str, Any] | None = Field(None, description="所有节点的详细输出数据") error_message: str | None = Field(None, description="错误信息") elapsed_time: float | None = Field(None, description="耗时(秒)") - token_usage: dict[str, Any] | None = Field(None, description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}") + token_usage: dict[str, Any] | None = Field(None, + description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}") class WorkflowExecutionStreamChunk(BaseModel): @@ -136,7 +147,7 @@ class WorkflowExecutionStreamChunk(BaseModel): class WorkflowExecution(BaseModel): """工作流执行记录输出""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID workflow_config_id: uuid.UUID app_id: uuid.UUID @@ -156,15 +167,15 @@ class WorkflowExecution(BaseModel): token_usage: dict[str, Any] | None meta_data: dict[str, Any] created_at: datetime.datetime - + @field_serializer("started_at", when_used="json") def _serialize_started_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("completed_at", when_used="json") def _serialize_completed_at(self, dt: datetime.datetime | None): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -173,7 +184,7 @@ class WorkflowExecution(BaseModel): class WorkflowNodeExecution(BaseModel): """工作流节点执行记录输出""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID execution_id: uuid.UUID node_id: str @@ -193,15 +204,15 @@ class WorkflowNodeExecution(BaseModel): cache_key: str | None meta_data: dict[str, Any] created_at: datetime.datetime - + @field_serializer("started_at", when_used="json") def _serialize_started_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("completed_at", when_used="json") def _serialize_completed_at(self, dt: datetime.datetime | None): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index f3c6260a..6e6e0ecb 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -321,6 +321,26 @@ class AppService: self.db.add(agent_cfg) logger.debug("Agent 配置已创建", extra={"app_id": str(app_id)}) + def _create_workflow_config( + self, + app_id: uuid.UUID, + data: app_schema.WorkflowConfigCreate, + now: datetime.datetime + ): + workflow_cfg = WorkflowConfig( + id=uuid.uuid4(), + app_id=app_id, + nodes=[node.model_dump() for node in data.nodes] if data.nodes else [], + edges=[edge.model_dump() for edge in data.edges] if data.edges else [], + variables=[var.model_dump() for var in data.variables] if data.variables else [], + execution_config=data.execution_config.model_dump() if data.execution_config else {}, + triggers=[trigger.model_dump() for trigger in data.triggers] if data.triggers else [], + is_active=True, + created_at=now, + updated_at=now + ) + self.db.add(workflow_cfg) + def _create_multi_agent_config( self, app_id: uuid.UUID, @@ -532,6 +552,9 @@ class AppService: if app.type == "multi_agent" and data.multi_agent_config: self._create_multi_agent_config(app.id, data.multi_agent_config, now) + if app.type == "workflow" and data.workflow_config: + self._create_workflow_config(app.id, data.workflow_config, now) + self.db.commit() self.db.refresh(app) @@ -968,7 +991,7 @@ class AppService: config = self.db.scalars(stmt).first() try: - config_memory=config.memory + config_memory = config.memory if 'memory_content' in config_memory: config.memory['memory_config_id'] = config.memory.pop('memory_content') except: @@ -1189,9 +1212,9 @@ class AppService: # ==================== 记忆配置提取方法 ==================== def _extract_memory_config_id( - self, - app_type: str, - config: Dict[str, Any] + self, + app_type: str, + config: Dict[str, Any] ) -> Tuple[Optional[uuid.UUID], bool]: """从发布配置中提取 memory_config_id(委托给 MemoryConfigService) @@ -1205,13 +1228,13 @@ class AppService: - is_legacy_int: 是否检测到旧格式 int 数据,需要回退到工作空间默认配置 """ from app.services.memory_config_service import MemoryConfigService - + service = MemoryConfigService(self.db) return service.extract_memory_config_id(app_type, config) def _get_workspace_default_memory_config_id( - self, - workspace_id: uuid.UUID + self, + workspace_id: uuid.UUID ) -> Optional[uuid.UUID]: """获取工作空间的默认记忆配置ID @@ -1222,22 +1245,22 @@ class AppService: Optional[uuid.UUID]: 默认记忆配置ID,如果不存在则返回 None """ from app.services.memory_config_service import MemoryConfigService - + service = MemoryConfigService(self.db) config = service.get_workspace_default_config(workspace_id) - + if not config: logger.warning( f"工作空间没有可用的记忆配置: workspace_id={workspace_id}" ) return None - + return config.config_id def _update_endusers_memory_config( - self, - app_id: uuid.UUID, - memory_config_id: uuid.UUID + self, + app_id: uuid.UUID, + memory_config_id: uuid.UUID ) -> int: """批量更新应用下所有终端用户的 memory_config_id @@ -1249,13 +1272,13 @@ class AppService: int: 更新的终端用户数量 """ from app.repositories.end_user_repository import EndUserRepository - + repo = EndUserRepository(self.db) updated_count = repo.batch_update_memory_config_id( app_id=app_id, memory_config_id=memory_config_id ) - + return updated_count # ==================== 应用发布管理 ==================== @@ -1403,7 +1426,7 @@ class AppService: # 提取记忆配置ID并更新终端用户 memory_config_id, is_legacy_int = self._extract_memory_config_id(app.type, config) - + # 如果检测到旧格式 int 数据,回退到工作空间默认配置 if is_legacy_int and not memory_config_id: memory_config_id = self._get_workspace_default_memory_config_id(app.workspace_id) @@ -1412,7 +1435,7 @@ class AppService: f"发布时使用工作空间默认记忆配置(旧数据兼容): app_id={app_id}, " f"workspace_id={app.workspace_id}, memory_config_id={memory_config_id}" ) - + if memory_config_id: updated_count = self._update_endusers_memory_config(app_id, memory_config_id) logger.info( @@ -1537,7 +1560,7 @@ class AppService: # 提取记忆配置ID并更新终端用户 memory_config_id, is_legacy_int = self._extract_memory_config_id(release.type, release.config) - + # 如果检测到旧格式 int 数据,回退到工作空间默认配置 if is_legacy_int and not memory_config_id: memory_config_id = self._get_workspace_default_memory_config_id(app.workspace_id) @@ -1546,7 +1569,7 @@ class AppService: f"回滚时使用工作空间默认记忆配置(旧数据兼容): app_id={app_id}, " f"workspace_id={app.workspace_id}, memory_config_id={memory_config_id}" ) - + if memory_config_id: updated_count = self._update_endusers_memory_config(app_id, memory_config_id) logger.info( diff --git a/api/app/services/workflow_import_service.py b/api/app/services/workflow_import_service.py new file mode 100644 index 00000000..2e17f404 --- /dev/null +++ b/api/app/services/workflow_import_service.py @@ -0,0 +1,102 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 14:39 +import json +import uuid +from typing import Any + +from sqlalchemy.orm import Session + +from app.aioRedis import aio_redis_set, aio_redis_get +from app.core.config import settings +from app.core.exceptions import BusinessException +from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult +from app.core.workflow.adapters.errors import UnsupportPlatform, InvalidConfiguration +from app.core.workflow.adapters.registry import PlatformAdapterRegistry +from app.schemas import AppCreate +from app.schemas.workflow_schema import WorkflowConfigCreate +from app.services.app_service import AppService +from app.services.workflow_service import WorkflowService + + +class WorkflowImportService: + def __init__(self, db: Session): + self.db = db + self.registry = PlatformAdapterRegistry + self.cache_timeout = settings.WORKFLOW_IMPORT_CACHE_TIMEOUT + + self.app_service = AppService(db) + self.workflow_service = WorkflowService(db) + + async def flush_config(self, temp_id: str, config: WorkflowParserResult): + config_cache = await aio_redis_get(temp_id) + if not config_cache: + raise BusinessException("Workflow configuration has expired. Please re-upload it.") + await aio_redis_set(temp_id, config.model_dump_json(), expire=self.cache_timeout) + + async def upload_config( + self, + platform: str, + config: dict[str, Any], + ): + + if not self.registry.is_supported(platform): + return WorkflowImportResult( + success=False, + temp_id=None, + workflow_id=None, + errors=[UnsupportPlatform(platform=platform)] + ) + + adapter = self.registry.get_adapter(platform, config) + + if not adapter.validate_config(): + return WorkflowImportResult( + success=False, + temp_id=None, + workflow_id=None, + errors=[InvalidConfiguration()] + ) + + workflow_config = adapter.parse_workflow() + temp_id = uuid.uuid4().hex + await aio_redis_set(temp_id, workflow_config.model_dump(), expire=self.cache_timeout) + return WorkflowImportResult( + success=True, + temp_id=temp_id, + workflow_id=None, + edges=workflow_config.edges, + nodes=workflow_config.nodes, + variables=workflow_config.variables, + warnings=workflow_config.warnings, + errors=workflow_config.errors + ) + + async def save_workflow( + self, + user_id: uuid.UUID, + workspace_id: uuid.UUID, + temp_id: str, + name: str, + description: str | None, + ): + config = await aio_redis_get(temp_id) + if config is None: + raise BusinessException("Configuration import timed out. Please try again.") + config = json.loads(config) + app = self.app_service.create_app( + user_id=user_id, + workspace_id=workspace_id, + data=AppCreate( + name=name, + description=description, + type="workflow", + workflow_config=WorkflowConfigCreate( + nodes=config["nodes"], + edges=config["edges"], + variables=config["variables"] + ) + ) + ) + return app diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index d06a05d7..188ef6cd 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -6,13 +6,16 @@ import logging import uuid from typing import Any, Annotated, Optional +import yaml from fastapi import Depends from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException +from app.core.workflow.adapters.registry import PlatformAdapterRegistry from app.core.workflow.validator import validate_workflow_config from app.db import get_db +from app.models import App from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.repositories.workflow_repository import ( WorkflowConfigRepository, @@ -38,6 +41,8 @@ class WorkflowService: self.conversation_service = ConversationService(db) self.multimodal_service = MultimodalService(db) + self.registry = PlatformAdapterRegistry + # ==================== 配置管理 ==================== def create_workflow_config( @@ -200,6 +205,32 @@ class WorkflowService: logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}") return True + def export_workflow_dsl(self, app_id: uuid.UUID): + config = self.get_workflow_config(app_id) + if not config: + raise BusinessException( + code=BizCode.NOT_FOUND, + message=f"工作流配置不存在: app_id={app_id}" + ) + + app: App = config.app + dsl_info = { + "app": { + "name": app.name, + "description": app.description, + "icon": app.icon, + "icon_type": app.icon_type + }, + "workflow": { + "variables": config.variables, + "edges": config.edges, + "nodes": config.nodes, + "execution_config": config.execution_config, + "triggers": config.triggers + } + } + return yaml.dump(dsl_info, default_flow_style=False, allow_unicode=True) + def check_config(self, app_id: uuid.UUID) -> WorkflowConfig: """检查工作流配置的完整性 From 5f211620c533068d8b2882263c6ada2219486a26 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Sat, 28 Feb 2026 14:01:49 +0800 Subject: [PATCH 010/164] fix(app): Lock the conversation with the application dialogue --- .../controllers/service/app_api_controller.py | 4 +-- api/app/version_info.json | 30 +++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index bb71d831..61a919b1 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -89,7 +89,6 @@ async def chat( body = await request.json() payload = AppChatRequest(**body) - other_id = payload.user_id app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id) other_id = payload.user_id workspace_id = app.workspace_id @@ -135,7 +134,8 @@ async def chat( app_id=app.id, workspace_id=workspace_id, user_id=end_user_id, - is_draft=False + is_draft=False, + conversation_id=payload.conversation_id ) if app_type == AppType.AGENT: diff --git a/api/app/version_info.json b/api/app/version_info.json index aea03dcd..7d82eabc 100644 --- a/api/app/version_info.json +++ b/api/app/version_info.json @@ -1,4 +1,34 @@ { + "v0.2.5": { + "introduction": { + "codeName": "行云", + "releaseDate": "2026-2-26", + "upgradePosition": "🐻 精炼根基,优化核心用户体验与系统稳定性", + "coreUpgrades": [ + "1. 用户体验与国际化 🎨
* 语言参数修复:语言偏好现正确保留
* 邮箱修改支持:用户可直接在用户管理系统中修改邮箱地址", + "2. 工作流可视化增强 💬
* 循环与迭代节点输出展示:实时显示执行进度和中间输出,便于调试复杂迭代过程
* 变量支持回车选择:支持回车键确认变量选择,简化工作流配置流程", + "3. 优化模型管理 ⚙️
* 模型广场移除自定义模型,优化模型使用体验", + "4. 稳健性与缺陷修复 🔧
* 知识图谱构建修复:解决知识图谱构建流程稳定性问题,确保更可靠的实体提取和关系映射", + "
", + "版本 0.2.5 通过解决国际化边界情况和改进工作流透明度,构建更具生产就绪性的平台。工作流可视化改进为更复杂的调试和监控能力奠定基础。未来将继续深化企业就绪性,扩展用户管理功能、优化知识图谱智能和增强工作流编排能力,在可观测性、性能优化和无缝集成模式方面持续改进。", + "智慧致远 🐻✨" + ] + }, + "introduction_en": { + "codeName": "Flowing Clouds", + "releaseDate": "2026-2-26", + "upgradePosition": "🐻 Refined foundations with enhanced user experience and system stability", + "coreUpgrades": [ + "1. User Experience & Internationalization 🎨
* Language parameter fix: language preferences are now correctly retained
* Email Update Support: Users can now modify email addresses directly in user management system", + "2. Workflow Visualization Enhancements 💬
* Loop & Iteration Node Output Display: Real-time display of execution progress and intermediate outputs for easier debugging
* Variable Selection with Enter Key: Enabled Enter key confirmation for streamlined variable assignment", + "3. Optimized Model Management ⚙️
* Custom models have been removed from the Model marketplace to optimize the model usage experience", + "4. Robustness & Bug Fixes 🔧
* Knowledge Graph Construction Fix: Addressed stability issues in knowledge graph pipeline for more reliable entity extraction and relationship mapping", + "
", + "Version 0.2.5 matures MemoryBear's operational foundations by addressing internationalization edge cases and improving workflow transparency. The workflow visualization improvements lay groundwork for sophisticated debugging and monitoring capabilities. Looking forward, we will deepen enterprise readiness by expanding user management features, refining knowledge graph intelligence, and enhancing workflow orchestration with continued improvements in observability, performance optimization, and seamless integration patterns.", + "Intelligent Resilience 🐻✨" + ] + } + }, "v0.2.4": { "introduction": { "codeName": "智远", From 1037729fb3457109e68315dac0309f18177381f6 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Sat, 28 Feb 2026 16:51:56 +0800 Subject: [PATCH 011/164] fix(model): The custom models in the model list can batch add APIkeys through the provider --- api/app/repositories/model_repository.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py index 2c513e82..f49227d3 100644 --- a/api/app/repositories/model_repository.py +++ b/api/app/repositories/model_repository.py @@ -428,19 +428,17 @@ class ModelConfigRepository: try: # 查询ModelConfig关联的ModelApiKey,筛选出匹配的model_config_id - model_config_ids = db.query(ModelConfig.id).join( - ModelBase, ModelConfig.model_id == ModelBase.id - ).filter( + model_config_ids = db.query(ModelConfig.id).filter( and_( or_( ModelConfig.tenant_id == tenant_id, ModelConfig.is_public ), - ModelBase.provider == provider, + ModelConfig.provider == provider, ModelConfig.is_active, ~ModelConfig.is_composite ) - ).distinct().all() + ).all() db_logger.debug(f"查询成功: 数量={len(model_config_ids)}") return [row[0] for row in model_config_ids] From 3a0671c661baee7000c1ab3ae43daf69bb11d811 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Sat, 28 Feb 2026 17:18:42 +0800 Subject: [PATCH 012/164] [add]The semantic pruning function is activated, removing the protection of question-answer pairs. --- .../core/memory/agent/utils/get_dialogs.py | 57 +- .../data_preprocessing/data_pruning.py | 511 ++++++++---------- .../data_preprocessing/scene_config.py | 326 +++++++++++ .../extraction_orchestrator.py | 16 +- 4 files changed, 619 insertions(+), 291 deletions(-) create mode 100644 api/app/core/memory/storage_services/extraction_engine/data_preprocessing/scene_config.py diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index bfb0f675..22555fff 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -21,7 +21,7 @@ async def get_chunked_dialogs( end_user_id: Group identifier messages: Structured message list [{"role": "user", "content": "..."}, ...] ref_id: Reference identifier - config_id: Configuration ID for processing + config_id: Configuration ID for processing (used to load pruning config) Returns: List of DialogData objects with generated chunks @@ -57,6 +57,61 @@ async def get_chunked_dialogs( end_user_id=end_user_id, config_id=config_id ) + + # 语义剪枝步骤(在分块之前) + try: + from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner + from app.core.memory.models.config_models import PruningConfig + from app.db import get_db_context + from app.services.memory_config_service import MemoryConfigService + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + + # 加载剪枝配置 + pruning_config = None + if config_id: + try: + with get_db_context() as db: + # 使用 MemoryConfigService 加载完整的 MemoryConfig 对象 + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + service_name="semantic_pruning" + ) + + if memory_config: + pruning_config = PruningConfig( + pruning_switch=memory_config.pruning_enabled, + pruning_scene=memory_config.pruning_scene or "education", + pruning_threshold=memory_config.pruning_threshold + ) + logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}") + + # 获取LLM客户端用于剪枝 + if pruning_config.pruning_switch: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client_from_config(memory_config) + + # 执行剪枝 - 使用 prune_dataset 支持消息级剪枝 + pruner = SemanticPruner(config=pruning_config, llm_client=llm_client) + original_msg_count = len(dialog_data.context.msgs) + + # 使用 prune_dataset 而不是 prune_dialog + # prune_dataset 会进行消息级剪枝,即使对话整体相关也会删除不重要消息 + pruned_dialogs = await pruner.prune_dataset([dialog_data]) + + if pruned_dialogs: + dialog_data = pruned_dialogs[0] + remaining_msg_count = len(dialog_data.context.msgs) + deleted_count = original_msg_count - remaining_msg_count + logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)") + else: + logger.warning("[剪枝] prune_dataset 返回空列表") + else: + logger.info("[剪枝] 配置中剪枝开关关闭,跳过剪枝") + except Exception as e: + logger.warning(f"[剪枝] 加载配置失败,跳过剪枝: {e}", exc_info=True) + except Exception as e: + logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True) chunker = DialogueChunker(chunker_strategy) extracted_chunks = await chunker.process_dialogue(dialog_data) diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index 2d0142c6..d932c542 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -22,6 +22,10 @@ from app.core.memory.models.message_models import DialogData, ConversationMessag from app.core.memory.models.config_models import PruningConfig from app.core.memory.utils.config.config_utils import get_pruning_config from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering +from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import ( + SceneConfigRegistry, + ScenePatterns +) class DialogExtractionResponse(BaseModel): @@ -78,6 +82,20 @@ class SemanticPruner: self.language = language # 保存语言配置 self.max_concurrent = max_concurrent # 新增:最大并发数 + # 加载场景特定配置 + self.scene_config: ScenePatterns = SceneConfigRegistry.get_config( + self.config.pruning_scene, + fallback_to_generic=True + ) + + # 检查场景是否有专门支持 + is_supported = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene) + if is_supported: + self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用专门配置") + else: + self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 未预定义,使用通用配置(保守策略)") + self._log(f"[剪枝-初始化] 支持的场景: {SceneConfigRegistry.get_all_scenes()}") + # Load Jinja2 template self.template = prompt_env.get_template("extracat_Pruning.jinja2") @@ -87,108 +105,80 @@ class SemanticPruner: # 运行日志:收集关键终端输出,便于写入 JSON self.run_logs: List[str] = [] - - # 扩展的填充词库(包含表情符号和网络用语) - self._extended_fillers = [ - # 基础寒暄 - "你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦", - "好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢", - "拜拜", "再见", "88", "拜", "回见", - # 口头禅 - "哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia", - "额", "呃", "啊", "诶", "唉", "哎", "嗯哼", - # 确认词 - "是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了", - # 标点和符号 - "。。。", "...", "???", "???", "!!!", "!!!", - # 表情符号(文本形式) - "[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]", - "[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]", - "[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]", - "[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]", - # 网络用语 - "hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok", - "emmm", "emm", "em", "mmp", "wtf", "omg", - ] def _is_important_message(self, message: ConversationMessage) -> bool: """基于启发式规则识别重要信息消息,优先保留。 - 改进版:增强了规则覆盖范围,包括: - - 含日期/时间(如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午) - - 含编号/ID/订单号/申请号/账号/电话/金额等关键字段 - - 关键词:"时间"、"日期"、"编号"、"订单"、"流水"、"金额"、"¥"、"元"、"电话"、"手机号"、"邮箱"、"地址" - - 新增:问句识别、决策性语句、承诺性语句 + 改进版:使用场景特定的模式进行识别 + - 根据 pruning_scene 动态加载对应的识别规则 + - 支持教育、在线服务、外呼三个场景的特定模式 """ text = message.msg.strip() if not text: return False - patterns = [ - # 原有模式 - r"\d{4}-\d{1,2}-\d{1,2}", # 修复:移除 \b 边界,因为中文前后没有单词边界 - r"\d{1,2}:\d{2}", # 修复:移除 \b - r"\d{4}年\d{1,2}月\d{1,2}日", - r"上午|下午|AM|PM|今天|明天|后天|昨天|前天|本周|下周|上周|本月|下月|上月", - r"订单号|工单|申请号|编号|ID|账号|账户|流水号|单号", - r"电话|手机号|微信|QQ|邮箱|联系方式", - r"地址|地点|位置|门牌号", - r"金额|费用|价格|¥|¥|\d+元|人民币|美元|欧元", - r"时间|日期|有效期|截止|期限|到期", - # 新增模式 - r"什么|为什么|怎么|如何|哪里|哪个|谁|多少|几点|何时", # 问句关键词 - r"必须|一定|务必|需要|要求|规定|应该", # 决策性语句 - r"承诺|保证|确保|负责|同意|答应", # 承诺性语句 - r"\d{11}", # 11位手机号 - r"\d{3,4}-\d{7,8}", # 固定电话 - r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", # 邮箱 - ] + # 使用场景特定的模式 + all_patterns = ( + self.scene_config.high_priority_patterns + + self.scene_config.medium_priority_patterns + + self.scene_config.low_priority_patterns + ) - for p in patterns: - if re.search(p, text, flags=re.IGNORECASE): + for pattern, _ in all_patterns: + if re.search(pattern, text, flags=re.IGNORECASE): return True # 检查是否为问句(以问号结尾或包含疑问词) if text.endswith("?") or text.endswith("?"): return True + + # 检查是否包含问句关键词 + if any(keyword in text for keyword in self.scene_config.question_keywords): + return True + + # 检查是否包含决策性关键词 + if any(keyword in text for keyword in self.scene_config.decision_keywords): + return True return False def _importance_score(self, message: ConversationMessage) -> int: """为重要消息打分,用于在保留比例内优先保留更关键的内容。 - 改进版:更细致的评分体系(0-10分) + 改进版:使用场景特定的权重体系(0-10分) + - 根据场景动态调整不同信息类型的权重 + - 高优先级模式:4-6分 + - 中优先级模式:2-3分 + - 低优先级模式:1分 """ text = message.msg.strip() score = 0 - weights = [ - # 高优先级(4-5分) - (r"订单号|工单|申请号|编号|ID|账号|账户", 5), - (r"金额|费用|价格|¥|¥|\d+元", 5), - (r"\d{11}", 4), # 手机号 - (r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱 - - # 中优先级(2-3分) - (r"\d{4}-\d{1,2}-\d{1,2}", 3), # 修复:移除 \b - (r"\d{4}年\d{1,2}月\d{1,2}日", 3), - (r"电话|手机号|微信|QQ|联系方式", 3), - (r"地址|地点|位置", 2), - (r"时间|日期|有效期|截止|明天|后天|下周|下月", 2), # 新增时间相关词 - - # 低优先级(1分) - (r"\d{1,2}:\d{2}", 1), # 修复:移除 \b - (r"上午|下午|AM|PM", 1), - ] + # 使用场景特定的权重 + for pattern, weight in self.scene_config.high_priority_patterns: + if re.search(pattern, text, flags=re.IGNORECASE): + score += weight - for p, w in weights: - if re.search(p, text, flags=re.IGNORECASE): - score += w + for pattern, weight in self.scene_config.medium_priority_patterns: + if re.search(pattern, text, flags=re.IGNORECASE): + score += weight + + for pattern, weight in self.scene_config.low_priority_patterns: + if re.search(pattern, text, flags=re.IGNORECASE): + score += weight # 问句加分 if text.endswith("?") or text.endswith("?"): score += 2 + # 包含问句关键词加分 + if any(keyword in text for keyword in self.scene_config.question_keywords): + score += 1 + + # 包含决策性关键词加分 + if any(keyword in text for keyword in self.scene_config.decision_keywords): + score += 2 + # 长度加分(较长的消息通常包含更多信息) if len(text) > 50: score += 1 @@ -198,20 +188,35 @@ class SemanticPruner: return min(score, 10) # 最高10分 def _is_filler_message(self, message: ConversationMessage) -> bool: - """检测典型寒暄/口头禅/确认类短消息,用于跳过LLM分类以加速。 + """检测典型寒暄/口头禅/确认类短消息。 - 改进版:扩展了填充词库,支持表情符号和网络用语 + 改进版:更严格的填充消息判断,避免误删场景相关内容 满足以下之一视为填充消息: - - 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体 - - 在扩展填充词库中 + - 纯标点或空白 + - 在场景特定填充词库中(精确匹配) - 纯表情符号 + - 常见寒暄(精确匹配短语) + + 注意:不再使用长度判断,避免误删短但重要的消息 """ t = message.msg.strip() if not t: return True - # 检查是否在扩展填充词库中 - if t in self._extended_fillers: + # 检查是否在场景特定填充词库中(精确匹配) + if t in self.scene_config.filler_phrases: + return True + + # 常见寒暄和问候(精确匹配,避免误删) + common_greetings = { + "在吗", "在不在", "在呢", "在的", + "你好", "您好", "hello", "hi", + "拜拜", "再见", "拜", "88", "bye", + "好的", "好", "行", "可以", "嗯", "哦", "啊", + "是的", "对", "对的", "没错", "是啊", + "哈哈", "呵呵", "嘿嘿", "嗯嗯" + } + if t in common_greetings: return True # 检查是否为纯表情符号(方括号包裹) @@ -232,13 +237,9 @@ class SemanticPruner: if emoji_pattern.fullmatch(t): return True - # 长度与字符类型判断 - if len(t) <= 8: - # 非数字、无关键实体的短文本 - if not re.search(r"[0-9]", t) and not self._is_important_message(message): - # 主要是标点或简单确认词 - if re.fullmatch(r"[。!?,.!?…·\s]+", t): - return True + # 纯标点符号 + if re.fullmatch(r"[。!?,.!?…·\s]+", t): + return True return False @@ -308,6 +309,8 @@ class SemanticPruner: def _identify_qa_pairs(self, messages: List[ConversationMessage]) -> List[QAPair]: """识别对话中的问答对,用于保护问答结构的完整性。 + 改进版:使用场景特定的问句关键词,并排除寒暄类问句 + Args: messages: 消息列表 @@ -316,21 +319,39 @@ class SemanticPruner: """ qa_pairs = [] + # 寒暄类问句,不应该被保护(这些不是真正的问答) + greeting_questions = { + "在吗", "在不在", "你好吗", "怎么样", "好吗", + "有空吗", "忙吗", "睡了吗", "起床了吗" + } + for i in range(len(messages) - 1): current_msg = messages[i].msg.strip() next_msg = messages[i + 1].msg.strip() - # 简单规则:如果当前消息是问句,下一条消息可能是答案 - is_question = ( - current_msg.endswith("?") or - current_msg.endswith("?") or - any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗"]) - ) + # 排除寒暄类问句 + if current_msg in greeting_questions: + continue + + # 使用场景特定的问句关键词,但要求更严格 + is_question = False + + # 1. 以问号结尾 + if current_msg.endswith("?") or current_msg.endswith("?"): + is_question = True + # 2. 包含实质性问句关键词(排除"吗"这种太宽泛的) + elif any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时"]): + is_question = True if is_question and next_msg: - # 检查下一条消息是否像答案(不是另一个问句) + # 检查下一条消息是否像答案(不是另一个问句,也不是寒暄) is_answer = not (next_msg.endswith("?") or next_msg.endswith("?")) + # 排除寒暄类回复 + greeting_answers = {"你好", "您好", "在呢", "在的", "嗯", "哦", "好的"} + if next_msg in greeting_answers: + is_answer = False + if is_answer: qa_pairs.append(QAPair( question_idx=i, @@ -533,10 +554,9 @@ class SemanticPruner: """数据集层面:全局消息级剪枝,保留所有对话。 改进版: - - 并发处理对话级相关性判断 - - 问答对识别和保护 - - 优化删除策略,保持上下文连贯性 - - 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动 + - 消息级独立判断,每条消息根据场景规则独立评估 + - 问答对保护已注释(暂不启用,留作观察) + - 优化删除策略:填充消息 → 不重要消息 → 低分重要消息 - 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留 - 保证每段对话至少保留1条消息,不会删除整段对话 """ @@ -553,209 +573,122 @@ class SemanticPruner: proportion = 0.0 self._log( - f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}" + f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断" ) - # 并发处理对话级相关性分类 - semaphore = asyncio.Semaphore(self.max_concurrent) - - async def classify_dialog(idx: int, dd: DialogData): - async with semaphore: - try: - ex = await self._extract_dialog_important(dd.content) - return { - "dialog": dd, - "is_related": bool(ex.is_related), - "index": idx, - "extraction": ex - } - except Exception as e: - self._log(f"[剪枝-并发] 对话 {idx} 分类失败: {str(e)[:100]}") - return { - "dialog": dd, - "is_related": True, # 失败时标记为相关,避免误删 - "index": idx, - "extraction": None - } - - # 并发执行所有对话的分类 - tasks = [classify_dialog(idx, dd) for idx, dd in enumerate(dialogs)] - evaluated_dialogs = await asyncio.gather(*tasks) - - # 统计相关 / 不相关对话 - not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]] - related_dialogs = [d for d in evaluated_dialogs if d["is_related"]] - self._log( - f"[剪枝-数据集] 相关对话数={len(related_dialogs)} 不相关对话数={len(not_related_dialogs)}" - ) - - # 简洁打印第几段对话相关/不相关(索引基于1) - def _fmt_indices(items, cap: int = 10): - inds = [i["index"] + 1 for i in items] - if len(inds) <= cap: - return inds - return inds[:cap] + ["...", f"共{len(inds)}个"] - - rel_inds = _fmt_indices(related_dialogs) - nrel_inds = _fmt_indices(not_related_dialogs) - self._log(f"[剪枝-数据集] 相关对话:第{rel_inds}段;不相关对话:第{nrel_inds}段") - result: List[DialogData] = [] - if not_related_dialogs: - # 为每个不相关对话进行分析 - per_dialog_info = {} - total_unrelated = 0 + total_original_msgs = 0 + total_deleted_msgs = 0 + + for d_idx, dd in enumerate(dialogs): + msgs = dd.context.msgs + original_count = len(msgs) + total_original_msgs += original_count - for d in not_related_dialogs: - dd = d["dialog"] - extraction = d.get("extraction") - if extraction is None: - extraction = await self._extract_dialog_important(dd.content) - - # 合并所有重要标记 - tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords - msgs = dd.context.msgs - - # 识别问答对 - qa_pairs = self._identify_qa_pairs(msgs) - protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=1) - - # 分类消息(考虑问答对保护) - imp_unrel_msgs = [] - unimp_unrel_msgs = [] - - for idx, m in enumerate(msgs): - # 问答对中的消息自动标记为重要 - if idx in protected_indices: - imp_unrel_msgs.append((idx, m)) - elif self._msg_matches_tokens(m, tokens) or self._is_important_message(m): - imp_unrel_msgs.append((idx, m)) - elif not self._is_filler_message(m): - unimp_unrel_msgs.append((idx, m)) - # 填充消息不加入任何列表,优先删除 - - # 重要消息按重要性排序 - imp_sorted = sorted(imp_unrel_msgs, key=lambda x: self._importance_score(x[1])) - imp_sorted_ids = [id(m) for _, m in imp_sorted] - - info = { - "dialog": dd, - "total_msgs": len(msgs), - "unrelated_count": len(msgs), - "imp_ids_sorted": imp_sorted_ids, - "unimp_ids": [id(m) for _, m in unimp_unrel_msgs], - "protected_indices": protected_indices, - "qa_pairs_count": len(qa_pairs), - } - per_dialog_info[d["index"]] = info - total_unrelated += info["unrelated_count"] + # ========== 问答对保护(已注释,暂不启用,留作观察) ========== + # qa_pairs = self._identify_qa_pairs(msgs) + # protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=0) + # ======================================================== - # 全局删除配额计算 - global_delete = int(total_unrelated * proportion) - if proportion > 0 and total_unrelated > 0 and global_delete == 0: - global_delete = 1 + # 消息级分类:每条消息独立判断 + important_msgs = [] # 重要消息(保留) + unimportant_msgs = [] # 不重要消息(可删除) + filler_msgs = [] # 填充消息(优先删除) - # 每段的最大可删容量 - capacities = [] - for d in not_related_dialogs: - idx = d["index"] - info = per_dialog_info[idx] - imp_count = len(info["imp_ids_sorted"]) - unimp_count = len(info["unimp_ids"]) - imp_cap = int(imp_count * proportion) - cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1)) - capacities.append(cap) + for idx, m in enumerate(msgs): + msg_text = m.msg.strip() + + # ========== 问答对保护判断(已注释) ========== + # if idx in protected_indices: + # important_msgs.append((idx, m)) + # self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(问答对保护)") + # ========================================== + + # 填充消息(寒暄、表情等) + if self._is_filler_message(m): + filler_msgs.append((idx, m)) + self._log(f" [{idx}] '{msg_text[:30]}...' → 填充") + # 重要信息(学号、成绩、时间、金额等) + elif self._is_important_message(m): + important_msgs.append((idx, m)) + self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)") + # 其他消息 + else: + unimportant_msgs.append((idx, m)) + self._log(f" [{idx}] '{msg_text[:30]}...' → 不重要") - total_capacity = sum(capacities) - if global_delete > total_capacity: - self._log(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}。将按最大可删执行。") - global_delete = total_capacity - - # 配额分配 - alloc = [] - for i, d in enumerate(not_related_dialogs): - idx = d["index"] - info = per_dialog_info[idx] - share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0 - alloc.append(min(share, capacities[i])) + # 计算删除配额 + delete_target = int(original_count * proportion) + if proportion > 0 and original_count > 0 and delete_target == 0: + delete_target = 1 - allocated = sum(alloc) - rem = global_delete - allocated - turn = 0 - while rem > 0 and turn < 100000: - progressed = False - for i in range(len(not_related_dialogs)): - if rem <= 0: - break - if alloc[i] < capacities[i]: - alloc[i] += 1 - rem -= 1 - progressed = True - if not progressed: - break - turn += 1 - - # 应用删除 - total_deleted_confirm = 0 - for d in evaluated_dialogs: - dd = d["dialog"] - msgs = dd.context.msgs - original = len(msgs) - - if d["is_related"]: - result.append(dd) - continue - - idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None) - if idx_in_unrel is None: - result.append(dd) - continue - - quota = alloc[idx_in_unrel] - info = per_dialog_info[d["index"]] - - # 计算删除ID - imp_count = len(info["imp_ids_sorted"]) - imp_del_cap = int(imp_count * proportion) - - unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))]) - del_unimp = min(quota, len(unimp_delete_ids)) - rem_quota = quota - del_unimp - - imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)]) - - deleted_here = 0 - actual_unimp_deleted = 0 - actual_imp_deleted = 0 - kept = [] - - for m in msgs: - mid = id(m) - if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp: - actual_unimp_deleted += 1 - deleted_here += 1 - continue - if mid in imp_delete_ids and actual_imp_deleted < len(imp_delete_ids): - actual_imp_deleted += 1 - deleted_here += 1 - continue - kept.append(m) - - if not kept and msgs: - kept = [msgs[0]] - - dd.context.msgs = kept - total_deleted_confirm += deleted_here - - qa_info = f",问答对={info['qa_pairs_count']}" if info['qa_pairs_count'] > 0 else "" - self._log( - f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}{qa_info}" - ) - result.append(dd) + # 确保至少保留1条消息 + max_deletable = max(0, original_count - 1) + delete_target = min(delete_target, max_deletable) - self._log(f"[剪枝-数据集] 全局消息级剪枝完成,总删除 {total_deleted_confirm} 条(保护问答对和上下文)。") - else: - result = [d["dialog"] for d in evaluated_dialogs] + # 删除策略:优先删除填充消息,再删除不重要消息 + to_delete_indices = set() + deleted_details = [] # 记录删除的消息详情 + + # 第一步:删除填充消息 + filler_to_delete = min(len(filler_msgs), delete_target) + for i in range(filler_to_delete): + idx, msg = filler_msgs[i] + to_delete_indices.add(idx) + deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'") + + # 第二步:如果还需要删除,删除不重要消息 + remaining_quota = delete_target - len(to_delete_indices) + if remaining_quota > 0: + unimp_to_delete = min(len(unimportant_msgs), remaining_quota) + for i in range(unimp_to_delete): + idx, msg = unimportant_msgs[i] + to_delete_indices.add(idx) + deleted_details.append(f"[{idx}] 不重要: '{msg.msg[:50]}'") + + # 第三步:如果还需要删除,按重要性分数删除重要消息 + remaining_quota = delete_target - len(to_delete_indices) + if remaining_quota > 0 and important_msgs: + # 按重要性分数排序(分数低的优先删除) + imp_sorted = sorted(important_msgs, key=lambda x: self._importance_score(x[1])) + imp_to_delete = min(len(imp_sorted), remaining_quota) + for i in range(imp_to_delete): + idx, msg = imp_sorted[i] + to_delete_indices.add(idx) + score = self._importance_score(msg) + deleted_details.append(f"[{idx}] 重要(分数{score}): '{msg.msg[:50]}'") + + # 执行删除 + kept_msgs = [] + for idx, m in enumerate(msgs): + if idx not in to_delete_indices: + kept_msgs.append(m) + + # 确保至少保留1条 + if not kept_msgs and msgs: + kept_msgs = [msgs[0]] + + dd.context.msgs = kept_msgs + deleted_count = original_count - len(kept_msgs) + total_deleted_msgs += deleted_count + + # 输出删除详情 + if deleted_details: + self._log(f"[剪枝-删除详情] 对话 {d_idx+1} 删除了以下消息:") + for detail in deleted_details: + self._log(f" {detail}") + + # ========== 问答对统计(已注释) ========== + # qa_info = f",问答对={len(qa_pairs)}" if qa_pairs else "" + # ======================================== + + self._log( + f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} " + f"(重要={len(important_msgs)} 不重要={len(unimportant_msgs)} 填充={len(filler_msgs)}) " + f"删除={deleted_count} 保留={len(kept_msgs)}" + ) + + result.append(dd) self._log(f"[剪枝-数据集] 剩余对话数={len(result)}") diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/scene_config.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/scene_config.py new file mode 100644 index 00000000..ed9592af --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/scene_config.py @@ -0,0 +1,326 @@ +""" +场景特定配置 - 为不同场景提供定制化的剪枝规则 + +功能: +- 场景特定的重要信息识别模式 +- 场景特定的重要性评分权重 +- 场景特定的填充词库 +- 场景特定的问答对识别规则 +""" + +from typing import Dict, List, Set, Tuple +from dataclasses import dataclass, field + + +@dataclass +class ScenePatterns: + """场景特定的识别模式""" + + # 重要信息的正则模式(优先级从高到低) + high_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) # (pattern, weight) + medium_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) + low_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) + + # 填充词库(无意义对话) + filler_phrases: Set[str] = field(default_factory=set) + + # 问句关键词(用于识别问答对) + question_keywords: Set[str] = field(default_factory=set) + + # 决策性/承诺性关键词 + decision_keywords: Set[str] = field(default_factory=set) + + +class SceneConfigRegistry: + """场景配置注册表 - 管理所有场景的特定配置""" + + # 基础通用模式(所有场景共享) + BASE_HIGH_PRIORITY = [ + (r"订单号|工单|申请号|编号|ID|账号|账户", 5), + (r"金额|费用|价格|¥|¥|\d+元", 5), + (r"\d{11}", 4), # 手机号 + (r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱 + ] + + BASE_MEDIUM_PRIORITY = [ + (r"\d{4}-\d{1,2}-\d{1,2}", 3), # 日期 + (r"\d{4}年\d{1,2}月\d{1,2}日", 3), + (r"电话|手机号|微信|QQ|联系方式", 3), + (r"地址|地点|位置", 2), + (r"时间|日期|有效期|截止", 2), + (r"今天|明天|后天|昨天|前天", 3), # 相对时间(提高权重) + (r"下周|下月|下年|上周|上月|上年|本周|本月|本年", 3), + (r"今年|去年|明年", 3), + ] + + BASE_LOW_PRIORITY = [ + (r"\d{1,2}:\d{2}", 2), # 时间点 HH:MM + (r"\d{1,2}点\d{0,2}分?", 2), # 时间点 X点Y分 或 X点 + (r"上午|下午|中午|晚上|早上|傍晚|凌晨", 2), # 时段(提高权重并扩充) + (r"AM|PM|am|pm", 1), + ] + + BASE_FILLERS = { + # 基础寒暄 + "你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦", + "好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢", + "拜拜", "再见", "88", "拜", "回见", + # 口头禅 + "哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia", + "额", "呃", "啊", "诶", "唉", "哎", "嗯哼", + # 确认词 + "是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了", + # 标点和符号 + "。。。", "...", "???", "???", "!!!", "!!!", + # 表情符号 + "[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]", + "[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]", + "[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]", + "[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]", + # 网络用语 + "hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok", + "emmm", "emm", "em", "mmp", "wtf", "omg", + } + + BASE_QUESTION_KEYWORDS = { + "什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗" + } + + BASE_DECISION_KEYWORDS = { + "必须", "一定", "务必", "需要", "要求", "规定", "应该", + "承诺", "保证", "确保", "负责", "同意", "答应" + } + + @classmethod + def get_education_config(cls) -> ScenePatterns: + """教育场景配置""" + return ScenePatterns( + high_priority_patterns=cls.BASE_HIGH_PRIORITY + [ + # 成绩相关(最高优先级) + (r"成绩|分数|得分|满分|及格|不及格", 6), + (r"GPA|绩点|学分|平均分", 6), + (r"\d+分|\d+\.?\d*分", 5), # 具体分数 + (r"排名|名次|第.{1,3}名", 5), # 支持"第三名"、"第1名"等 + + # 学籍信息 + (r"学号|学生证|教师工号|工号", 5), + (r"班级|年级|专业|院系", 4), + + # 课程相关 + (r"课程|科目|学科|必修|选修", 4), + (r"教材|课本|教科书|参考书", 4), + (r"章节|第.{1,3}章|第.{1,3}节", 3), # 支持"第三章"、"第1章"等 + + # 学科内容(新增) + (r"微积分|导数|积分|函数|极限|微分", 4), + (r"代数|几何|三角|概率|统计", 4), + (r"物理|化学|生物|历史|地理", 4), + (r"英语|语文|数学|政治|哲学", 4), + (r"定义|定理|公式|概念|原理|法则", 3), + (r"例题|解题|证明|推导|计算", 3), + ], + medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [ + # 教学活动 + (r"作业|练习|习题|题目", 3), + (r"考试|测验|测试|考核|期中|期末", 3), + (r"上课|下课|课堂|讲课", 2), + (r"提问|回答|发言|讨论", 2), + (r"问一下|请教|咨询|询问", 2), # 新增:问询相关 + (r"理解|明白|懂|掌握|学会", 2), # 新增:学习状态 + + # 时间安排 + (r"课表|课程表|时间表", 3), + (r"第.{1,3}节课|第.{1,3}周", 2), # 支持"第三节课"、"第1周"等 + ], + low_priority_patterns=cls.BASE_LOW_PRIORITY + [ + (r"老师|教师|同学|学生", 1), + (r"教室|实验室|图书馆", 1), + ], + filler_phrases=cls.BASE_FILLERS | { + # 教育场景特有填充词(移除了"明白了"、"懂了"、"不懂"等,这些在教育场景中有意义) + "老师好", "同学们好", "上课", "下课", "起立", "坐下", + "举手", "请坐", "很好", "不错", "继续", + "下一个", "下一题", "下一位", "还有吗", "还有问题吗", + }, + question_keywords=cls.BASE_QUESTION_KEYWORDS | { + "为啥", "咋", "咋办", "怎样", "如何做", + "能不能", "可不可以", "行不行", "对不对", "是不是", + }, + decision_keywords=cls.BASE_DECISION_KEYWORDS | { + "必考", "重点", "考点", "难点", "关键", + "记住", "背诵", "掌握", "理解", "复习", + } + ) + + @classmethod + def get_online_service_config(cls) -> ScenePatterns: + """在线服务场景配置""" + return ScenePatterns( + high_priority_patterns=cls.BASE_HIGH_PRIORITY + [ + # 工单相关(最高优先级) + (r"工单号|工单编号|ticket|TK\d+", 6), + (r"工单状态|处理中|已解决|已关闭|待处理", 5), + (r"优先级|紧急|高优先级|P0|P1|P2", 5), + + # 产品信息 + (r"产品型号|型号|SKU|产品编号", 5), + (r"序列号|SN|设备号", 5), + (r"版本号|软件版本|固件版本", 4), + + # 问题描述 + (r"故障|错误|异常|bug|问题", 4), + (r"错误代码|故障代码|error code", 5), + (r"无法|不能|失败|报错", 3), + ], + medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [ + # 服务相关 + (r"退款|退货|换货|补发", 4), + (r"发票|收据|凭证", 3), + (r"物流|快递|运单号", 3), + (r"保修|质保|售后", 3), + + # 时效相关 + (r"SLA|响应时间|处理时长", 4), + (r"超时|延迟|等待", 2), + ], + low_priority_patterns=cls.BASE_LOW_PRIORITY + [ + (r"客服|工程师|技术支持", 1), + (r"用户|客户|会员", 1), + ], + filler_phrases=cls.BASE_FILLERS | { + # 在线服务特有填充词 + "您好", "请问", "请稍等", "稍等", "马上", "立即", + "正在查询", "正在处理", "正在为您", "帮您查一下", + "还有其他问题吗", "还需要什么帮助", "很高兴为您服务", + "感谢您的耐心等待", "抱歉让您久等了", + "已记录", "已反馈", "已转接", "已升级", + "祝您生活愉快", "再见", "欢迎下次咨询", + }, + question_keywords=cls.BASE_QUESTION_KEYWORDS | { + "能否", "可否", "是否", "有没有", "能不能", + "怎么办", "如何处理", "怎么解决", + }, + decision_keywords=cls.BASE_DECISION_KEYWORDS | { + "立即处理", "马上解决", "尽快", "优先", + "升级", "转接", "派单", "跟进", + "补偿", "赔偿", "退款", "换货", + } + ) + + @classmethod + def get_outbound_config(cls) -> ScenePatterns: + """外呼场景配置""" + return ScenePatterns( + high_priority_patterns=cls.BASE_HIGH_PRIORITY + [ + # 意向相关(最高优先级) + (r"意向|意愿|兴趣|感兴趣", 6), + (r"A类|B类|C类|D类|高意向|低意向", 6), + (r"成交|签约|下单|购买|确认", 6), + + # 联系信息(外呼场景中更重要) + (r"预约|约定|安排|确定时间", 5), + (r"下次联系|回访|跟进", 5), + (r"方便|有空|可以|时间", 4), + + # 通话状态 + (r"接通|未接通|占线|关机|停机", 4), + (r"通话时长|通话时间", 3), + ], + medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [ + # 客户信息 + (r"姓名|称呼|先生|女士", 3), + (r"公司|单位|职位|职务", 3), + (r"需求|要求|期望", 3), + + # 跟进状态 + (r"跟进状态|进展|进度", 3), + (r"已联系|待联系|联系中", 2), + (r"拒绝|不感兴趣|考虑|再说", 3), + ], + low_priority_patterns=cls.BASE_LOW_PRIORITY + [ + (r"销售|客户经理|业务员", 1), + (r"产品|服务|方案", 1), + ], + filler_phrases=cls.BASE_FILLERS | { + # 外呼场景特有填充词 + "您好", "喂", "hello", "打扰了", "不好意思", + "方便接电话吗", "现在方便吗", "占用您一点时间", + "我是", "我们是", "我们公司", "我们这边", + "了解一下", "介绍一下", "简单说一下", + "考虑考虑", "想一想", "再说", "再看看", + "不需要", "不感兴趣", "没兴趣", "不用了", + "好的", "行", "可以", "没问题", "那就这样", + "再联系", "回头聊", "有需要再说", + }, + question_keywords=cls.BASE_QUESTION_KEYWORDS | { + "有没有", "需不需要", "要不要", "考虑不考虑", + "了解吗", "知道吗", "听说过吗", + "方便吗", "有空吗", "在吗", + }, + decision_keywords=cls.BASE_DECISION_KEYWORDS | { + "确定", "决定", "选择", "购买", "下单", + "预约", "安排", "约定", "确认", + "跟进", "回访", "联系", "沟通", + } + ) + + @classmethod + def get_config(cls, scene: str, fallback_to_generic: bool = True) -> ScenePatterns: + """根据场景名称获取配置 + + Args: + scene: 场景名称 ('education', 'online_service', 'outbound' 或其他) + fallback_to_generic: 如果场景不存在,是否降级到通用配置 + + Returns: + 对应场景的配置,如果场景不存在: + - fallback_to_generic=True: 返回通用配置(仅基础规则) + - fallback_to_generic=False: 抛出异常 + """ + scene_map = { + 'education': cls.get_education_config, + 'online_service': cls.get_online_service_config, + 'outbound': cls.get_outbound_config, + } + + if scene in scene_map: + return scene_map[scene]() + + if fallback_to_generic: + # 返回通用配置(仅包含基础规则,不包含场景特定规则) + return cls.get_generic_config() + else: + raise ValueError(f"不支持的场景: {scene},支持的场景: {list(scene_map.keys())}") + + @classmethod + def get_generic_config(cls) -> ScenePatterns: + """通用场景配置 - 仅包含基础规则,适用于未定义的场景 + + 这是一个保守的配置,只使用最通用的规则,避免误删重要信息 + """ + return ScenePatterns( + high_priority_patterns=cls.BASE_HIGH_PRIORITY, + medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY, + low_priority_patterns=cls.BASE_LOW_PRIORITY, + filler_phrases=cls.BASE_FILLERS, + question_keywords=cls.BASE_QUESTION_KEYWORDS, + decision_keywords=cls.BASE_DECISION_KEYWORDS + ) + + @classmethod + def get_all_scenes(cls) -> List[str]: + """获取所有预定义场景的列表""" + return ['education', 'online_service', 'outbound'] + + @classmethod + def is_scene_supported(cls, scene: str) -> bool: + """检查场景是否有专门的配置支持 + + Args: + scene: 场景名称 + + Returns: + True: 有专门配置 + False: 将使用通用配置 + """ + return scene in cls.get_all_scenes() diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index a47497da..17bda0e4 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1988,6 +1988,7 @@ async def get_chunked_dialogs_with_preprocessing( input_data_path: Optional[str] = None, llm_client: Optional[Any] = None, skip_cleaning: bool = True, + pruning_config: Optional[Dict] = None, ) -> List[DialogData]: """包含数据预处理步骤的完整分块流程 @@ -2000,6 +2001,7 @@ async def get_chunked_dialogs_with_preprocessing( input_data_path: 输入数据路径 llm_client: LLM 客户端 skip_cleaning: 是否跳过数据清洗步骤(默认False) + pruning_config: 剪枝配置字典,包含 pruning_switch, pruning_scene, pruning_threshold Returns: 带 chunks 的 DialogData 列表 @@ -2030,7 +2032,19 @@ async def get_chunked_dialogs_with_preprocessing( from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import ( SemanticPruner, ) - pruner = SemanticPruner(llm_client=llm_client) + from app.core.memory.models.config_models import PruningConfig + + # 构建剪枝配置 + if pruning_config: + # 使用传入的配置 + config = PruningConfig(**pruning_config) + print(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") + else: + # 使用默认配置(关闭剪枝) + config = None + print("[剪枝] 未提供配置,使用默认配置(剪枝关闭)") + + pruner = SemanticPruner(config=config, llm_client=llm_client) # 记录单对话场景下剪枝前的消息数量 single_dialog_original_msgs = None From 54700e6fbe571da04e750625f32deb36ddc9dc5d Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Sat, 28 Feb 2026 17:27:07 +0800 Subject: [PATCH 013/164] fix(workflow): fix exceptions when importing configs from Dify --- .../core/workflow/adapters/base_adapter.py | 2 ++ .../core/workflow/adapters/dify/converter.py | 25 +++++++++++++++---- .../workflow/adapters/dify/dify_adapter.py | 12 +++++++-- api/app/core/workflow/executor.py | 4 +-- api/app/core/workflow/nodes/base_node.py | 14 +++++------ .../workflow/utils/expression_evaluator.py | 16 ++++++++++-- .../core/workflow/utils/template_renderer.py | 13 +++++++++- api/app/schemas/workflow_schema.py | 2 +- 8 files changed, 68 insertions(+), 20 deletions(-) diff --git a/api/app/core/workflow/adapters/base_adapter.py b/api/app/core/workflow/adapters/base_adapter.py index 601c8ff2..49321b89 100644 --- a/api/app/core/workflow/adapters/base_adapter.py +++ b/api/app/core/workflow/adapters/base_adapter.py @@ -68,6 +68,8 @@ class BasePlatformAdapter(ABC): self.branch_node_cache = defaultdict(list) self.error_branch_node_cache = [] + self.node_output_map = {} + @abstractmethod def get_metadata(self) -> PlatformMetadata: """get platform metadata""" diff --git a/api/app/core/workflow/adapters/dify/converter.py b/api/app/core/workflow/adapters/dify/converter.py index 0e92b2c7..18beef15 100644 --- a/api/app/core/workflow/adapters/dify/converter.py +++ b/api/app/core/workflow/adapters/dify/converter.py @@ -44,6 +44,7 @@ class DifyConverter(BaseConverter): warnings: list branch_node_cache: dict error_branch_node_cache: list + node_output_map: dict def __init__(self): self.CONFIG_CONVERT_MAP = { @@ -60,7 +61,8 @@ class DifyConverter(BaseConverter): "knowledge-retrieval": self.convert_knowledge_node_config, "parameter-extractor": self.convert_parameter_extractor_node_config, "question-classifier": self.convert_question_classifier_node_config, - "variable-aggregator": self.convert_variable_aggregator, + "variable-aggregator": self.convert_variable_aggregator_node_config, + "tool": self.convert_tool_node_config, "loop-start": lambda x: {}, "iteration-start": lambda x: {}, "loop-end": lambda x: {}, @@ -74,8 +76,7 @@ class DifyConverter(BaseConverter): def is_variable(expression) -> bool: return bool(re.match(r"\{\{#(.*?)#}}", expression)) - @staticmethod - def process_var_selector(var_selector): + def process_var_selector(self, var_selector): if not var_selector: return "" selector = var_selector.split('.') @@ -86,7 +87,7 @@ class DifyConverter(BaseConverter): var_selector = ".".join(selector) mapping = { "sys.query": "sys.message" - } + } | self.node_output_map var_selector = mapping.get(var_selector, var_selector) return var_selector @@ -124,6 +125,8 @@ class DifyConverter(BaseConverter): "checkbox": VariableType.BOOLEAN, "file-list": VariableType.ARRAY_FILE, "select": VariableType.STRING, + "integer": VariableType.NUMBER, + "float": VariableType.NUMBER, } var_type = type_map.get(source_type, source_type) return var_type @@ -160,6 +163,8 @@ class DifyConverter(BaseConverter): "≥": ComparisonOperator.GE, "≤": ComparisonOperator.LE, "not empty": ComparisonOperator.NOT_EMPTY, + "start with": ComparisonOperator.START_WITH, + "end with": ComparisonOperator.END_WITH, } return operator_map.get(operator, operator) @@ -633,7 +638,7 @@ class DifyConverter(BaseConverter): prompt=node_data["instruction"] ).model_dump() - def convert_variable_aggregator(self, node: dict) -> dict: + def convert_variable_aggregator_node_config(self, node: dict) -> dict: node_data = node["data"] group_enable = node_data["advanced_settings"]["group_enabled"] group_variables = {} @@ -657,3 +662,13 @@ class DifyConverter(BaseConverter): group_variables=group_variables, group_type=group_type, ).model_dump() + + def convert_tool_node_config(self, node: dict) -> dict: + node_data = node["data"] + self.warnings.append(ExceptionDefineition( + node_id=node["id"], + node_name=node_data["title"], + type=ExceptionType.CONFIG, + detail=f"Please reconfigure the tool node.", + )) + return {} \ No newline at end of file diff --git a/api/app/core/workflow/adapters/dify/dify_adapter.py b/api/app/core/workflow/adapters/dify/dify_adapter.py index 48a0cbd6..2ecde092 100644 --- a/api/app/core/workflow/adapters/dify/dify_adapter.py +++ b/api/app/core/workflow/adapters/dify/dify_adapter.py @@ -43,7 +43,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): "knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL, "parameter-extractor": NodeType.PARAMETER_EXTRACTOR, "question-classifier": NodeType.QUESTION_CLASSIFIER, - "variable-aggregator": NodeType.VAR_AGGREGATOR + "variable-aggregator": NodeType.VAR_AGGREGATOR, + "tool": NodeType.TOOL } def __init__(self, config: dict[str, Any]): @@ -89,6 +90,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): return True def parse_workflow(self) -> WorkflowParserResult: + self._init_node_output_map() for node in self.origin_nodes: node = self._convert_node(node) if node: @@ -128,6 +130,11 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): errors=self.errors ) + def _init_node_output_map(self): + for node in self.origin_nodes: + if self.map_node_type(node["data"]["type"]) == NodeType.LLM: + self.node_output_map[f"{node['id']}.text"] = f"{node['id']}.output" + def _convert_cycle_node_position(self, node_id: str, position: dict): for node in self.origin_nodes: if node["id"] == node_id: @@ -214,6 +221,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): type=ExceptionType.EDGE, detail=f"convert edge error - {e}", )) + logger.debug(f"convert edge error - {e}", exc_info=True) return None def _convert_variable(self, variable) -> VariableDefinition | None: @@ -221,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): return VariableDefinition( name=variable["name"], default=variable["value"], - type=variable["value_type"], + type=self.variable_type_map(variable["value_type"]), ) except Exception as e: self.errors.append(ExceptionDefineition( diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 2b554a60..3c3137fe 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -175,7 +175,7 @@ class WorkflowExecutor: elapsed_time = (end_time - start_time).total_seconds() logger.info( - f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}s") + f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms") return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content) @@ -322,7 +322,7 @@ class WorkflowExecutor: ) logger.info( f"Workflow execution completed (streaming), " - f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_context.execution_id}" + f"elapsed: {elapsed_time:.2f}ms, execution_id: {self.execution_context.execution_id}" ) yield { diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index a01ffbe3..3e30c00e 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -196,7 +196,7 @@ class BaseNode(ABC): timeout=timeout ) - elapsed_time = time.time() - start_time + elapsed_time = (time.time() - start_time) * 1000 # Extract processed outputs using subclass-defined logic. extracted_output = self._extract_output(business_result) @@ -219,7 +219,7 @@ class BaseNode(ABC): } | self.trans_activate(state) except TimeoutError: - elapsed_time = time.time() - start_time + elapsed_time = (time.time() - start_time) * 1000 logger.error( f"Node {self.node_id} execution timed out ({timeout} seconds)." ) @@ -230,7 +230,7 @@ class BaseNode(ABC): variable_pool, ) except Exception as e: - elapsed_time = time.time() - start_time + elapsed_time = (time.time() - start_time) * 1000 logger.error( f"Node {self.node_id} execution failed: {e}", exc_info=True, @@ -307,10 +307,10 @@ class BaseNode(ABC): "done": done }) - elapsed_time = time.time() - start_time + elapsed_time = (time.time() - start_time) * 1000 logger.info(f"Node {self.node_id} streaming execution finished, " - f"time elapsed: {elapsed_time:.2f}s, chunks: {chunk_count}") + f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}") # Extract processed output (call subclass's _extract_output) extracted_output = self._extract_output(final_result) @@ -337,7 +337,7 @@ class BaseNode(ABC): yield state_update | self.trans_activate(state) except TimeoutError: - elapsed_time = time.time() - start_time + elapsed_time = (time.time() - start_time) * 1000 logger.error(f"Node {self.node_id} execution timed out ({timeout}s)") error_output = self._wrap_error( f"Node execution timed out ({timeout}s)", @@ -347,7 +347,7 @@ class BaseNode(ABC): ) yield error_output except Exception as e: - elapsed_time = time.time() - start_time + elapsed_time = (time.time() - start_time) * 1000 logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True) error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool) yield error_output diff --git a/api/app/core/workflow/utils/expression_evaluator.py b/api/app/core/workflow/utils/expression_evaluator.py index 26f0c41c..4bc5fc4c 100644 --- a/api/app/core/workflow/utils/expression_evaluator.py +++ b/api/app/core/workflow/utils/expression_evaluator.py @@ -12,9 +12,20 @@ class ExpressionEvaluator: # Reserved namespaces RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"} - - @staticmethod + + @classmethod + def normalize_template(cls, template: str) -> str: + pattern = re.compile( + r"\{\{\s*(\d+)\.(\w+)\s*}}" + ) + return pattern.sub( + r'{{ node["\1"].\2 }}', + template + ) + + @classmethod def evaluate( + cls, expression: str, conv_vars: dict[str, Any], node_outputs: dict[str, Any], @@ -37,6 +48,7 @@ class ExpressionEvaluator: """ # Remove Jinja2-style brackets if present expression = expression.strip() + expression = cls.normalize_template(expression) pattern = r"\{\{\s*(.*?)\s*\}\}" expression = re.sub(pattern, r"\1", expression).strip() diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py index 236e0840..424fdf20 100644 --- a/api/app/core/workflow/utils/template_renderer.py +++ b/api/app/core/workflow/utils/template_renderer.py @@ -5,6 +5,7 @@ """ import logging +import re from typing import Any from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined @@ -39,6 +40,16 @@ class TemplateRenderer: autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML ) + @staticmethod + def normalize_template(template: str) -> str: + pattern = re.compile( + r"\{\{\s*(\d+)\.(\w+)\s*}}" + ) + return pattern.sub( + r'{{ node["\1"].\2 }}', + template + ) + def render( self, template: str, @@ -95,7 +106,7 @@ class TemplateRenderer: context.update(conv_vars) context["nodes"] = node_outputs or {} # 旧语法兼容 - + template = self.normalize_template(template) try: tmpl = self.env.from_string(template) return tmpl.render(**context) diff --git a/api/app/schemas/workflow_schema.py b/api/app/schemas/workflow_schema.py index 9e15f227..e580833f 100644 --- a/api/app/schemas/workflow_schema.py +++ b/api/app/schemas/workflow_schema.py @@ -68,7 +68,7 @@ class WorkflowImportSave(BaseModel): """工作流导入请求""" temp_id: str name: str - description: str + description: str | None = Field(default=None) # ==================== 工作流配置 ==================== From e6aa0e0e108e9aecc35f434ff90f052f1e1fb0fd Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Sat, 28 Feb 2026 17:51:12 +0800 Subject: [PATCH 014/164] [add]New semantic pruning effect display for streaming output --- api/app/services/pilot_run_service.py | 97 +++++++++++++++++++++++++-- 1 file changed, 92 insertions(+), 5 deletions(-) diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index 34b8867e..31e4d6dd 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -101,14 +101,101 @@ async def run_pilot_extraction( ) if progress_callback: - await progress_callback("text_preprocessing", "开始预处理文本...") + await progress_callback("text_preprocessing", "开始预处理文本(语义剪枝 + 语义分块)...") + # ========== 步骤 2.1: 语义剪枝 ========== + pruned_dialogs = [dialog] + deleted_messages = [] # 记录被删除的消息 + + if memory_config.pruning_enabled: + try: + from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import ( + SemanticPruner, + ) + from app.core.memory.models.config_models import PruningConfig + + # 构建剪枝配置 + pruning_config_dict = { + "pruning_switch": memory_config.pruning_enabled, + "pruning_scene": memory_config.pruning_scene, + "pruning_threshold": memory_config.pruning_threshold, + "llm_model_id": str(memory_config.llm_model_id), + } + config = PruningConfig(**pruning_config_dict) + + logger.info(f"[PILOT_RUN] 开始语义剪枝: scene={config.pruning_scene}, threshold={config.pruning_threshold}") + + # 记录剪枝前的消息(用于对比) + original_messages = [{"role": msg.role, "content": msg.msg} for msg in dialog.context.msgs] + original_msg_count = len(original_messages) + + # 执行剪枝 + pruner = SemanticPruner(config=config, llm_client=llm_client) + pruned_dialogs = await pruner.prune_dataset([dialog]) + + # 计算剪枝结果并找出被删除的消息 + if pruned_dialogs and pruned_dialogs[0].context: + remaining_messages = [{"role": msg.role, "content": msg.msg} for msg in pruned_dialogs[0].context.msgs] + remaining_msg_count = len(remaining_messages) + deleted_msg_count = original_msg_count - remaining_msg_count + + # 找出被删除的消息(通过内容对比) + remaining_contents = {msg["content"] for msg in remaining_messages} + deleted_messages = [ + {"index": idx, "role": msg["role"], "content": msg["content"]} + for idx, msg in enumerate(original_messages) + if msg["content"] not in remaining_contents + ] + + pruning_result = { + "enabled": True, + "scene": config.pruning_scene, + "threshold": config.pruning_threshold, + "original_count": original_msg_count, + "remaining_count": remaining_msg_count, + "deleted_count": deleted_msg_count, + "deleted_messages": deleted_messages, + } + + logger.info( + f"[PILOT_RUN] 语义剪枝完成: 原始{original_msg_count}条 -> " + f"保留{remaining_msg_count}条 (删除{deleted_msg_count}条)" + ) + + if progress_callback: + await progress_callback("text_preprocessing_pruning", "语义剪枝完成", pruning_result) + else: + logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话") + pruned_dialogs = [dialog] + + except Exception as e: + logger.error(f"[PILOT_RUN] 语义剪枝失败,使用原始对话: {e}", exc_info=True) + pruned_dialogs = [dialog] + if progress_callback: + error_result = { + "enabled": True, + "error": str(e), + "fallback": "使用原始对话" + } + await progress_callback("text_preprocessing_pruning", "语义剪枝失败", error_result) + else: + logger.info("[PILOT_RUN] 语义剪枝已关闭,跳过") + if progress_callback: + pruning_result = { + "enabled": False, + "message": "语义剪枝已关闭" + } + await progress_callback("text_preprocessing_pruning", "语义剪枝已关闭", pruning_result) + + # ========== 步骤 2.2: 语义分块 ========== chunked_dialogs = await get_chunked_dialogs_from_preprocessed( - data=[dialog], + data=pruned_dialogs, chunker_strategy=memory_config.chunker_strategy, llm_client=llm_client, ) - logger.info(f"Processed dialogue text: {len(messages)} messages") + + remaining_msg_count = len(pruned_dialogs[0].context.msgs) if pruned_dialogs and pruned_dialogs[0].context else 0 + logger.info(f"Processed dialogue text: {remaining_msg_count} messages after pruning") # 进度回调:输出每个分块的结果 if progress_callback: @@ -121,14 +208,14 @@ async def run_pilot_extraction( "dialog_id": dlg.id, "chunker_strategy": memory_config.chunker_strategy, } - await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result) + await progress_callback("text_preprocessing_chunking", f"分块 {i + 1} 处理完成", chunk_result) preprocessing_summary = { "total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs), "total_dialogs": len(chunked_dialogs), "chunker_strategy": memory_config.chunker_strategy, } - await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary) + await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary) log_time("Data Loading & Chunking", time.time() - step_start, log_file) From b79fe07052c8065ba3c04d6db627e5a6d8732c85 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Sat, 28 Feb 2026 18:01:00 +0800 Subject: [PATCH 015/164] feat(web): workflow import & export --- web/src/api/application.ts | 16 +- web/src/i18n/en.ts | 12 +- web/src/i18n/zh.ts | 13 +- .../components/ConfigHeader.tsx | 38 +- web/src/views/ApplicationConfig/types.ts | 5 +- .../components/UploadWorkflowModal.tsx | 348 +++++++++++------- web/src/views/ApplicationManagement/index.tsx | 10 +- web/src/views/ApplicationManagement/types.ts | 60 ++- .../AddChatVariable/ChatVariableModal.tsx | 93 +++-- .../views/Workflow/components/Chat/Chat.tsx | 10 +- .../Workflow/components/Chat/Runtime.tsx | 10 +- .../Workflow/components/Chat/chat.module.css | 5 +- .../components/Properties/CaseList/index.tsx | 10 +- .../Properties/ConditionList/index.tsx | 7 +- .../Properties/hooks/useVariableList.ts | 129 ++++++- web/src/views/Workflow/constant.ts | 2 +- .../views/Workflow/hooks/useWorkflowGraph.ts | 21 +- web/src/views/Workflow/index.tsx | 3 +- 18 files changed, 586 insertions(+), 206 deletions(-) diff --git a/web/src/api/application.ts b/web/src/api/application.ts index 244f3503..f019103e 100644 --- a/web/src/api/application.ts +++ b/web/src/api/application.ts @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 13:59:45 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 13:59:45 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-28 16:34:15 */ import { request } from '@/utils/request' import type { ApplicationModalData } from '@/views/ApplicationManagement/types' @@ -120,3 +120,15 @@ export const copyApplication = (app_id: string, new_name: string) => { export const getAppStatistics = (app_id: string, data: { start_date: number; end_date: number; }) => { return request.get(`/apps/${app_id}/statistics`, data) } +// 导出工作流 +export const exportWorkflow = (app_id: string, fileName: string) => { + return request.downloadFile(`/apps/${app_id}/workflow/export`, fileName, undefined, undefined, 'GET') +} +// 工作流上传+兼容性分析 +export const importWorkflow = (formData: FormData) => { + return request.uploadFile(`/apps/workflow/import`, formData) +} +// 完成工作流导入 +export const completeImportWorkflow = (data: { temp_id: string; name?: string; description?: string }) => { + return request.post(`/apps/workflow/import/save`, data) +} diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 8dfb68db..9df6d018 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1340,12 +1340,20 @@ export const en = { dynamicMatchSkill: 'Dynamic Match Skill', executeTask: 'Execute Task', + importWorkflow: 'Import Workflow', + platform: 'Source Platform', upload: 'Upload & Parse', complex: 'Compatibility Analysis', - node: 'Node Mapping', - configCheck: 'Configuration Validation', sureInfo: 'Information Confirmation', completed: 'Import Completed', + workflowName: 'Workflow Name', + fileName: 'File Name', + fileSize: 'File Size', + importSuccess: 'Import Success', + importSuccessDesc: 'Workflow imported successfully, you can view and manage it in the application management', + gotoList: 'Return to Application List', + gotoDetail: 'View Details', + dify: 'Dify', }, userMemory: { userMemory: 'User Memory', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index feefc843..3fe37ea8 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -736,12 +736,21 @@ export const zh = { dynamicMatchSkill: '动态匹配技能', executeTask: '执行任务', + importWorkflow: '导入工作流', + platform: '来源平台', upload: '上传与解析', complex: '兼容性分析', - node: '节点映射', - configCheck: '配置校验', sureInfo: '信息确认', completed: '完成导入', + baseInfo: '基本信息', + workflowName: '工作流名称', + fileName: '文件名称', + fileSize: '文件大小', + importSuccess: '导入成功', + importSuccessDesc: '您的工作流已成功导入,可以在应用管理中查看和管理', + gotoList: '返回应用列表', + gotoDetail: '查看详情', + dify: 'Dify', }, table: { totalRecords: '共 {{total}} 条记录' diff --git a/web/src/views/ApplicationConfig/components/ConfigHeader.tsx b/web/src/views/ApplicationConfig/components/ConfigHeader.tsx index 374c87e8..42031d85 100644 --- a/web/src/views/ApplicationConfig/components/ConfigHeader.tsx +++ b/web/src/views/ApplicationConfig/components/ConfigHeader.tsx @@ -1,10 +1,10 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 16:27:52 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 16:27:52 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-28 16:48:52 */ -import { type FC, useRef } from 'react'; +import { type FC, useRef, useMemo } from 'react'; import { useNavigate, useParams } from 'react-router-dom'; import { Layout, Tabs, Dropdown, Button, Flex } from 'antd'; import type { MenuProps } from 'antd'; @@ -21,6 +21,7 @@ import ApplicationModal from '@/views/ApplicationManagement/components/Applicati import type { CopyModalRef, AgentRef, ClusterRef, WorkflowRef } from '../types' import { deleteApplication } from '@/api/application' import CopyModal from './CopyModal' +import { exportToYaml } from '@/utils/yamlExport'; const { Header } = Layout; @@ -80,20 +81,6 @@ const ConfigHeader: FC = ({ label: t(`application.${key}`), })) } - /** - * Format dropdown menu items - */ - const formatMenuItems = () => { - const items = ['edit', 'copy', 'export', 'delete'].map(key => ({ - key, - icon: , - label: t(`common.${key}`), - })) - return { - items, - onClick: handleClick - } - } /** * Handle menu item click */ @@ -106,6 +93,8 @@ const ConfigHeader: FC = ({ copyModalRef.current?.handleOpen() break; case 'export': + console.log('export', workflowRef?.current?.config) + exportToYaml(workflowRef?.current?.config, application?.name ?`${application?.name}.yml`: undefined) break; case 'delete': handleDelete() @@ -160,6 +149,19 @@ const ConfigHeader: FC = ({ const addvariable = () => { workflowRef?.current?.addVariable() } + /** + * Format dropdown menu items + */ + const formatMenuItems = useMemo(() => { + const items = (application?.type === 'workflow' ? ['edit', 'copy', 'export', 'delete'] : ['edit', 'copy', 'delete']).map(key => ({ + key, + icon: , + label: t(`common.${key}`), + })) + return items + }, [t, handleClick, application]) + + console.log('formatMenuItems', formatMenuItems) return ( <>
@@ -170,7 +172,7 @@ const ConfigHeader: FC = ({
{application?.name}
diff --git a/web/src/views/ApplicationConfig/types.ts b/web/src/views/ApplicationConfig/types.ts index 2d09f739..36d40a40 100644 --- a/web/src/views/ApplicationConfig/types.ts +++ b/web/src/views/ApplicationConfig/types.ts @@ -2,13 +2,13 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:29:49 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-05 10:31:10 + * @Last Modified time: 2026-02-28 16:40:30 */ import type { KnowledgeConfig } from './components/Knowledge/types' import type { Variable } from './components/VariableList/types' import type { ToolOption } from './components/ToolList/types' import type { ChatItem } from '@/components/Chat/types' -import type { GraphRef } from '@/views/Workflow/types'; +import type { GraphRef, WorkflowConfig } from '@/views/Workflow/types'; import type { ApiKey } from '@/views/ApiKeyManagement/types' import type { SkillConfigForm } from './components/Skill/types' @@ -155,6 +155,7 @@ export interface WorkflowRef { graphRef: GraphRef; /** Add variable */ addVariable: () => void; + config: WorkflowConfig | null; } /** diff --git a/web/src/views/ApplicationManagement/components/UploadWorkflowModal.tsx b/web/src/views/ApplicationManagement/components/UploadWorkflowModal.tsx index 2f2f56b2..68bca452 100644 --- a/web/src/views/ApplicationManagement/components/UploadWorkflowModal.tsx +++ b/web/src/views/ApplicationManagement/components/UploadWorkflowModal.tsx @@ -1,98 +1,203 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-28 14:08:14 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-28 16:20:40 + */ +/** + * UploadWorkflowModal Component + * + * This component provides a modal for uploading workflow files with a multi-step process: + * 1. Upload - Select platform and file + * 2. Complex - Show warnings and errors if any + * 3. SureInfo - Confirm and edit workflow information + * 4. Completed - Show success message and options + */ import { forwardRef, useImperativeHandle, useState, useMemo } from 'react'; -import { Form, Select, Steps, Flex, Alert, Row, Col, Statistic, Input, Button } from 'antd'; +import { Form, Select, Steps, Flex, Alert, Input, Button, Result } from 'antd'; import { useTranslation } from 'react-i18next'; -import type { UploadWorkflowModalData, UploadWorkflowModalRef } from '../types' +import type { UploadWorkflowModalData, UploadData, UploadWorkflowModalRef } from '../types' import RbModal from '@/components/RbModal' import UploadFiles from '@/components/Upload/UploadFiles' -import { fileUploadUrl } from '@/api/fileStorage' -import RbCard from '@/components/RbCard/Card' +import { importWorkflow, completeImportWorkflow } from '@/api/application' +/** + * Props for UploadWorkflowModal component + */ interface UploadWorkflowModalProps { + /** Function to refresh the parent component after workflow import */ refresh: () => void; } + +/** + * Steps definition for the upload process + */ const steps = [ - 'upload', - 'complex', - 'node', - 'configCheck', - 'sureInfo', - 'completed' + 'upload', // Step 1: File upload + 'complex', // Step 2: Error/warning display + 'sureInfo', // Step 3: Information confirmation + 'completed' // Step 4: Success message ] + +/** + * UploadWorkflowModal component + * + * @param {UploadWorkflowModalProps} props - Component props + * @param {React.Ref} ref - Ref for imperative methods + */ const UploadWorkflowModal = forwardRef(({ refresh }, ref) => { const { t } = useTranslation(); - const [visible, setVisible] = useState(false); - const [form] = Form.useForm(); - const [loading, setLoading] = useState(false) - const [current, setCurrent] = useState(5); + + // State management + const [visible, setVisible] = useState(false); // Modal visibility + const [form] = Form.useForm(); // Form instance + const [loading, setLoading] = useState(false); // Loading state + const [current, setCurrent] = useState(0); // Current step + const [data, setData] = useState(null); // Upload response data + const [firstFormData, setFirstFormData] = useState(null); // First step form data + const [appId, setAppId] = useState(null); // Imported application ID - // 封装取消方法,添加关闭弹窗逻辑 + /** + * Handle modal close + * Resets all states and form fields + */ const handleClose = () => { setVisible(false); form.resetFields(); - setLoading(false) + setData(null); + setCurrent(0); + setFirstFormData(null); + setAppId(null); + setLoading(false); }; + /** + * Handle modal open + * Resets form fields and shows modal + */ const handleOpen = () => { form.resetFields(); setVisible(true); }; - // 封装保存方法,添加提交逻辑 + + /** + * Handle save/submit action + * Processes different logic based on current step + */ const handleSave = () => { + const values = form.getFieldsValue(); + switch(current) { - case 0: - setCurrent(1) + case 0: // Step 1: Upload file + const formData = new FormData(); + setFirstFormData(values); + formData.append('platform', values.platform); + formData.append('file', values.file[0]); + + // Call import workflow API + importWorkflow(formData) + .then(res => { + const response = res as UploadData; + const { errors, warnings } = response; + setData(response); + + // Navigate to error/warning step if any, otherwise go to confirmation + if (errors.length || warnings.length) { + setCurrent(1); + } else { + setCurrent(2); + // Pre-fill form with file information + form.setFieldsValue({ + name: values.file[0].name.split('.')[0], + platform: values.platform, + fileName: values.file[0].name, + fileSize: values.file[0].size, + }); + } + }); break; - case 1: - setCurrent(2) + case 1: // Step 2: Error/warning display + if (firstFormData) { + const { file, platform } = firstFormData; + // Pre-fill form with file information + form.setFieldsValue({ + name: file[0].name.split('.')[0], + platform: platform, + fileName: file[0].name, + fileSize: file[0].size, + }); + } + setCurrent(2); break; - case 2: - setCurrent(3) - break; - case 3: - setCurrent(4) - break; - case 4: - setCurrent(5) - break; - case 5: + case 2: // Step 3: Confirm information + if (data) { + // Complete import workflow + completeImportWorkflow({ + temp_id: data.temp_id, + name: values.name, + description: values.description, + }) + .then((res) => { + const response = res as { id: string }; + setCurrent(3); + setAppId(response.id); + }); + } break; default: - setCurrent(prev => prev + 1) + setCurrent(prev => prev + 1); break; } - // form - // .validateFields() - // .then(() => { - // }) - // .catch((err) => { - // console.log('err', err) - // }); - } + }; - // 暴露给父组件的方法 + // Expose methods to parent component via ref useImperativeHandle(ref, () => ({ handleOpen, handleClose })); + /** + * Handle navigation to previous step + * Adjusts step based on whether there were errors/warnings + */ const handleLastStep = () => { - setCurrent(prev => prev - 1) - } + let newStep = current - 1; + // If no errors or warnings, skip the error/warning step + if (!data?.warnings?.length && !data?.errors?.length) { + newStep = current - 2; + } + + // Reset form if not going back to error/warning step + if (newStep !== 1) { + form.resetFields(); + } + setCurrent(newStep); + }; + + /** + * Handle navigation after successful import + * @param {string} type - Navigation type ('detail' or 'list') + */ const handleJump = (type: string) => { switch(type) { case 'detail': - break; - default: + // Open application detail page in new tab + window.open(`/#/application/config/${appId}`, '_blank'); break; } - } + refresh(); + handleClose(); + }; + /** + * Generate modal footer based on current step + */ const getFooter = useMemo(() => { switch(current) { - case 0: + case 0: // Step 1: Upload return [ - ] - case 5: - return [ - , - - ] - default: + ]; + case 3: // Step 4: Completed + return null; + default: // Steps 1-2 return [ , , - ] + ]; } - }, [current]) + }, [current]); return ( + {/* Steps indicator */}
({ title: t(`application.${key}`) }))} />
+ + {/* Step 1: File upload */} {current === 0 &&
- + - - - - - } - {current === 3 && - - - - } - {current === 4 && - +
{t('application.baseInfo')}
- - source + + - fileName + - fileSize + - + - -
{t('application.importStatistic')}
- - {['complex', 'nodes', 'task'].map(key => ( - - - - ))} - } - {current === 5 && - -
导入成功
-
您的工作流已成功导入,可以在应用管理中查看和管理
-
+ + {/* Step 4: Success message */} + {current === 3 && + handleJump('list')}> + {t('application.gotoList')} + , + + ]} + /> }
); diff --git a/web/src/views/ApplicationManagement/index.tsx b/web/src/views/ApplicationManagement/index.tsx index 74dcef05..ca9888de 100644 --- a/web/src/views/ApplicationManagement/index.tsx +++ b/web/src/views/ApplicationManagement/index.tsx @@ -83,9 +83,9 @@ const ApplicationManagement: React.FC = () => { setQuery(prev => ({...prev, type: value})) } - // const handleImport = () => { - // uploadWorkflowModalRef.current?.handleOpen() - // } + const handleImport = () => { + uploadWorkflowModalRef.current?.handleOpen() + } return ( <> @@ -111,9 +111,9 @@ const ApplicationManagement: React.FC = () => { - {/* */} + diff --git a/web/src/views/ApplicationManagement/types.ts b/web/src/views/ApplicationManagement/types.ts index ccc4f114..696b828a 100644 --- a/web/src/views/ApplicationManagement/types.ts +++ b/web/src/views/ApplicationManagement/types.ts @@ -2,12 +2,12 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:34:15 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-06 11:08:37 + * @Last Modified time: 2026-02-28 16:16:03 */ /** * Type definitions for Application Management */ - +import type { WorkflowConfig } from '@/views/Workflow/types'; /** * Search query parameters */ @@ -174,9 +174,63 @@ export interface ApiExtensionModalRef { handleOpen: () => void; } - +/** + * Upload workflow modal form data + */ export interface UploadWorkflowModalData { + /** Platform type (e.g., 'dify') */ + platform: string; + /** Array of uploaded files */ + file: any[]; + /** Optional workflow name */ + name?: string; + /** Optional original file name */ + fileName?: string; + /** Optional file size in bytes */ + fileSize?: number; + /** Optional workflow description */ + description?: string; } + +/** + * Complex item for errors and warnings + */ +interface ComplexItem { + /** Error/warning type */ + type: string; + /** Detailed error/warning message */ + detail: string; + /** Node identifier where the error/warning occurred */ + node_id: string; + /** Node name where the error/warning occurred */ + node_name: string; + /** Optional scope of the error/warning */ + scope: string | null; + /** Optional name associated with the error/warning */ + name: string | null; +} + +/** + * Upload data response + * @extends WorkflowConfig + */ +export interface UploadData extends WorkflowConfig { + /** Whether the upload was successful */ + success: boolean; + /** Temporary identifier for the uploaded workflow */ + temp_id: string; + /** Optional workflow identifier if already exists */ + workflow_id?: string; + /** Array of error items */ + errors: ComplexItem[]; + /** Array of warning items */ + warnings: ComplexItem[]; +} + +/** + * Upload workflow modal ref interface + */ export interface UploadWorkflowModalRef { + /** Open the upload workflow modal */ handleOpen: () => void; } \ No newline at end of file diff --git a/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx b/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx index 52394ea1..15933d4a 100644 --- a/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx +++ b/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx @@ -1,3 +1,15 @@ +/* + * @Author: ZhaoYing + * @Date: 2025-12-30 13:59:36 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-28 16:19:26 + */ +/** + * ChatVariableModal Component + * + * This component provides a modal for adding or editing chat variables in workflows. + * It supports various variable types and provides appropriate input fields based on the selected type. + */ import { forwardRef, useImperativeHandle, useState } from 'react'; import { Form, Input, Select, InputNumber } from 'antd'; import { useTranslation } from 'react-i18next'; @@ -8,54 +20,86 @@ import RbModal from '@/components/RbModal' const FormItem = Form.Item; +/** + * Props for ChatVariableModal component + */ interface ChatVariableModalProps { + /** + * Callback function to refresh variable list + * @param {ChatVariable} value - The variable data + * @param {number} [editIndex] - Optional index for editing existing variable + */ refresh: (value: ChatVariable, editIndex?: number) => void; } +/** + * Supported variable types + */ const types = [ - 'string', - 'number', - 'boolean', + 'string', // String type + 'number', // Number type + 'boolean', // Boolean type + 'object', // Object type + 'array[string]', // Array of strings + 'array[number]', // Array of numbers + 'array[boolean]', // Array of booleans + 'array[object]', // Array of objects ] +/** + * ChatVariableModal component + */ const ChatVariableModal = forwardRef(({ refresh }, ref) => { const { t } = useTranslation(); - const [visible, setVisible] = useState(false); - const [form] = Form.useForm(); - const [loading, setLoading] = useState(false) - const [editIndex, setEditIndex] = useState(undefined) - const type = Form.useWatch('type', form); - // 封装取消方法,添加关闭弹窗逻辑 + // State management + const [visible, setVisible] = useState(false); // Modal visibility + const [form] = Form.useForm(); // Form instance + const [loading, setLoading] = useState(false); // Loading state + const [editIndex, setEditIndex] = useState(undefined); // Index of variable being edited + const type = Form.useWatch('type', form); // Current selected type + + /** + * Handle modal close + */ const handleClose = () => { setVisible(false); form.resetFields(); - setLoading(false) - setEditIndex(undefined) + setLoading(false); + setEditIndex(undefined); }; + /** + * Handle modal open + */ const handleOpen = (variable?: ChatVariable, index?: number) => { setVisible(true); if (variable) { - const { default: _, ...rest } = variable - form.setFieldsValue({ ...rest }) - setEditIndex(index) + // Exclude 'default' property and set form values + const { default: _, ...rest } = variable; + form.setFieldsValue({ ...rest }); + setEditIndex(index); } else { + // Reset form for new variable form.resetFields(); - setEditIndex(undefined) + setEditIndex(undefined); } }; - // 封装保存方法,添加提交逻辑 + + /** + * Handle save/submit action + */ const handleSave = () => { form.validateFields().then((values) => { - refresh({ ...values, default: values.defaultValue }, editIndex) - handleClose() - }) - } + // Create variable with 'default' property mapped from 'defaultValue' + refresh({ ...values, default: values.defaultValue }, editIndex); + handleClose(); + }); + }; - // 暴露给父组件的方法 + // Expose handleOpen method to parent component via ref useImperativeHandle(ref, () => ({ handleOpen })); @@ -74,6 +118,7 @@ const ChatVariableModal = forwardRef + {/* Variable name field */} + + {/* Variable type field */} + + {/* Default value field - dynamic based on type */} } + + {/* Variable description field */} (({ appId setConversationId(null) setMessage(undefined) setFileList([]) + setLoading(false) + setStreamLoading(false) } /** * Opens the variable configuration modal @@ -179,6 +181,7 @@ const Chat = forwardRef(({ appId cycle_idx: number; node_id: string; node_name?: string; + node_type?: string; input?: any; output?: any; elapsed_time?: string; @@ -188,7 +191,7 @@ const Chat = forwardRef(({ appId }; const node = graphRef.current?.getNodes().find(n => n.id === node_id); - const { name, icon } = node?.getData() || {} + const { name, icon, type } = node?.getData() || {} switch(item.event) { // Append streaming text chunks to assistant message @@ -218,6 +221,7 @@ const Chat = forwardRef(({ appId ...newSubContent[filterIndex], node_id: node_id, node_name: name, + node_type: type, icon, content: {}, } @@ -226,6 +230,7 @@ const Chat = forwardRef(({ appId id: node_id, node_id: node_id, node_name: name, + node_type: type, icon, content: {}, }) @@ -282,6 +287,7 @@ const Chat = forwardRef(({ appId cycle_idx, node_id, node_name: name, + node_type: type, icon, content: { cycle_idx, diff --git a/web/src/views/Workflow/components/Chat/Runtime.tsx b/web/src/views/Workflow/components/Chat/Runtime.tsx index 0f18f4da..e3608e10 100644 --- a/web/src/views/Workflow/components/Chat/Runtime.tsx +++ b/web/src/views/Workflow/components/Chat/Runtime.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-24 17:57:08 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-24 17:57:08 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-28 16:48:09 */ /* * Runtime Component @@ -105,7 +105,7 @@ const Runtime: FC<{ item: ChatItem; index: number;}> = ({ if (Array.isArray(list)) { return {list?.map(vo => { - const isLoop = vo.node_id.startsWith('loop'); + const isLoop = vo.node_type === 'loop'; // Render cycle variables for loop nodes without node_name if (typeof vo.cycle_idx === 'number' && isLoop && !vo.node_name) { return
@@ -165,7 +165,7 @@ const Runtime: FC<{ item: ChatItem; index: number;}> = ({ } {/* Display navigation to nested cycles if subContent exists */} {vo.subContent?.length > 0 && ( - handleViewDetail(vo, vo.node_id.startsWith('loop'))}> + handleViewDetail(vo, vo.node_type === 'loop')}> {Math.max(...vo.subContent.map((itemVo: any) => itemVo.cycle_idx + 1))} {t(`workflow.${isLoop ? 'loopNum' : 'iterationNum'}`)} @@ -217,7 +217,7 @@ const Runtime: FC<{ item: ChatItem; index: number;}> = ({ children: ( detail ? ( -
+
diff --git a/web/src/views/Workflow/components/Chat/chat.module.css b/web/src/views/Workflow/components/Chat/chat.module.css index 99fe11f7..c005ef2a 100644 --- a/web/src/views/Workflow/components/Chat/chat.module.css +++ b/web/src/views/Workflow/components/Chat/chat.module.css @@ -28,9 +28,6 @@ background-color: transparent; border-top: none; } -:global(.ant-collapse .ant-collapse-content>.ant-collapse-content-box) { - padding-top: 0; -} .collapse-item :global(.ant-collapse) { /* background-color: #F0F3F8; */ background-color: #FBFDFF; @@ -41,5 +38,5 @@ border-radius: 0 0 6px 6px; } .collapse-item :global(.ant-collapse .ant-collapse-content>.ant-collapse-content-box) { - padding: 0 4px 4px 4px; + padding: 4px; } \ No newline at end of file diff --git a/web/src/views/Workflow/components/Properties/CaseList/index.tsx b/web/src/views/Workflow/components/Properties/CaseList/index.tsx index 34708513..70c4c43f 100644 --- a/web/src/views/Workflow/components/Properties/CaseList/index.tsx +++ b/web/src/views/Workflow/components/Properties/CaseList/index.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-09 18:24:53 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-09 18:24:53 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-28 17:49:28 */ import { type FC } from 'react' import clsx from 'clsx' @@ -292,7 +292,7 @@ const CaseList: FC = ({ const leftFieldOption = options.find(option => `{{${option.value}}}` === leftFieldValue); const leftFieldType = leftFieldOption?.dataType; const operatorList = operatorsObj[leftFieldType || 'default'] || operatorsObj.default || []; - const inputType = leftFieldType === 'number' ? currentExpression.input_type : undefined; + const inputType = leftFieldType === 'number' ? currentExpression.input_type?.toLocaleLowerCase() : undefined; return (
@@ -330,7 +330,7 @@ const CaseList: FC = ({ = ({ - {inputType === 'Variable' + {inputType === 'variable' ? } keys - Set of existing keys to check for duplicates + * @param {string} key - Unique key for the variable + * @param {string} label - Human-readable label for the variable + * @param {string} dataType - Data type of the variable + * @param {string} value - Variable value/expression + * @param {any} nodeData - Node data associated with the variable + * @param {Partial} [extra] - Additional suggestion properties + */ const addVariable = ( list: Suggestion[], keys: Set, @@ -39,6 +72,14 @@ const addVariable = ( } }; +/** + * Process node variables based on node type + * + * @param {any} nodeData - Node data object + * @param {string} dataNodeId - Node ID + * @param {Suggestion[]} variableList - List to add variables to + * @param {Set} addedKeys - Set of already added keys + */ const processNodeVariables = ( nodeData: any, dataNodeId: string, @@ -47,29 +88,35 @@ const processNodeVariables = ( ) => { const { type, config } = nodeData; + // Add node-specific variables if (type in NODE_VARIABLES) { NODE_VARIABLES[type as keyof typeof NODE_VARIABLES].forEach(({ label, dataType, field }) => { addVariable(variableList, addedKeys, `${dataNodeId}_${label}`, label, dataType, `${dataNodeId}.${field}`, nodeData); }); } + // Process special node types switch (type) { case 'start': + // Add start node variables [...(config?.variables?.defaultValue ?? []), ...(config?.variables?.value ?? [])].forEach((v: any) => { if (v?.name) addVariable(variableList, addedKeys, `${dataNodeId}_${v.name}`, v.name, v.type, `${dataNodeId}.${v.name}`, nodeData); }); + // Add system variables config?.variables?.sys?.forEach((v: any) => { if (v?.name) addVariable(variableList, addedKeys, `${dataNodeId}_sys_${v.name}`, `sys.${v.name}`, v.type, `sys.${v.name}`, nodeData); }); break; case 'parameter-extractor': + // Add extracted parameters (config?.params?.defaultValue || []).forEach((p: any) => { if (p?.name) addVariable(variableList, addedKeys, `${dataNodeId}_${p.name}`, p.name, p.type || 'string', `${dataNodeId}.${p.name}`, nodeData); }); break; case 'var-aggregator': + // Add aggregated variables if (config.group.defaultValue) { (config.group_variables.defaultValue || []).forEach((gv: any) => { if (gv?.key) { @@ -93,6 +140,7 @@ const processNodeVariables = ( break; case 'iteration': + // Add iteration output variable let dt = 'string'; if (nodeData.output) { const sv = variableList.find(v => v.value === nodeData.output); @@ -102,11 +150,14 @@ const processNodeVariables = ( break; case 'loop': + // Add loop cycle variables (config.cycle_vars.defaultValue || []).forEach((cv: any) => { if (cv.name?.trim()) addVariable(variableList, addedKeys, `${dataNodeId}_cycle_${cv.name}`, cv.name, cv.type || 'string', `${dataNodeId}.${cv.name}`, nodeData); }); break; + case 'code': + // Add code node output variables (config.output_variables.defaultValue || []).forEach((cv: any) => { if (cv.name?.trim()) addVariable(variableList, addedKeys, `${dataNodeId}_cycle_${cv.name}`, cv.name, cv.type || 'string', `${dataNodeId}.${cv.name}`, nodeData); }); @@ -114,6 +165,9 @@ const processNodeVariables = ( } }; +/** + * Node types that have output variables + */ const hasOutputNodeTypes = [ 'llm', 'knowledge-retrieval', @@ -123,7 +177,15 @@ const hasOutputNodeTypes = [ 'http-request', 'tool', 'jinja-render' -] +]; + +/** + * Get variables for the current node + * + * @param {any} nodeData - Node data object + * @param {any} values - Additional values to merge with node config + * @returns {Suggestion[]} List of node variables + */ export const getCurrentNodeVariables = (nodeData: any, values: any): Suggestion[] => { if (!nodeData || !hasOutputNodeTypes.includes(nodeData.type)) return []; const list: Suggestion[] = []; @@ -137,9 +199,18 @@ export const getCurrentNodeVariables = (nodeData: any, values: any): Suggestion[ ...values } }, dataNodeId, list, keys); + + // Special case: var-aggregator without group enabled returns no variables return nodeData.type === 'var-aggregator' && !nodeData.config.group.defaultValue ? [] : list; }; +/** + * Get variables from child nodes in a loop/iteration + * + * @param {Node} selectedNode - Selected node + * @param {React.MutableRefObject} graphRef - Graph reference + * @returns {Suggestion[]} List of child node variables + */ export const getChildNodeVariables = ( selectedNode: Node, graphRef: React.MutableRefObject @@ -152,8 +223,15 @@ export const getChildNodeVariables = ( const edges = graph.getEdges(); const keys = new Set(); + // Find child nodes in the same cycle const childNodes = nodes.filter(node => node.getData()?.cycle === selectedNode.id); + /** + * Get all connected nodes recursively + * @param {string} nodeId - Node ID to start from + * @param {Set} visited - Set of visited node IDs + * @returns {string[]} List of connected node IDs + */ const getConnectedNodes = (nodeId: string, visited = new Set()): string[] => { if (visited.has(nodeId)) return []; visited.add(nodeId); @@ -161,12 +239,14 @@ export const getChildNodeVariables = ( return [...prev, ...prev.flatMap(id => getConnectedNodes(id, visited))]; }; + // Collect all relevant node IDs const relevantIds = new Set(); childNodes.forEach(child => { relevantIds.add(child.id); getConnectedNodes(child.id).forEach(id => relevantIds.add(id)); }); + // Process each relevant node relevantIds.forEach(id => { const node = nodes.find(n => n.id === id); if (!node) return; @@ -175,6 +255,7 @@ export const getChildNodeVariables = ( const nodeId = nodeData.id; const { type } = nodeData; + // Add node-specific variables if (type in NODE_VARIABLES) { NODE_VARIABLES[type as keyof typeof NODE_VARIABLES].forEach(({ label, dataType, field }) => { const varKey = `${nodeId}_${label}`; @@ -192,6 +273,7 @@ export const getChildNodeVariables = ( }); } + // Add parameter-extractor variables if (type === 'parameter-extractor') { (nodeData.config?.params?.defaultValue || []).forEach((p: any) => { if (p?.name && !keys.has(`${nodeId}_${p.name}`)) { @@ -207,11 +289,36 @@ export const getChildNodeVariables = ( } }); } + + // Add code node variables + if (type === 'code') { + (nodeData.config?.output_variables?.defaultValue || []).forEach((p: any) => { + if (p?.name && !keys.has(`${nodeId}_${p.name}`)) { + keys.add(`${nodeId}_${p.name}`); + list.push({ + key: `${nodeId}_${p.name}`, + label: p.name, + type: 'variable', + dataType: p.type || 'string', + value: `${nodeId}.${p.name}`, + nodeData, + }); + } + }); + } }); return list; }; +/** + * Hook for managing workflow variable list + * + * @param {Node | null | undefined} selectedNode - Currently selected node + * @param {React.MutableRefObject} graphRef - Graph reference + * @param {ChatVariable[]} chatVariables - List of chat variables + * @returns {Suggestion[]} List of available variables + */ export const useVariableList = ( selectedNode: Node | null | undefined, graphRef: React.MutableRefObject, @@ -228,6 +335,12 @@ export const useVariableList = ( const nodes = graph.getNodes(); const keys = new Set(); + /** + * Get all previous connected nodes recursively + * @param {string} nodeId - Node ID to start from + * @param {Set} visited - Set of visited node IDs + * @returns {string[]} List of previous node IDs + */ const getPreviousNodes = (nodeId: string, visited = new Set()): string[] => { if (visited.has(nodeId)) return []; visited.add(nodeId); @@ -235,6 +348,11 @@ export const useVariableList = ( return [...prev, ...prev.flatMap(id => getPreviousNodes(id, visited))]; }; + /** + * Get parent loop/iteration node + * @param {string} nodeId - Node ID to check + * @returns {Node | null} Parent loop/iteration node or null + */ const getParentLoop = (nodeId: string): Node | null => { const node = nodes.find(n => n.id === nodeId); const cycle = node?.getData()?.cycle; @@ -245,17 +363,21 @@ export const useVariableList = ( return null; }; + // Collect relevant node IDs const childIds = nodes.filter(n => n.getData()?.cycle === selectedNode.id).map(n => n.id); const parentLoop = getParentLoop(selectedNode.id); const relevantIds = [...getPreviousNodes(selectedNode.id), ...childIds, ...(parentLoop ? getPreviousNodes(parentLoop.id) : [])]; + // Add chat variables chatVariables?.forEach(v => addVariable(list, keys, `CONVERSATION_${v.name}`, v.name, v.type, `conv.${v.name}`, { type: 'CONVERSATION', name: 'CONVERSATION', icon: '' }, { group: 'CONVERSATION' })); + // Process each relevant node relevantIds.forEach(id => { const node = nodes.find(n => n.id === id); if (node) processNodeVariables(node.getData(), node.getData().id, list, keys); }); + // Add parent loop variables if (parentLoop) { const pd = parentLoop.getData(); const pid = pd.id; @@ -270,7 +392,9 @@ export const useVariableList = ( } else if (pd.type === 'iteration' && !pd.config.input.defaultValue) { let itemType = 'object'; const iv = list.find(v => `{{${v.value}}}` === pd.config.input.defaultValue); - if (iv?.dataType.startsWith('array[')) {itemType = iv.dataType.replace(/^array\[(.+)\]$/, '$1');} + if (iv?.dataType.startsWith('array[')) { + itemType = iv.dataType.replace(/^array\[(.+)\]$/, '$1'); + } addVariable(list, keys, `${pid}_item`, 'item', 'string', `${pid}.item`, pd); addVariable(list, keys, `${pid}_index`, 'index', 'number', `${pid}.index`, pd); } @@ -279,6 +403,7 @@ export const useVariableList = ( return list; }, [selectedNode, graphRef, trigger, chatVariables]); + // Refresh variable list when graph changes useEffect(() => { if (!graphRef?.current) return; const graph = graphRef.current; diff --git a/web/src/views/Workflow/constant.ts b/web/src/views/Workflow/constant.ts index 5ae3e5b0..9c65174c 100644 --- a/web/src/views/Workflow/constant.ts +++ b/web/src/views/Workflow/constant.ts @@ -313,7 +313,7 @@ export const nodeLibrary: NodeLibrary[] = [ config: { input: { type: 'variableList', - filterNodeTypes: ['knowledge-retrieval', 'iteration', 'loop', 'parameter-extractor', 'code'], + filterNodeTypes: ['knowledge-retrieval', 'iteration', 'loop', 'parameter-extractor', 'code', 'CONVERSATION'], filterVariableNames: ['message'] }, parallel: { diff --git a/web/src/views/Workflow/hooks/useWorkflowGraph.ts b/web/src/views/Workflow/hooks/useWorkflowGraph.ts index 35cb5aa3..a50bb416 100644 --- a/web/src/views/Workflow/hooks/useWorkflowGraph.ts +++ b/web/src/views/Workflow/hooks/useWorkflowGraph.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:17:48 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-09 18:37:01 + * @Last Modified time: 2026-02-28 17:59:34 */ import { useRef, useEffect, useState } from 'react'; import { useParams } from 'react-router-dom'; @@ -135,7 +135,24 @@ export const useWorkflowGraph = ({ if (nodeLibraryConfig?.config) { Object.keys(nodeLibraryConfig.config).forEach(key => { - if (type === 'memory-write' && key === 'message' && nodeLibraryConfig.config) { + if (type === 'loop' && key === 'condition' && nodeLibraryConfig.config) { + const { condition } = config; + console.log('condition', condition) + nodeLibraryConfig.config[key].defaultValue = condition ? { + ...condition, + expressions: (condition as any).expressions.map((expr: any) => { + return expr.input_type ? { ...expr, input_type: expr.input_type.toLocaleLowerCase() } : expr + }) + } : {} + } else if (type === 'if-else' && key === 'cases' && nodeLibraryConfig.config) { + const { cases } = config; + nodeLibraryConfig.config[key].defaultValue = cases && Array.isArray(cases) ? cases.map(item => ({ + ...item, + expressions: item.expressions.map((expr: any) => { + return expr.input_type ? { ...expr, input_type: expr.input_type.toLocaleLowerCase() } : expr + }), + })) : [] + } else if (type === 'memory-write' && key === 'message' && nodeLibraryConfig.config) { nodeLibraryConfig.config['messages'].defaultValue = [{ role: 'USER', content: config[key] }] delete nodeLibraryConfig.config[key] } else if (key === 'memory' && nodeLibraryConfig.config && nodeLibraryConfig.config[key]) { diff --git a/web/src/views/Workflow/index.tsx b/web/src/views/Workflow/index.tsx index 506fd3c4..31e3d4df 100644 --- a/web/src/views/Workflow/index.tsx +++ b/web/src/views/Workflow/index.tsx @@ -58,7 +58,8 @@ const Workflow = forwardRef((_props, ref) => { handleSave, handleRun, graphRef, - addVariable + addVariable, + config })) return (
From 035464c0ac65bbc7ff0e0f6208701bcf08d25b36 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Sat, 28 Feb 2026 18:19:44 +0800 Subject: [PATCH 016/164] [fix]Fix the display issue of semantic chunking for streaming output --- api/app/services/pilot_run_service.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index 31e4d6dd..4cfa158d 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -200,18 +200,19 @@ async def run_pilot_extraction( # 进度回调:输出每个分块的结果 if progress_callback: for dlg in chunked_dialogs: - for i, chunk in enumerate(dlg.chunks): - chunk_result = { - "chunk_index": i + 1, - "content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, - "full_length": len(chunk.content), - "dialog_id": dlg.id, - "chunker_strategy": memory_config.chunker_strategy, - } - await progress_callback("text_preprocessing_chunking", f"分块 {i + 1} 处理完成", chunk_result) + if hasattr(dlg, 'chunks') and dlg.chunks: + for i, chunk in enumerate(dlg.chunks): + chunk_result = { + "chunk_index": i + 1, + "content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, + "full_length": len(chunk.content), + "dialog_id": dlg.id, + "chunker_strategy": memory_config.chunker_strategy, + } + await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result) preprocessing_summary = { - "total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs), + "total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs if hasattr(dlg, 'chunks') and dlg.chunks), "total_dialogs": len(chunked_dialogs), "chunker_strategy": memory_config.chunker_strategy, } From 6718553bf4f4ceb105734a98a8b69645e00a34ff Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Sat, 28 Feb 2026 18:47:08 +0800 Subject: [PATCH 017/164] Fix/develop memory rag (#419) * fix_rag/fast summary * fix_rag/fast summary --- .../langgraph_graph/nodes/summary_nodes.py | 49 ++++++++++++++++--- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index 0144c0e9..cf832add 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -17,6 +17,8 @@ from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService +from app.core.rag.nlp.search import knowledge_retrieval + from app.db import get_db template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') @@ -32,6 +34,41 @@ class SummaryNodeService(LLMServiceMixin): # 创建全局服务实例 summary_service = SummaryNodeService() +async def rag_config(state): + user_rag_memory_id = state.get('user_rag_memory_id', '') + kb_config = { + "knowledge_bases": [ + { + "kb_id": user_rag_memory_id, + "similarity_threshold": 0.7, + "vector_similarity_weight": 0.5, + "top_k": 10, + "retrieve_type": "participle" + } + ], + "merge_strategy": "weight", + "reranker_id": os.getenv('reranker_id'), + "reranker_top_k": 10 + } + return kb_config +async def rag_knowledge(state,question): + kb_config = await rag_config(state) + end_user_id = state.get('end_user_id', '') + user_rag_memory_id=state.get("user_rag_memory_id",'') + retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) + try: + retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] + clean_content = '\n\n'.join(retrieval_knowledge) + cleaned_query = question + raw_results = clean_content + logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") + except Exception : + retrieval_knowledge=[] + clean_content = '' + raw_results = '' + cleaned_query = question + logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") + return retrieval_knowledge,clean_content,cleaned_query,raw_results async def summary_history(state: ReadState) -> ReadState: end_user_id = state.get("end_user_id", '') @@ -71,7 +108,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o ) # 验证结构化响应 if structured is None: - logger.warning(f"LLM返回None,使用默认回答") + logger.warning("LLM返回None,使用默认回答") return "信息不足,无法回答" # 根据操作类型提取答案 @@ -82,7 +119,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o if hasattr(structured, 'data') and structured.data: aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答" else: - logger.warning(f"结构化响应缺少data字段") + logger.warning("结构化响应缺少data字段") aimessages = "信息不足,无法回答" # 验证答案不为空 @@ -186,12 +223,13 @@ async def Input_Summary(state: ReadState) -> ReadState: } try: - retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config) + if storage_type!="rag": + retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config) + else: + retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data) except Exception as e: logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True ) retrieve_info, question, raw_results = "", data, [] - - try: # aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2', # 'input_summary',RetrieveSummaryResponse) @@ -290,7 +328,6 @@ async def Summary(state: ReadState)-> ReadState: summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary = summary_result[1] return {"summary":summary} - async def Summary_fails(state: ReadState)-> ReadState: storage_type=state.get("storage_type", '') user_rag_memory_id=state.get("user_rag_memory_id", '') From 4c592bf7e3f0132c72a4e94197222ed67360f773 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Sat, 28 Feb 2026 18:58:33 +0800 Subject: [PATCH 018/164] Feature/default ontology (#424) * [add]Create a workspace and initialize the default ontology engineering scenario * [add]The language parameters for creating the workspace determine the default language for switching in the ontology project. * [changes]Standardized return format * [add]The default ontology is associated with the default configuration. * [add]Create a workspace and initialize the default ontology engineering scenario * [add]The language parameters for creating the workspace determine the default language for switching in the ontology project. * [changes]Standardized return format * [add]The default ontology is associated with the default configuration. --- api/app/config/__init__.py | 1 + api/app/config/default_ontology_config.py | 239 +++++++++++++++++ .../config/default_ontology_initializer.py | 249 ++++++++++++++++++ .../controllers/memory_agent_controller.py | 4 +- api/app/controllers/ontology_controller.py | 31 ++- .../controllers/ontology_secondary_routes.py | 32 ++- api/app/controllers/workspace_controller.py | 21 +- api/app/models/ontology_class.py | 5 +- api/app/models/ontology_scene.py | 5 +- api/app/services/memory_agent_service.py | 12 +- api/app/services/workspace_service.py | 117 +++++++- redbear-mem-benchmark | 2 +- 12 files changed, 696 insertions(+), 22 deletions(-) create mode 100644 api/app/config/__init__.py create mode 100644 api/app/config/default_ontology_config.py create mode 100644 api/app/config/default_ontology_initializer.py diff --git a/api/app/config/__init__.py b/api/app/config/__init__.py new file mode 100644 index 00000000..df675a16 --- /dev/null +++ b/api/app/config/__init__.py @@ -0,0 +1 @@ +"""Configuration module for application settings.""" diff --git a/api/app/config/default_ontology_config.py b/api/app/config/default_ontology_config.py new file mode 100644 index 00000000..157aa73e --- /dev/null +++ b/api/app/config/default_ontology_config.py @@ -0,0 +1,239 @@ +"""默认本体场景配置 + +本模块定义系统预设的本体场景和实体类型配置。 +这些配置用于在工作空间创建时自动初始化默认场景。 +支持中英文双语配置,根据用户语言偏好创建对应语言的场景。 +""" + +# 在线教育场景配置 +ONLINE_EDUCATION_SCENE = { + "name_chinese": "在线教育", + "name_english": "Online Education", + "description_chinese": "适用于在线教育平台的本体建模,包含学生、教师、课程等核心实体类型", + "description_english": "Ontology modeling for online education platforms, including core entity types such as students, teachers, and courses", + "types": [ + { + "name_chinese": "学生", + "name_english": "Student", + "description_chinese": "在教育系统中接受教育的个体,包含姓名、学号、年级、班级等属性", + "description_english": "Individuals receiving education in the education system, including attributes such as name, student ID, grade, and class" + }, + { + "name_chinese": "教师", + "name_english": "Teacher", + "description_chinese": "在教育系统中提供教学服务的个体,包含姓名、工号、任教学科、职称等属性", + "description_english": "Individuals providing teaching services in the education system, including attributes such as name, employee ID, teaching subject, and title" + }, + { + "name_chinese": "课程", + "name_english": "Course", + "description_chinese": "教育系统中的教学内容单元,包含课程名称、课程代码、学分、学时等属性", + "description_english": "Teaching content units in the education system, including attributes such as course name, course code, credits, and class hours" + }, + { + "name_chinese": "作业", + "name_english": "Assignment", + "description_chinese": "课程中布置的学习任务,包含作业标题、截止日期、所属课程、提交状态等属性", + "description_english": "Learning tasks assigned in courses, including attributes such as assignment title, deadline, course, and submission status" + }, + { + "name_chinese": "成绩", + "name_english": "Grade", + "description_chinese": "学生学习成果的评价结果,包含分数、评级、考试类型、所属课程等属性", + "description_english": "Evaluation results of student learning outcomes, including attributes such as score, rating, exam type, and course" + }, + { + "name_chinese": "考试", + "name_english": "Exam", + "description_chinese": "评估学生学习成果的测试活动,包含考试名称、时间、地点、科目等属性", + "description_english": "Test activities to assess student learning outcomes, including attributes such as exam name, time, location, and subject" + }, + { + "name_chinese": "教室", + "name_english": "Classroom", + "description_chinese": "进行教学活动的物理或虚拟空间,包含教室编号、容量、设备等属性", + "description_english": "Physical or virtual spaces for teaching activities, including attributes such as classroom number, capacity, and equipment" + }, + { + "name_chinese": "学科", + "name_english": "Subject", + "description_chinese": "知识的分类领域,包含学科名称、代码、所属院系等属性", + "description_english": "Classification domains of knowledge, including attributes such as subject name, code, and department" + }, + { + "name_chinese": "教材", + "name_english": "Textbook", + "description_chinese": "教学使用的书籍或资料,包含书名、作者、出版社、ISBN等属性", + "description_english": "Books or materials used for teaching, including attributes such as title, author, publisher, and ISBN" + }, + { + "name_chinese": "班级", + "name_english": "Class", + "description_chinese": "学生的组织单位,包含班级名称、年级、人数、班主任等属性", + "description_english": "Organizational units of students, including attributes such as class name, grade, number of students, and class teacher" + }, + { + "name_chinese": "学期", + "name_english": "Semester", + "description_chinese": "教学时间的划分单位,包含学期名称、开始时间、结束时间等属性", + "description_english": "Time division units for teaching, including attributes such as semester name, start time, and end time" + }, + { + "name_chinese": "课时", + "name_english": "Class Hour", + "description_chinese": "课程的时间单位,包含上课时间、地点、教师、课程等属性", + "description_english": "Time units of courses, including attributes such as class time, location, teacher, and course" + }, + { + "name_chinese": "教学计划", + "name_english": "Teaching Plan", + "description_chinese": "课程的教学安排,包含教学目标、内容安排、进度计划等属性", + "description_english": "Teaching arrangements for courses, including attributes such as teaching objectives, content arrangement, and progress plan" + } + ] +} + +# 情感陪伴场景配置 +EMOTIONAL_COMPANION_SCENE = { + "name_chinese": "情感陪伴", + "name_english": "Emotional Companion", + "description_chinese": "适用于情感陪伴应用的本体建模,包含用户、情绪、活动等核心实体类型", + "description_english": "Ontology modeling for emotional companion applications, including core entity types such as users, emotions, and activities", + "types": [ + { + "name_chinese": "用户", + "name_english": "User", + "description_chinese": "使用情感陪伴服务的个体,包含姓名、昵称、性格特征、偏好等属性", + "description_english": "Individuals using emotional companion services, including attributes such as name, nickname, personality traits, and preferences" + }, + { + "name_chinese": "情绪", + "name_english": "Emotion", + "description_chinese": "用户的情感状态,包含情绪类型、强度、触发原因、持续时间等属性", + "description_english": "Emotional states of users, including attributes such as emotion type, intensity, trigger cause, and duration" + }, + { + "name_chinese": "活动", + "name_english": "Activity", + "description_chinese": "用户参与的各类活动,包含活动名称、类型、参与者、时间地点等属性", + "description_english": "Various activities users participate in, including attributes such as activity name, type, participants, time, and location" + }, + { + "name_chinese": "对话", + "name_english": "Conversation", + "description_chinese": "用户之间的交流记录,包含对话主题、参与者、时间、关键内容等属性", + "description_english": "Communication records between users, including attributes such as conversation topic, participants, time, and key content" + }, + { + "name_chinese": "兴趣爱好", + "name_english": "Hobby", + "description_chinese": "用户的兴趣和爱好,包含爱好名称、类别、熟练程度、相关活动等属性", + "description_english": "User interests and hobbies, including attributes such as hobby name, category, proficiency level, and related activities" + }, + { + "name_chinese": "日常事件", + "name_english": "Daily Event", + "description_chinese": "用户日常生活中的事件,包含事件描述、时间、地点、相关人物等属性", + "description_english": "Events in users' daily lives, including attributes such as event description, time, location, and related people" + }, + { + "name_chinese": "关系", + "name_english": "Relationship", + "description_chinese": "用户之间的社会关系,包含关系类型、亲密度、建立时间等属性", + "description_english": "Social relationships between users, including attributes such as relationship type, intimacy, and establishment time" + }, + { + "name_chinese": "回忆", + "name_english": "Memory", + "description_chinese": "用户的重要记忆片段,包含回忆内容、时间、地点、相关人物等属性", + "description_english": "Important memory fragments of users, including attributes such as memory content, time, location, and related people" + }, + { + "name_chinese": "地点", + "name_english": "Location", + "description_chinese": "用户活动的地理位置,包含地点名称、地址、类型、相关事件等属性", + "description_english": "Geographic locations of user activities, including attributes such as location name, address, type, and related events" + }, + { + "name_chinese": "时间节点", + "name_english": "Time Point", + "description_chinese": "重要的时间标记,包含日期、事件、意义等属性", + "description_english": "Important time markers, including attributes such as date, event, and significance" + }, + { + "name_chinese": "目标", + "name_english": "Goal", + "description_chinese": "用户设定的目标,包含目标描述、截止时间、完成状态、相关活动等属性", + "description_english": "Goals set by users, including attributes such as goal description, deadline, completion status, and related activities" + }, + { + "name_chinese": "成就", + "name_english": "Achievement", + "description_chinese": "用户获得的成就,包含成就名称、获得时间、描述、相关目标等属性", + "description_english": "Achievements obtained by users, including attributes such as achievement name, acquisition time, description, and related goals" + } + ] +} + +# 导出默认场景列表 +DEFAULT_SCENES = [ONLINE_EDUCATION_SCENE, EMOTIONAL_COMPANION_SCENE] + + +def get_scene_name(scene_config: dict, language: str = "zh") -> str: + """获取场景名称(根据语言) + + Args: + scene_config: 场景配置字典 + language: 语言类型 ("zh" 或 "en") + + Returns: + 对应语言的场景名称 + """ + if language == "en": + return scene_config.get("name_english", scene_config.get("name_chinese")) + return scene_config.get("name_chinese") + + +def get_scene_description(scene_config: dict, language: str = "zh") -> str: + """获取场景描述(根据语言) + + Args: + scene_config: 场景配置字典 + language: 语言类型 ("zh" 或 "en") + + Returns: + 对应语言的场景描述 + """ + if language == "en": + return scene_config.get("description_english", scene_config.get("description_chinese")) + return scene_config.get("description_chinese") + + +def get_type_name(type_config: dict, language: str = "zh") -> str: + """获取类型名称(根据语言) + + Args: + type_config: 类型配置字典 + language: 语言类型 ("zh" 或 "en") + + Returns: + 对应语言的类型名称 + """ + if language == "en": + return type_config.get("name_english", type_config.get("name_chinese")) + return type_config.get("name_chinese") + + +def get_type_description(type_config: dict, language: str = "zh") -> str: + """获取类型描述(根据语言) + + Args: + type_config: 类型配置字典 + language: 语言类型 ("zh" 或 "en") + + Returns: + 对应语言的类型描述 + """ + if language == "en": + return type_config.get("description_english", type_config.get("description_chinese")) + return type_config.get("description_chinese") diff --git a/api/app/config/default_ontology_initializer.py b/api/app/config/default_ontology_initializer.py new file mode 100644 index 00000000..3d06a352 --- /dev/null +++ b/api/app/config/default_ontology_initializer.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- +"""默认本体场景初始化器 + +本模块提供默认本体场景和类型的自动初始化功能。 +在工作空间创建时,自动添加预设的本体场景和实体类型。 + +Classes: + DefaultOntologyInitializer: 默认本体场景初始化器 +""" + +import logging +from typing import List, Optional, Tuple +from uuid import UUID + +from sqlalchemy.orm import Session + +from app.config.default_ontology_config import ( + DEFAULT_SCENES, + get_scene_name, + get_scene_description, + get_type_name, + get_type_description, +) +from app.core.logging_config import get_business_logger +from app.repositories.ontology_scene_repository import OntologySceneRepository +from app.repositories.ontology_class_repository import OntologyClassRepository + + +class DefaultOntologyInitializer: + """默认本体场景初始化器 + + 负责在工作空间创建时自动初始化默认的本体场景和类型。 + 遵循最小侵入原则,确保初始化失败不阻止工作空间创建。 + + Attributes: + db: 数据库会话 + scene_repo: 场景Repository + class_repo: 类型Repository + logger: 业务日志记录器 + """ + + def __init__(self, db: Session): + """初始化 + + Args: + db: 数据库会话 + """ + self.db = db + self.scene_repo = OntologySceneRepository(db) + self.class_repo = OntologyClassRepository(db) + self.logger = get_business_logger() + + def initialize_default_scenes( + self, + workspace_id: UUID, + language: str = "zh" + ) -> Tuple[bool, str]: + """为工作空间初始化默认场景 + + 创建两个默认场景(在线教育、情感陪伴)及其对应的实体类型。 + 如果创建失败,记录错误日志但不抛出异常。 + + Args: + workspace_id: 工作空间ID + language: 语言类型 ("zh" 或 "en"),默认为 "zh" + + Returns: + Tuple[bool, str]: (是否成功, 错误信息) + """ + try: + self.logger.info( + f"开始初始化默认本体场景 - workspace_id={workspace_id}, language={language}" + ) + + scenes_created = 0 + total_types_created = 0 + + # 遍历默认场景配置 + for scene_config in DEFAULT_SCENES: + scene_name = get_scene_name(scene_config, language) + + # 创建场景及其类型 + scene_id = self._create_scene_with_types(workspace_id, scene_config, language) + + if scene_id: + scenes_created += 1 + # 统计类型数量 + types_count = len(scene_config.get("types", [])) + total_types_created += types_count + + self.logger.info( + f"场景创建成功 - scene_name={scene_name}, " + f"scene_id={scene_id}, types_count={types_count}, language={language}" + ) + else: + self.logger.warning( + f"场景创建失败 - scene_name={scene_name}, " + f"workspace_id={workspace_id}, language={language}" + ) + + # 记录总体结果 + self.logger.info( + f"默认场景初始化完成 - workspace_id={workspace_id}, " + f"language={language}, scenes_created={scenes_created}, " + f"total_types_created={total_types_created}" + ) + + # 如果至少创建了一个场景,视为成功 + if scenes_created > 0: + return True, "" + else: + error_msg = "所有默认场景创建失败" + self.logger.error( + f"默认场景初始化失败 - workspace_id={workspace_id}, " + f"language={language}, error={error_msg}" + ) + return False, error_msg + + except Exception as e: + error_msg = f"默认场景初始化异常: {str(e)}" + self.logger.error( + f"默认场景初始化异常 - workspace_id={workspace_id}, " + f"language={language}, error={str(e)}", + exc_info=True + ) + return False, error_msg + + def _create_scene_with_types( + self, + workspace_id: UUID, + scene_config: dict, + language: str = "zh" + ) -> Optional[UUID]: + """创建场景及其类型 + + Args: + workspace_id: 工作空间ID + scene_config: 场景配置字典 + language: 语言类型 ("zh" 或 "en") + + Returns: + Optional[UUID]: 创建的场景ID,失败返回None + """ + try: + scene_name = get_scene_name(scene_config, language) + scene_description = get_scene_description(scene_config, language) + + # 检查是否已存在同名场景(支持向后兼容) + existing_scene = self.scene_repo.get_by_name(scene_name, workspace_id) + if existing_scene: + self.logger.info( + f"场景已存在,跳过创建 - scene_name={scene_name}, " + f"workspace_id={workspace_id}, scene_id={existing_scene.scene_id}, " + f"language={language}" + ) + return None + + # 创建场景记录,设置 is_system_default=true + scene_data = { + "scene_name": scene_name, + "scene_description": scene_description + } + + scene = self.scene_repo.create(scene_data, workspace_id) + + # 设置系统默认标识 + scene.is_system_default = True + self.db.flush() + + self.logger.info( + f"场景创建成功 - scene_name={scene_name}, " + f"scene_id={scene.scene_id}, is_system_default=True, language={language}" + ) + + # 批量创建类型 + types_config = scene_config.get("types", []) + types_created = self._batch_create_types(scene.scene_id, types_config, language) + + self.logger.info( + f"场景类型创建完成 - scene_id={scene.scene_id}, " + f"types_created={types_created}/{len(types_config)}, language={language}" + ) + + return scene.scene_id + + except Exception as e: + scene_name = get_scene_name(scene_config, language) + self.logger.error( + f"场景创建失败 - scene_name={scene_name}, " + f"workspace_id={workspace_id}, language={language}, error={str(e)}", + exc_info=True + ) + return None + + def _batch_create_types( + self, + scene_id: UUID, + types_config: List[dict], + language: str = "zh" + ) -> int: + """批量创建实体类型 + + Args: + scene_id: 场景ID + types_config: 类型配置列表 + language: 语言类型 ("zh" 或 "en") + + Returns: + int: 成功创建的类型数量 + """ + created_count = 0 + + for type_config in types_config: + try: + type_name = get_type_name(type_config, language) + type_description = get_type_description(type_config, language) + + # 创建类型数据 + class_data = { + "class_name": type_name, + "class_description": type_description + } + + # 创建类型 + ontology_class = self.class_repo.create(class_data, scene_id) + + # 设置系统默认标识 + ontology_class.is_system_default = True + self.db.flush() + + created_count += 1 + + self.logger.debug( + f"类型创建成功 - class_name={type_name}, " + f"class_id={ontology_class.class_id}, " + f"scene_id={scene_id}, is_system_default=True, language={language}" + ) + + except Exception as e: + type_name = get_type_name(type_config, language) + self.logger.warning( + f"单个类型创建失败,继续创建其他类型 - " + f"class_name={type_name}, scene_id={scene_id}, " + f"language={language}, error={str(e)}" + ) + # 继续创建其他类型 + continue + + return created_count diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 0e632fcc..ef65c679 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -633,11 +633,11 @@ async def get_knowledge_type_stats_api( current_user: User = Depends(get_current_user) ): """ - 统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory。 + 统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | Memory。 会对缺失类型补 0,返回字典形式。 可选按状态过滤。 - 知识库类型根据当前用户的 current_workspace_id 过滤 - - memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤 + - Memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤 - 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0 """ api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 49a2fb3a..e4a87141 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -31,7 +31,7 @@ from sqlalchemy.orm import Session from app.core.config import settings from app.core.error_codes import BizCode from app.core.language_utils import get_language_from_header -from app.core.logging_config import get_api_logger +from app.core.logging_config import get_api_logger, get_business_logger from app.core.response_utils import fail, success from app.db import get_db from app.dependencies import get_current_user @@ -61,6 +61,7 @@ from app.repositories.ontology_scene_repository import OntologySceneRepository api_logger = get_api_logger() +business_logger = get_business_logger() logger = logging.getLogger(__name__) router = APIRouter( @@ -399,6 +400,20 @@ async def update_scene( api_logger.warning(f"User {current_user.id} has no current workspace") return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + # 检查是否为系统默认场景 + scene_repo = OntologySceneRepository(db) + scene = scene_repo.get_by_id(scene_uuid) + if scene and scene.is_system_default: + business_logger.warning( + f"尝试修改系统默认场景: user_id={current_user.id}, " + f"scene_id={scene_id}, scene_name={scene.scene_name}" + ) + return fail( + BizCode.BAD_REQUEST, + "系统默认场景不可修改", + "该场景为系统预设场景,不允许修改" + ) + # 创建OntologyService实例 from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.models.base import RedBearModelConfig @@ -491,6 +506,20 @@ async def delete_scene( api_logger.warning(f"User {current_user.id} has no current workspace") return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + # 检查是否为系统默认场景 + scene_repo = OntologySceneRepository(db) + scene = scene_repo.get_by_id(scene_uuid) + if scene and scene.is_system_default: + business_logger.warning( + f"尝试删除系统默认场景: user_id={current_user.id}, " + f"scene_id={scene_id}, scene_name={scene.scene_name}" + ) + return fail( + BizCode.BAD_REQUEST, + "系统默认场景不可删除", + "该场景为系统预设场景,不允许删除" + ) + # 创建OntologyService实例 from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.models.base import RedBearModelConfig diff --git a/api/app/controllers/ontology_secondary_routes.py b/api/app/controllers/ontology_secondary_routes.py index 99017eea..607a0739 100644 --- a/api/app/controllers/ontology_secondary_routes.py +++ b/api/app/controllers/ontology_secondary_routes.py @@ -11,7 +11,7 @@ from fastapi import Depends from sqlalchemy.orm import Session from app.core.error_codes import BizCode -from app.core.logging_config import get_api_logger +from app.core.logging_config import get_api_logger, get_business_logger from app.core.response_utils import fail, success from app.db import get_db from app.dependencies import get_current_user @@ -30,9 +30,11 @@ from app.schemas.response_schema import ApiResponse from app.services.ontology_service import OntologyService from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.models.base import RedBearModelConfig +from app.repositories.ontology_class_repository import OntologyClassRepository api_logger = get_api_logger() +business_logger = get_business_logger() def _get_dummy_ontology_service(db: Session) -> OntologyService: @@ -366,6 +368,20 @@ async def update_class_handler( api_logger.warning(f"User {current_user.id} has no current workspace") return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + # 检查是否为系统默认类型 + class_repo = OntologyClassRepository(db) + ontology_class = class_repo.get_by_id(class_uuid) + if ontology_class and ontology_class.is_system_default: + business_logger.warning( + f"尝试修改系统默认类型: user_id={current_user.id}, " + f"class_id={class_id}, class_name={ontology_class.class_name}" + ) + return fail( + BizCode.BAD_REQUEST, + "系统默认类型不可修改", + "该类型为系统预设类型,不允许修改" + ) + # 创建Service service = _get_dummy_ontology_service(db) @@ -429,6 +445,20 @@ async def delete_class_handler( api_logger.warning(f"User {current_user.id} has no current workspace") return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + # 检查是否为系统默认类型 + class_repo = OntologyClassRepository(db) + ontology_class = class_repo.get_by_id(class_uuid) + if ontology_class and ontology_class.is_system_default: + business_logger.warning( + f"尝试删除系统默认类型: user_id={current_user.id}, " + f"class_id={class_id}, class_name={ontology_class.class_name}" + ) + return fail( + BizCode.BAD_REQUEST, + "系统默认类型不可删除", + "该类型为系统预设类型,不允许删除" + ) + # 创建Service service = _get_dummy_ontology_service(db) diff --git a/api/app/controllers/workspace_controller.py b/api/app/controllers/workspace_controller.py index d2afb10f..9bcd8571 100644 --- a/api/app/controllers/workspace_controller.py +++ b/api/app/controllers/workspace_controller.py @@ -1,7 +1,7 @@ import uuid from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query, status +from fastapi import APIRouter, Depends, Header, HTTPException, Query, status from sqlalchemy.orm import Session from app.core.logging_config import get_api_logger @@ -95,16 +95,29 @@ def get_workspaces( @router.post("", response_model=ApiResponse) def create_workspace( workspace: WorkspaceCreate, + language_type: str = Header(default="zh", alias="X-Language-Type"), db: Session = Depends(get_db), current_user: User = Depends(get_current_superuser), ): """创建新的工作空间""" - api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}") + from app.core.language_utils import get_language_from_header + + # 验证并获取语言参数 + language = get_language_from_header(language_type) + + api_logger.info( + f"用户 {current_user.username} 请求创建工作空间: {workspace.name}, " + f"language={language}" + ) result = workspace_service.create_workspace( - db=db, workspace=workspace, user=current_user) + db=db, workspace=workspace, user=current_user, language=language + ) - api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}") + api_logger.info( + f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, " + f"创建者: {current_user.username}, language={language}" + ) result_schema = WorkspaceResponse.model_validate(result) return success(data=result_schema, msg="工作空间创建成功") diff --git a/api/app/models/ontology_class.py b/api/app/models/ontology_class.py index 528d934e..a8468090 100644 --- a/api/app/models/ontology_class.py +++ b/api/app/models/ontology_class.py @@ -9,7 +9,7 @@ Classes: import datetime import uuid -from sqlalchemy import Column, String, DateTime, Text, ForeignKey +from sqlalchemy import Column, String, DateTime, Text, ForeignKey, Boolean from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship from app.db import Base @@ -25,6 +25,9 @@ class OntologyClass(Base): # 类型信息 class_name = Column(String(200), nullable=False, comment="类型名称") class_description = Column(Text, nullable=True, comment="类型描述") + + # 系统默认标识 + is_system_default = Column(Boolean, default=False, nullable=False, comment="是否为系统默认类型") # 外键:关联到本体场景 scene_id = Column(UUID(as_uuid=True), ForeignKey("ontology_scene.scene_id", ondelete="CASCADE"), nullable=False, index=True, comment="所属场景ID") diff --git a/api/app/models/ontology_scene.py b/api/app/models/ontology_scene.py index 350bfdd6..3ce42cad 100644 --- a/api/app/models/ontology_scene.py +++ b/api/app/models/ontology_scene.py @@ -9,7 +9,7 @@ Classes: import datetime import uuid -from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint +from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint, Boolean from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship from app.db import Base @@ -28,6 +28,9 @@ class OntologyScene(Base): # 场景信息 scene_name = Column(String(200), nullable=False, comment="场景名称") scene_description = Column(Text, nullable=True, comment="场景描述") + + # 系统默认标识 + is_system_default = Column(Boolean, default=False, nullable=False, index=True, comment="是否为系统默认场景") # 外键:关联到工作空间 workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False, index=True, comment="所属工作空间ID") diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index da8a8e06..ad295667 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -816,11 +816,11 @@ class MemoryAgentService: """ 统计知识库类型分布,包含: 1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤) - 2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤) + 2. Neo4j 中的 Memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤) 3. total: 所有类型的总和 参数: - - end_user_id: 用户组ID(可选,未提供时 memory 统计为 0) + - end_user_id: 用户组ID(可选,未提供时 Memory 统计为 0) - only_active: 是否仅统计有效记录 - current_workspace_id: 当前工作空间ID(可选,未提供时知识库统计为 0) - db: 数据库会话 @@ -831,7 +831,7 @@ class MemoryAgentService: "Web": count, "Third-party": count, "Folder": count, - "memory": chunk_count, + "Memory": chunk_count, "total": sum_of_all } """ @@ -912,17 +912,17 @@ class MemoryAgentService: total_chunks += chunk_count logger.debug(f"EndUser {end_user_id_str} Chunk数量: {chunk_count}") - result["memory"] = total_chunks + result["Memory"] = total_chunks logger.info(f"Neo4j memory统计成功: 总Chunk数={total_chunks}, 宿主数={len(end_users)}") else: # 没有 workspace_id 时,返回 0 - result["memory"] = 0 + result["Memory"] = 0 logger.info("未提供 workspace_id,memory 统计为 0") except Exception as e: logger.error(f"Neo4j memory统计失败: {e}", exc_info=True) # 如果 Neo4j 查询失败,memory 设为 0 - result["memory"] = 0 + result["Memory"] = 0 # 3. 计算知识库类型总和(不包括 memory) result["total"] = ( diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py index 6f102695..2f8cdc70 100644 --- a/api/app/services/workspace_service.py +++ b/api/app/services/workspace_service.py @@ -30,6 +30,7 @@ from app.schemas.workspace_schema import ( WorkspaceModelsUpdate, WorkspaceUpdate, ) +from app.config.default_ontology_initializer import DefaultOntologyInitializer # 获取业务逻辑专用日志器 business_logger = get_business_logger() @@ -129,7 +130,7 @@ def _create_workspace_only( raise def create_workspace( - db: Session, workspace: WorkspaceCreate, user: User + db: Session, workspace: WorkspaceCreate, user: User, language: str = "zh" ) -> Workspace: business_logger.info( f"创建工作空间: {workspace.name}, 创建者: {user.username}, " @@ -145,10 +146,68 @@ def create_workspace( db=db, workspace=workspace, tenant_id=user.tenant_id ) business_logger.info(f"工作空间创建成功: {db_workspace.name} (ID: {db_workspace.id}), 创建者: {user.username}") - db.commit() + db.flush() # 使用 flush 而不是 commit,获取 ID 但不提交事务 db.refresh(db_workspace) + # Initialize default ontology scenes for the workspace (先创建本体场景) + default_scene_id = None + try: + initializer = DefaultOntologyInitializer(db) + success, error_msg = initializer.initialize_default_scenes( + db_workspace.id, language=language + ) + + if success: + business_logger.info( + f"为工作空间 {db_workspace.id} 创建默认本体场景成功 (language={language})" + ) + + # 获取默认场景ID,优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景 + from app.repositories.ontology_scene_repository import OntologySceneRepository + from app.config.default_ontology_config import ( + ONLINE_EDUCATION_SCENE, + EMOTIONAL_COMPANION_SCENE, + get_scene_name + ) + + scene_repo = OntologySceneRepository(db) + + # 优先尝试获取教育场景 + education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language) + education_scene = scene_repo.get_by_name(education_scene_name, db_workspace.id) + + if education_scene: + default_scene_id = education_scene.scene_id + business_logger.info( + f"获取到教育场景ID用于默认记忆配置: {default_scene_id} (scene_name={education_scene_name})" + ) + else: + # 如果教育场景不存在,尝试获取情感陪伴场景 + companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language) + companion_scene = scene_repo.get_by_name(companion_scene_name, db_workspace.id) + + if companion_scene: + default_scene_id = companion_scene.scene_id + business_logger.info( + f"教育场景不存在,使用情感陪伴场景ID用于默认记忆配置: {default_scene_id} (scene_name={companion_scene_name})" + ) + else: + business_logger.warning( + f"未找到任何默认场景 (education={education_scene_name}, companion={companion_scene_name})" + ) + else: + business_logger.warning( + f"为工作空间 {db_workspace.id} 创建默认本体场景失败: {error_msg} (language={language})" + ) + except Exception as ontology_error: + business_logger.error( + f"为工作空间 {db_workspace.id} 创建默认本体场景异常: {str(ontology_error)} (language={language})" + ) + # Don't fail workspace creation if default ontology initialization fails + # The workspace can still function without default ontology scenes + # Create default memory config for the workspace (only for neo4j storage types) + # 将默认场景ID(教育场景或情感陪伴场景)关联到记忆配置 if workspace.storage_type == 'neo4j': try: _create_default_memory_config( @@ -158,9 +217,10 @@ def create_workspace( llm_id=llm, embedding_id=embedding, rerank_id=rerank, + scene_id=default_scene_id, # 传入默认场景ID(优先教育场景,其次情感陪伴场景) ) business_logger.info( - f"为工作空间 {db_workspace.id} 创建默认记忆配置成功" + f"为工作空间 {db_workspace.id} 创建默认记忆配置成功 (scene_id={default_scene_id})" ) except Exception as mc_error: business_logger.error( @@ -209,7 +269,6 @@ def create_workspace( db=db, knowledge=knowledge_data ) - db.commit() business_logger.info( f"为工作空间 {db_workspace.id} 自动创建知识库成功: " f"{db_knowledge.name} (ID: {db_knowledge.id})" @@ -224,6 +283,12 @@ def create_workspace( BizCode.INTERNAL_ERROR ) + # 统一提交所有更改 + db.commit() + business_logger.info( + f"工作空间 {db_workspace.id} 及相关资源创建完成并已提交" + ) + return db_workspace except Exception as e: @@ -919,6 +984,43 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None: f"Workspace {workspace.id} missing default memory config, creating one" ) + # 尝试获取默认场景ID,优先教育场景,其次情感陪伴场景 + default_scene_id = None + try: + from app.repositories.ontology_scene_repository import OntologySceneRepository + from app.config.default_ontology_config import ( + ONLINE_EDUCATION_SCENE, + EMOTIONAL_COMPANION_SCENE, + get_scene_name + ) + + scene_repo = OntologySceneRepository(db) + # 尝试中文和英文场景名称 + for language in ["zh", "en"]: + # 优先尝试教育场景 + education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language) + education_scene = scene_repo.get_by_name(education_scene_name, workspace.id) + if education_scene: + default_scene_id = education_scene.scene_id + business_logger.info( + f"找到教育场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={education_scene_name}" + ) + break + + # 如果教育场景不存在,尝试情感陪伴场景 + companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language) + companion_scene = scene_repo.get_by_name(companion_scene_name, workspace.id) + if companion_scene: + default_scene_id = companion_scene.scene_id + business_logger.info( + f"教育场景不存在,找到情感陪伴场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={companion_scene_name}" + ) + break + except Exception as scene_error: + business_logger.warning( + f"获取默认场景失败,将创建不关联场景的记忆配置: {str(scene_error)}" + ) + try: _create_default_memory_config( db=db, @@ -927,6 +1029,7 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None: llm_id=uuid.UUID(workspace.llm) if workspace.llm else None, embedding_id=uuid.UUID(workspace.embedding) if workspace.embedding else None, rerank_id=uuid.UUID(workspace.rerank) if workspace.rerank else None, + scene_id=default_scene_id, # 传入默认场景ID(优先教育场景,其次情感陪伴场景) ) except Exception as e: business_logger.error( @@ -1008,6 +1111,7 @@ def _create_default_memory_config( llm_id: Optional[uuid.UUID] = None, embedding_id: Optional[uuid.UUID] = None, rerank_id: Optional[uuid.UUID] = None, + scene_id: Optional[uuid.UUID] = None, ) -> None: """Create a default memory config for a newly created workspace. @@ -1018,6 +1122,7 @@ def _create_default_memory_config( llm_id: Optional LLM model ID embedding_id: Optional embedding model ID rerank_id: Optional rerank model ID + scene_id: Optional ontology scene ID (默认关联教育场景) """ from app.models.memory_config_model import MemoryConfig @@ -1031,12 +1136,13 @@ def _create_default_memory_config( llm_id=str(llm_id) if llm_id else None, embedding_id=str(embedding_id) if embedding_id else None, rerank_id=str(rerank_id) if rerank_id else None, + scene_id=scene_id, # 关联本体场景ID state=True, # Active by default is_default=True, # Mark as workspace default ) db.add(default_config) - db.commit() + db.flush() # 使用 flush 而不是 commit,让调用者统一提交 business_logger.info( "Created default memory config for workspace", @@ -1044,5 +1150,6 @@ def _create_default_memory_config( "workspace_id": str(workspace_id), "config_id": str(config_id), "config_name": default_config.config_name, + "scene_id": str(scene_id) if scene_id else None, } ) diff --git a/redbear-mem-benchmark b/redbear-mem-benchmark index 4b0257bb..8494e824 160000 --- a/redbear-mem-benchmark +++ b/redbear-mem-benchmark @@ -1 +1 @@ -Subproject commit 4b0257bb4e7dc384b2aaf849b0bd6eae4b39835d +Subproject commit 8494e82498cb99c70ac67a64a544ff872432363a From 77ea0680fb2c39285a0fa1ce0862b045ef6b91a1 Mon Sep 17 00:00:00 2001 From: Mark Date: Sat, 28 Feb 2026 19:22:13 +0800 Subject: [PATCH 019/164] [add] migration script --- .../versions/4bf27c66ae63_202602281918.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 api/migrations/versions/4bf27c66ae63_202602281918.py diff --git a/api/migrations/versions/4bf27c66ae63_202602281918.py b/api/migrations/versions/4bf27c66ae63_202602281918.py new file mode 100644 index 00000000..78b13435 --- /dev/null +++ b/api/migrations/versions/4bf27c66ae63_202602281918.py @@ -0,0 +1,44 @@ +"""202602281918 + +Revision ID: 4bf27c66ae63 +Revises: 7672d8f0f939 +Create Date: 2026-02-28 19:18:38.332468 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '4bf27c66ae63' +down_revision: Union[str, None] = '7672d8f0f939' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # Add columns as nullable first + op.add_column('ontology_class', sa.Column('is_system_default', sa.Boolean(), nullable=True, comment='是否为系统默认类型')) + op.add_column('ontology_scene', sa.Column('is_system_default', sa.Boolean(), nullable=True, comment='是否为系统默认场景')) + + # Set default value for existing rows + op.execute("UPDATE ontology_class SET is_system_default = false WHERE is_system_default IS NULL") + op.execute("UPDATE ontology_scene SET is_system_default = false WHERE is_system_default IS NULL") + + # Now make columns NOT NULL + op.alter_column('ontology_class', 'is_system_default', nullable=False) + op.alter_column('ontology_scene', 'is_system_default', nullable=False) + + op.create_index(op.f('ix_ontology_scene_is_system_default'), 'ontology_scene', ['is_system_default'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_ontology_scene_is_system_default'), table_name='ontology_scene') + op.drop_column('ontology_scene', 'is_system_default') + op.drop_column('ontology_class', 'is_system_default') + # ### end Alembic commands ### From 8b546b73669c455e7ded0ccc0ecd4cbf2a2a181f Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Sat, 28 Feb 2026 19:26:16 +0800 Subject: [PATCH 020/164] [add]Complete the interface integration for the display of semantic pruning for streaming output. --- api/app/services/pilot_run_service.py | 33 +++++++++++++++++---------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index 4cfa158d..c39d089e 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -106,6 +106,7 @@ async def run_pilot_extraction( # ========== 步骤 2.1: 语义剪枝 ========== pruned_dialogs = [dialog] deleted_messages = [] # 记录被删除的消息 + pruning_stats = None # 保存剪枝统计信息,用于最终汇总 if memory_config.pruning_enabled: try: @@ -147,13 +148,17 @@ async def run_pilot_extraction( if msg["content"] not in remaining_contents ] - pruning_result = { + # 保存剪枝统计信息(用于最终汇总,只保留deleted_count) + pruning_stats = { "enabled": True, "scene": config.pruning_scene, "threshold": config.pruning_threshold, - "original_count": original_msg_count, - "remaining_count": remaining_msg_count, "deleted_count": deleted_msg_count, + } + + # 输出剪枝结果(显示删除的消息详情) + pruning_result = { + "type": "pruning", "deleted_messages": deleted_messages, } @@ -163,7 +168,7 @@ async def run_pilot_extraction( ) if progress_callback: - await progress_callback("text_preprocessing_pruning", "语义剪枝完成", pruning_result) + await progress_callback("text_preprocessing_result", "语义剪枝完成", pruning_result) else: logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话") pruned_dialogs = [dialog] @@ -173,19 +178,16 @@ async def run_pilot_extraction( pruned_dialogs = [dialog] if progress_callback: error_result = { - "enabled": True, + "type": "pruning", "error": str(e), "fallback": "使用原始对话" } - await progress_callback("text_preprocessing_pruning", "语义剪枝失败", error_result) + await progress_callback("text_preprocessing_result", "语义剪枝失败", error_result) else: logger.info("[PILOT_RUN] 语义剪枝已关闭,跳过") - if progress_callback: - pruning_result = { - "enabled": False, - "message": "语义剪枝已关闭" - } - await progress_callback("text_preprocessing_pruning", "语义剪枝已关闭", pruning_result) + pruning_stats = { + "enabled": False, + } # ========== 步骤 2.2: 语义分块 ========== chunked_dialogs = await get_chunked_dialogs_from_preprocessed( @@ -203,6 +205,7 @@ async def run_pilot_extraction( if hasattr(dlg, 'chunks') and dlg.chunks: for i, chunk in enumerate(dlg.chunks): chunk_result = { + "type": "chunking", "chunk_index": i + 1, "content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, "full_length": len(chunk.content), @@ -211,11 +214,17 @@ async def run_pilot_extraction( } await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result) + # 构建预处理完成总结(包含剪枝统计) preprocessing_summary = { "total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs if hasattr(dlg, 'chunks') and dlg.chunks), "total_dialogs": len(chunked_dialogs), "chunker_strategy": memory_config.chunker_strategy, } + + # 添加剪枝统计信息 + if pruning_stats: + preprocessing_summary["pruning"] = pruning_stats + await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary) log_time("Data Loading & Chunking", time.time() - step_start, log_file) From 87df352adc4306411b41315dcba698e8ae525be7 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 2 Mar 2026 11:42:46 +0800 Subject: [PATCH 021/164] feat(web): memoryExtractionEngine add pruning --- web/src/i18n/en.ts | 4 + web/src/i18n/zh.ts | 4 + .../components/Result.tsx | 96 +++++++++++-------- 3 files changed, 62 insertions(+), 42 deletions(-) diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index fdbde290..02add0ec 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1640,6 +1640,10 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re scene_type_distribution: 'Scene Type Distribution', general_type_distribution: 'General Type Distribution', unmatched: 'Unmatched', + disagreementCase: 'Disagreement Case', + Pruned: 'Pruned', + pruning: 'Pruning', + pruning_desc: 'Text pruning {{count}} fragments' }, memoryConversation: { searchPlaceholder: 'Enter user ID...', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index c855667f..06abf63a 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1637,6 +1637,10 @@ export const zh = { scene_type_distribution: '场景类型', general_type_distribution: '通用类型', unmatched: '未匹配', + disagreementCase: '不一致案例', + Pruned: '已剪枝', + pruning: '剪枝', + pruning_desc: '文本剪枝{{count}}个片段' }, memoryConversation: { chatEmpty:'有什么我可以帮您的吗?', diff --git a/web/src/views/MemoryExtractionEngine/components/Result.tsx b/web/src/views/MemoryExtractionEngine/components/Result.tsx index 68ff397b..6504f571 100644 --- a/web/src/views/MemoryExtractionEngine/components/Result.tsx +++ b/web/src/views/MemoryExtractionEngine/components/Result.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 17:30:11 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-09 21:04:14 + * @Last Modified time: 2026-03-02 11:41:12 */ /** * Result Component @@ -91,7 +91,7 @@ const Result: FC = ({ loading, handleSave }) => { setDeduplication({...initObj} as ModuleItem) setTestResult({} as TestResult) const handleStreamMessage = (list: SSEMessage[]) => { - + list.forEach((data: AnyObject) => { switch(data.event) { case 'text_preprocessing': // Start text preprocessing @@ -104,7 +104,7 @@ const Result: FC = ({ loading, handleSave }) => { case 'text_preprocessing_result': // Text preprocessing in progress setTextPreprocessing(prev => ({ ...prev, - data: [...prev.data, data.data?.data] + data: [...prev.data, data.data?.deleted_messages ? { deleted_messages: data.data?.deleted_messages } : data.data?.data], })) break case 'text_preprocessing_complete': // Text preprocessing complete @@ -193,9 +193,9 @@ const Result: FC = ({ loading, handleSave }) => { dialogue_text: t('memoryExtractionEngine.exampleText'), custom_text: runForm.getFieldValue('custom_text') }, handleStreamMessage) - .finally(() => { - setRunLoading(false) - }) + .finally(() => { + setRunLoading(false) + }) } const completedNum = [textPreprocessing, knowledgeExtraction, creatingNodesEdges, deduplication].filter(item => item.status === 'completed').length const deduplicationData = groupDataByType(deduplication.data, 'result_type') @@ -251,10 +251,10 @@ const Result: FC = ({ loading, handleSave }) => {
: !testResult || Object.keys(testResult).length === 0 - ? } className="rb:mb-3.5"> - {t('memoryExtractionEngine.warning')} - - : } className="rb:mb-3.5"> + ? } className="rb:mb-3.5"> + {t('memoryExtractionEngine.warning')} + + : } className="rb:mb-3.5"> {t('memoryExtractionEngine.success')} } @@ -266,15 +266,28 @@ const Result: FC = ({ loading, handleSave }) => { headerType="borderL" headerClassName="rb:before:bg-[#155EEF]!" > - {textPreprocessing.data.map((vo, index) => ( -
- -
- ))} + {textPreprocessing.data.map((vo, index) => { + if (vo.deleted_messages) { + return
+
{t('memoryExtractionEngine.Pruned')}
+ {vo.deleted_messages.map((msg: any, idx: number) => ( +
+ +
+ ))} +
+ } + return ( +
+ +
+ ) + })} {formatTime(textPreprocessing)} {textPreprocessing.result && } className="rb:mt-3"> - {t('memoryExtractionEngine.text_preprocessing_desc', { count: textPreprocessing.result.total_chunks })}, + {t('memoryExtractionEngine.pruning_desc', { count: textPreprocessing.result.pruning.deleted_count || 0 })}, + {t('memoryExtractionEngine.text_preprocessing_desc', { count: textPreprocessing.result.total_chunks })}, {t('memoryExtractionEngine.chunkerStrategy')}: {t(`memoryExtractionEngine.${lowercaseFirst(textPreprocessing.result.chunker_strategy)}`)} } @@ -286,7 +299,7 @@ const Result: FC = ({ loading, handleSave }) => { headerType="borderL" headerClassName="rb:before:bg-[#155EEF]!" > - {knowledgeExtraction.data.map((vo, index) => + {knowledgeExtraction.data.map((vo, index) =>
{vo.statement}
)} {formatTime(knowledgeExtraction)} @@ -345,31 +358,30 @@ const Result: FC = ({ loading, handleSave }) => { {Object.keys(resultObj).map((key, index) => { const keys = (resultObj as Record)[key].split('.') return ( -
-
{(testResult?.[keys[0] as keyof TestResult] as any)?.[keys[1]]}
-
{t(`memoryExtractionEngine.${key}`)}
-
- {} - {key === 'extractTheNumberOfEntities' && testResult.dedup - ? t(`memoryExtractionEngine.${key}Desc`, { - num: testResult.dedup.total_merged_count, - exact: testResult.dedup.breakdown.exact, - fuzzy: testResult.dedup.breakdown.fuzzy, - llm: testResult.dedup.breakdown.llm, - }) - : key === 'numberOfEntityDisambiguation' && testResult.disambiguation - ? t(`memoryExtractionEngine.${key}Desc`, { num: testResult.disambiguation.effects?.length, block_count: testResult.disambiguation.block_count }) - : key === 'numberOfRelationalTriples' && testResult.triplets - ? t(`memoryExtractionEngine.${key}Desc`, { num: testResult.triplets.count }) - :t(`memoryExtractionEngine.${key}Desc`) - } +
+
{(testResult?.[keys[0] as keyof TestResult] as any)?.[keys[1]]}
+
{t(`memoryExtractionEngine.${key}`)}
+
+ {key === 'extractTheNumberOfEntities' && testResult.dedup + ? t(`memoryExtractionEngine.${key}Desc`, { + num: testResult.dedup.total_merged_count, + exact: testResult.dedup.breakdown.exact, + fuzzy: testResult.dedup.breakdown.fuzzy, + llm: testResult.dedup.breakdown.llm, + }) + : key === 'numberOfEntityDisambiguation' && testResult.disambiguation + ? t(`memoryExtractionEngine.${key}Desc`, { num: testResult.disambiguation.effects?.length, block_count: testResult.disambiguation.block_count }) + : key === 'numberOfRelationalTriples' && testResult.triplets + ? t(`memoryExtractionEngine.${key}Desc`, { num: testResult.triplets.count }) + :t(`memoryExtractionEngine.${key}Desc`) + } +
-
- )})} + )})}
} - + {testResult?.dedup?.impact && testResult.dedup.impact?.length > 0 && = ({ loading, handleSave }) => {
} - + {testResult?.disambiguation && testResult.disambiguation?.effects?.length > 0 && = ({ loading, handleSave }) => {
0, })}> -
Disagreement Case {index +1}:
+
{t('memoryExtractionEngine.disagreementCase')} {index +1}:
-{item.left.name}({item.left.type}) vs {item.right.name}({item.right.type}) → {item.result}
))} @@ -409,7 +421,7 @@ const Result: FC = ({ loading, handleSave }) => {
} - + {testResult?.core_entities && testResult?.core_entities.length > 0 && = ({ loading, handleSave }) => {
} - + {testResult?.triplet_samples && testResult?.triplet_samples.length > 0 && Date: Fri, 27 Feb 2026 11:06:00 +0800 Subject: [PATCH 022/164] [fix]Complete the API call logic for the homepage --- .../controllers/memory_agent_controller.py | 5 +- .../memory_dashboard_controller.py | 50 +++++++++++++---- api/app/services/memory_agent_service.py | 53 ++----------------- 3 files changed, 46 insertions(+), 62 deletions(-) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index ef65c679..b88e65ff 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -633,12 +633,11 @@ async def get_knowledge_type_stats_api( current_user: User = Depends(get_current_user) ): """ - 统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | Memory。 + 统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。 会对缺失类型补 0,返回字典形式。 可选按状态过滤。 - 知识库类型根据当前用户的 current_workspace_id 过滤 - - Memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤 - - 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0 + - 如果用户没有当前工作空间,对应的统计返回 0 """ api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") try: diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 88684a39..475d184e 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -9,6 +9,7 @@ from app.schemas.response_schema import ApiResponse from app.services import memory_dashboard_service, memory_storage_service, workspace_service from app.services.memory_agent_service import get_end_users_connected_configs_batch +from app.services.app_statistics_service import AppStatisticsService from app.core.logging_config import get_api_logger # 获取API专用日志器 @@ -469,6 +470,8 @@ async def get_chunk_insight( @router.get("/dashboard_data", response_model=ApiResponse) async def dashboard_data( end_user_id: Optional[str] = Query(None, description="可选的用户ID"), + start_date: Optional[int] = Query(None, description="开始时间戳(毫秒)"), + end_date: Optional[int] = Query(None, description="结束时间戳(毫秒)"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): @@ -503,6 +506,15 @@ async def dashboard_data( workspace_id = current_user.current_workspace_id api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据") + # 如果没有提供时间范围,默认使用最近30天 + if start_date is None or end_date is None: + from datetime import datetime, timedelta + end_dt = datetime.now() + start_dt = end_dt - timedelta(days=30) + end_date = int(end_dt.timestamp() * 1000) + start_date = int(start_dt.timestamp() * 1000) + api_logger.info(f"使用默认时间范围: {start_dt} 到 {end_dt}") + # 获取 storage_type,如果为 None 则使用默认值 storage_type = workspace_service.get_workspace_storage_type( db=db, @@ -563,17 +575,22 @@ async def dashboard_data( except Exception as e: api_logger.warning(f"获取知识库类型统计失败: {str(e)}") - # 3. 获取API调用增量(total_api_call,转换为整数) + # 3. 获取API调用统计(total_api_call) try: - api_increment = memory_dashboard_service.get_workspace_api_increment( - db=db, + # 使用 AppStatisticsService 获取真实的API调用统计 + app_stats_service = AppStatisticsService(db) + api_stats = app_stats_service.get_workspace_api_statistics( workspace_id=workspace_id, - current_user=current_user + start_date=start_date, + end_date=end_date ) - neo4j_data["total_api_call"] = api_increment - api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}") + # 计算总调用次数 + total_api_calls = sum(item.get("total_calls", 0) for item in api_stats) + neo4j_data["total_api_call"] = total_api_calls + api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}") except Exception as e: - api_logger.warning(f"获取API调用增量失败: {str(e)}") + api_logger.error(f"获取API调用统计失败: {str(e)}") + neo4j_data["total_api_call"] = 0 result["neo4j_data"] = neo4j_data api_logger.info("成功获取neo4j_data") @@ -602,10 +619,23 @@ async def dashboard_data( total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user) rag_data["total_knowledge"] = total_kb - # total_api_call: 固定值 - rag_data["total_api_call"] = 1024 + # total_api_call: 使用 AppStatisticsService 获取真实的API调用统计 + try: + app_stats_service = AppStatisticsService(db) + api_stats = app_stats_service.get_workspace_api_statistics( + workspace_id=workspace_id, + start_date=start_date, + end_date=end_date + ) + # 计算总调用次数 + total_api_calls = sum(item.get("total_calls", 0) for item in api_stats) + rag_data["total_api_call"] = total_api_calls + api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}") + except Exception as e: + api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}") + rag_data["total_api_call"] = 0 - api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}") + api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}") except Exception as e: api_logger.warning(f"获取RAG相关数据失败: {str(e)}") diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index ad295667..1f3667a6 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -816,11 +816,10 @@ class MemoryAgentService: """ 统计知识库类型分布,包含: 1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤) - 2. Neo4j 中的 Memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤) - 3. total: 所有类型的总和 + 2. total: 所有类型的总和 参数: - - end_user_id: 用户组ID(可选,未提供时 Memory 统计为 0) + - end_user_id: 用户组ID(可选,保留参数以保持接口兼容性) - only_active: 是否仅统计有效记录 - current_workspace_id: 当前工作空间ID(可选,未提供时知识库统计为 0) - db: 数据库会话 @@ -831,7 +830,6 @@ class MemoryAgentService: "Web": count, "Third-party": count, "Folder": count, - "Memory": chunk_count, "total": sum_of_all } """ @@ -878,51 +876,8 @@ class MemoryAgentService: logger.error(f"知识库类型统计失败: {e}") raise Exception(f"知识库类型统计失败: {e}") - # 2. 统计 Neo4j 中的 memory 总量(统计当前空间下所有宿主的 Chunk 总数) - try: - if current_workspace_id: - # 获取当前空间下的所有宿主 - from app.repositories import app_repository, end_user_repository - from app.schemas.app_schema import App as AppSchema - from app.schemas.end_user_schema import EndUser as EndUserSchema - - # 查询应用并转换为 Pydantic 模型 - apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id) - apps = [AppSchema.model_validate(h) for h in apps_orm] - app_ids = [app.id for app in apps] - - # 获取所有宿主 - end_users = [] - for app_id in app_ids: - end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id) - end_users.extend(h for h in end_user_orm_list) - - # 统计所有宿主的 Chunk 总数 - total_chunks = 0 - for end_user in end_users: - end_user_id_str = str(end_user.id) - memory_query = """ - MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN count(n) AS Count - """ - neo4j_result = await _neo4j_connector.execute_query( - memory_query, - end_user_id=end_user_id_str, - ) - chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0 - total_chunks += chunk_count - logger.debug(f"EndUser {end_user_id_str} Chunk数量: {chunk_count}") - - result["Memory"] = total_chunks - logger.info(f"Neo4j memory统计成功: 总Chunk数={total_chunks}, 宿主数={len(end_users)}") - else: - # 没有 workspace_id 时,返回 0 - result["Memory"] = 0 - logger.info("未提供 workspace_id,memory 统计为 0") - - except Exception as e: - logger.error(f"Neo4j memory统计失败: {e}", exc_info=True) - # 如果 Neo4j 查询失败,memory 设为 0 - result["Memory"] = 0 + # 2. 统计 Neo4j 中的 memory 总量已移除 + # memory 字段不再返回 # 3. 计算知识库类型总和(不包括 memory) result["total"] = ( From 3a36d038eee85425ba75dafc7073ff643cd67827 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Feb 2026 12:20:51 +0800 Subject: [PATCH 023/164] [fix]Reconstructing memory incremental statistical scheduling task --- api/app/celery_app.py | 17 ++-- api/app/core/config.py | 1 - api/app/tasks.py | 197 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 203 insertions(+), 12 deletions(-) diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 8ef44975..f422f4a0 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -82,7 +82,7 @@ celery_app.conf.update( 'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'}, 'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'}, 'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'}, - 'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'}, + 'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'}, }, ) @@ -115,16 +115,11 @@ beat_schedule_config = { "config_id": None, # 使用默认配置,可以通过环境变量配置 }, }, + "write-all-workspaces-memory": { + "task": "app.tasks.write_all_workspaces_memory_task", + "schedule": memory_increment_schedule, + "args": (), + }, } -#如果配置了默认工作空间ID,则添加记忆总量统计任务 -if settings.DEFAULT_WORKSPACE_ID: - beat_schedule_config["write-total-memory"] = { - "task": "app.controllers.memory_storage_controller.search_all", - "schedule": memory_increment_schedule, - "kwargs": { - "workspace_id": settings.DEFAULT_WORKSPACE_ID, - }, - } - celery_app.conf.beat_schedule = beat_schedule_config diff --git a/api/app/core/config.py b/api/app/core/config.py index 0962b545..19998d32 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -201,7 +201,6 @@ class Settings: REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300")) HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) - DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None) REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) # Memory Cache Regeneration Configuration diff --git a/api/app/tasks.py b/api/app/tasks.py index d408a0da..8e3aea85 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1304,6 +1304,203 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: "workspace_id": workspace_id, "elapsed_time": elapsed_time, } +@celery_app.task( + name="app.tasks.write_all_workspaces_memory_task", + bind=True, + ignore_result=False, + max_retries=3, + acks_late=True, + time_limit=3600, + soft_time_limit=3300, +) +def write_all_workspaces_memory_task(self) -> Dict[str, Any]: + """定时任务:遍历所有工作空间,统计并写入记忆增量 + + 此任务会: + 1. 查询所有活跃的工作空间 + 2. 对每个工作空间统计记忆总量 + 3. 将统计结果写入 memory_increments 表 + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.core.logging_config import get_api_logger + from app.models.workspace_model import Workspace + from app.models.app_model import App + from app.models.end_user_model import EndUser + from app.repositories.memory_increment_repository import write_memory_increment + from app.services.memory_storage_service import search_all + + api_logger = get_api_logger() + + with get_db_context() as db: + try: + # 获取所有活跃的工作空间 + workspaces = db.query(Workspace).filter( + Workspace.is_active.is_(True) + ).all() + + if not workspaces: + api_logger.warning("没有找到活跃的工作空间") + return { + "status": "SUCCESS", + "message": "没有找到活跃的工作空间", + "workspace_count": 0, + "workspace_results": [] + } + + api_logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量") + all_workspace_results = [] + + # 遍历每个工作空间 + for workspace in workspaces: + workspace_id = workspace.id + api_logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})") + + try: + # 1. 查询当前workspace下的所有app(仅未删除的) + apps = db.query(App).filter( + App.workspace_id == workspace_id, + App.is_active.is_(True) + ).all() + + if not apps: + # 如果没有app,总量为0 + memory_increment = write_memory_increment( + db=db, + workspace_id=workspace_id, + total_num=0 + ) + all_workspace_results.append({ + "workspace_id": str(workspace_id), + "workspace_name": workspace.name, + "status": "SUCCESS", + "total_num": 0, + "end_user_count": 0, + "memory_increment_id": str(memory_increment.id), + "created_at": memory_increment.created_at.isoformat(), + }) + api_logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0") + continue + + # 2. 查询所有app下的end_user_id(去重) + app_ids = [app.id for app in apps] + end_users = db.query(EndUser.id).filter( + EndUser.app_id.in_(app_ids) + ).distinct().all() + + # 3. 遍历所有end_user,查询每个宿主的记忆总量并累加 + total_num = 0 + end_user_details = [] + + for (end_user_id,) in end_users: + try: + # 调用 search_all 接口查询该宿主的总量 + result = await search_all(str(end_user_id)) + user_total = result.get("total", 0) + total_num += user_total + end_user_details.append({ + "end_user_id": str(end_user_id), + "total": user_total + }) + except Exception as e: + # 记录单个用户查询失败,但继续处理其他用户 + api_logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}") + end_user_details.append({ + "end_user_id": str(end_user_id), + "total": 0, + "error": str(e) + }) + + # 4. 写入数据库 + memory_increment = write_memory_increment( + db=db, + workspace_id=workspace_id, + total_num=total_num + ) + + all_workspace_results.append({ + "workspace_id": str(workspace_id), + "workspace_name": workspace.name, + "status": "SUCCESS", + "total_num": total_num, + "end_user_count": len(end_users), + "memory_increment_id": str(memory_increment.id), + "created_at": memory_increment.created_at.isoformat(), + }) + + api_logger.info( + f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}" + ) + + except Exception as e: + db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间 + api_logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}") + all_workspace_results.append({ + "workspace_id": str(workspace_id), + "workspace_name": workspace.name, + "status": "FAILURE", + "error": str(e), + "total_num": 0, + "end_user_count": 0, + }) + + total_memory = sum(r.get("total_num", 0) for r in all_workspace_results) + success_count = sum(1 for r in all_workspace_results if r.get("status") == "SUCCESS") + + return { + "status": "SUCCESS", + "message": f"成功处理 {success_count}/{len(workspaces)} 个工作空间,总记忆量: {total_memory}", + "workspace_count": len(workspaces), + "success_count": success_count, + "total_memory": total_memory, + "workspace_results": all_workspace_results + } + + except Exception as e: + api_logger.error(f"记忆增量统计任务执行失败: {str(e)}") + return { + "status": "FAILURE", + "error": str(e), + "workspace_count": 0, + "workspace_results": [] + } + + try: + # 使用 nest_asyncio 来避免事件循环冲突 + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + # 尝试获取现有事件循环,如果不存在则创建新的 + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) + elapsed_time = time.time() - start_time + result["elapsed_time"] = elapsed_time + result["task_id"] = self.request.id + + return result + except Exception as e: + elapsed_time = time.time() - start_time + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": elapsed_time, + "task_id": self.request.id + } @celery_app.task( From ed0d963aeca622a44577fc0c0e3bb1602db456d5 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Feb 2026 14:47:23 +0800 Subject: [PATCH 024/164] [fix]Modify the person who generates the user summary --- .../core/memory/utils/prompt/prompt_utils.py | 14 +++++++--- .../utils/prompt/prompts/user_summary.jinja2 | 26 ++++++++++++++----- api/app/services/user_memory_service.py | 24 ++++++++++++++++- 3 files changed, 54 insertions(+), 10 deletions(-) diff --git a/api/app/core/memory/utils/prompt/prompt_utils.py b/api/app/core/memory/utils/prompt/prompt_utils.py index 50d31f2a..d88f50cf 100644 --- a/api/app/core/memory/utils/prompt/prompt_utils.py +++ b/api/app/core/memory/utils/prompt/prompt_utils.py @@ -400,7 +400,8 @@ async def render_user_summary_prompt( user_id: str, entities: str, statements: str, - language: str = "zh" + language: str = "zh", + user_display_name: str = None ) -> str: """ Renders the user summary prompt using the user_summary.jinja2 template. @@ -410,16 +411,22 @@ async def render_user_summary_prompt( entities: Core entities with frequency information statements: Representative statement samples language: The language to use for summary generation ("zh" for Chinese, "en" for English) + user_display_name: Display name for the user (e.g., other_name or "该用户"/"the user") Returns: Rendered prompt content as string """ + # 如果没有提供 user_display_name,使用默认值 + if user_display_name is None: + user_display_name = "该用户" if language == "zh" else "the user" + template = prompt_env.get_template("user_summary.jinja2") rendered_prompt = template.render( user_id=user_id, entities=entities, statements=statements, - language=language + language=language, + user_display_name=user_display_name ) # 记录渲染结果到提示日志 @@ -429,7 +436,8 @@ async def render_user_summary_prompt( 'user_id': user_id, 'entities_len': len(entities), 'statements_len': len(statements), - 'language': language + 'language': language, + 'user_display_name': user_display_name }) return rendered_prompt diff --git a/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 b/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 index 35619112..30b48719 100644 --- a/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 @@ -14,8 +14,8 @@ Your task is to generate a comprehensive user profile based on the provided enti {% endif %} ===Inputs=== -{% if user_id %} -- User ID: {{ user_id }} +{% if user_display_name %} +- User Display Name: {{ user_display_name }} {% endif %} {% if entities %} - Core Entities & Frequency: {{ entities }} @@ -33,6 +33,20 @@ Your task is to generate a comprehensive user profile based on the provided enti 3. Avoid excessive adjectives and empty phrases 4. Strictly follow the output format specified below +{% if language == "zh" %} +**【严格人称规定】** +- 在描述用户时,必须使用"{{ user_display_name }}"作为人称 +- 绝对禁止使用用户ID(如 {{ user_id }})来称呼用户 +- 绝对禁止在摘要中出现任何形式的UUID或ID字符串 +- 如果需要指代用户,只能使用"{{ user_display_name }}"或相应的代词(他/她/TA) +{% else %} +**【STRICT PRONOUN RULES】** +- When describing the user, you MUST use "{{ user_display_name }}" as the reference +- It is ABSOLUTELY FORBIDDEN to use the user ID (such as {{ user_id }}) to refer to the user +- It is ABSOLUTELY FORBIDDEN to include any form of UUID or ID string in the summary +- If you need to refer to the user, you can ONLY use "{{ user_display_name }}" or appropriate pronouns (he/she/they) +{% endif %} + **Section-Specific Requirements:** {% if language == "zh" %} @@ -103,13 +117,13 @@ Your task is to generate a comprehensive user profile based on the provided enti {% if language == "zh" %} Example Input: -- User ID: user_12345 +- User Display Name: 张三 - Core Entities & Frequency: 产品经理 (15), AI (12), 深圳 (10), 数据分析 (8), 团队协作 (7) - Representative Statement Samples: 我在深圳从事产品经理工作已经5年了 | 我相信好的产品源于对用户需求的深刻理解 | 我喜欢在团队中起到协调作用 | 数据驱动决策是我的工作原则 Example Output: 【基本介绍】 -我是张三,一名充满热情的高级产品经理。在过去的5年里,我专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。我相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。 +张三是一名充满热情的高级产品经理,在深圳工作。在过去的5年里,张三专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。张三相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。 【性格特点】 性格开朗,善于沟通,注重细节。喜欢在团队中起到协调作用,帮助大家达成共识。面对挑战时保持乐观,相信每个问题都有解决方案。 @@ -121,13 +135,13 @@ Example Output: "让每一个产品决策都充满温度。" {% else %} Example Input: -- User ID: user_12345 +- User Display Name: John - Core Entities & Frequency: Product Manager (15), AI (12), San Francisco (10), Data Analysis (8), Team Collaboration (7) - Representative Statement Samples: I have been working as a product manager in San Francisco for 5 years | I believe good products come from deep understanding of user needs | I enjoy playing a coordinating role in teams | Data-driven decision making is my work principle Example Output: 【Basic Introduction】 -This is a passionate senior product manager based in San Francisco. Over the past 5 years, they have focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. They believe good products stem from deep understanding of user needs and continuous exploration of technological possibilities. +John is a passionate senior product manager based in San Francisco. Over the past 5 years, John has focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. John believes good products stem from deep understanding of user needs and continuous exploration of technological possibilities. 【Personality Traits】 Outgoing personality with excellent communication skills and attention to detail. Enjoys playing a coordinating role in teams, helping everyone reach consensus. Maintains optimism when facing challenges, believing every problem has a solution. diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 80413c12..e34756b9 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -1163,11 +1163,32 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st """ from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt from app.core.language_utils import validate_language + from app.repositories.end_user_repository import EndUserRepository + from app.db import get_db import re # 验证语言参数 language = validate_language(language) + # 获取用户的 other_name 字段 + user_display_name = "该用户" if language == "zh" else "the user" + if end_user_id: + try: + # 获取数据库会话并查询用户信息 + db = next(get_db()) + try: + repo = EndUserRepository(db) + end_user = repo.get_by_id(uuid.UUID(end_user_id)) + if end_user and end_user.other_name: + user_display_name = end_user.other_name + logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}") + else: + logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}") + finally: + db.close() + except Exception as e: + logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}") + # 创建 UserSummaryHelper 实例 user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_end_user_id", "group_123")) @@ -1184,7 +1205,8 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st user_id=user_summary_tool.user_id, entities=", ".join(entity_lines) if entity_lines else "(空)" if language == "zh" else "(empty)", statements=" | ".join(statement_samples) if statement_samples else "(空)" if language == "zh" else "(empty)", - language=language + language=language, + user_display_name=user_display_name ) messages = [ From 6db6c33564c9cd85573662294239e83e9d43fde5 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Feb 2026 14:59:28 +0800 Subject: [PATCH 025/164] [fix]Reduce the default number of items returned for popular tags --- api/app/core/memory/analytics/hot_memory_tags.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index f99b811e..5ffc6fed 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -139,10 +139,10 @@ async def get_raw_tags_from_db( return [(record["name"], record["frequency"]) for record in results] -async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: +async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = False) -> List[Tuple[str, int]]: """ 获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。 - 查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。 + 查询更多的标签(limit=10)给LLM提供更丰富的上下文进行筛选。 Args: end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id From 4d59e04abad24d242269cf46f966fa1dfa8c25b4 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Feb 2026 15:08:06 +0800 Subject: [PATCH 026/164] [changes]Ensure that there are sufficient labels for LLM to process, and control the number of label returns. --- api/app/core/memory/analytics/hot_memory_tags.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index 5ffc6fed..abb0f138 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -142,11 +142,11 @@ async def get_raw_tags_from_db( async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = False) -> List[Tuple[str, int]]: """ 获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。 - 查询更多的标签(limit=10)给LLM提供更丰富的上下文进行筛选。 + 查询更多的标签(40条)给LLM提供更丰富的上下文进行筛选,但最终返回数量由limit参数控制。 Args: end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id - limit: 返回的标签数量限制 + limit: 最终返回的标签数量限制(默认10) by_user: 是否按user_id查询(默认False,按end_user_id查询) Raises: @@ -161,8 +161,9 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = # 使用项目的Neo4jConnector connector = Neo4jConnector() try: - # 1. 从数据库获取原始排名靠前的标签 - raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user) + # 1. 从数据库获取原始排名靠前的标签(查询40条给LLM提供更丰富的上下文) + query_limit = 40 + raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user) if not raw_tags_with_freq: return [] @@ -177,7 +178,8 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = if tag in meaningful_tag_names: final_tags.append((tag, freq)) - return final_tags + # 4. 限制返回的标签数量 + return final_tags[:limit] finally: # 确保关闭连接 await connector.close() From 0ba370052ecddf7a3d697bc5a4acf825c5dd0c4e Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Feb 2026 16:09:22 +0800 Subject: [PATCH 027/164] [fix]Address the shortcomings of intelligent pruning --- .../data_preprocessing/data_pruning.py | 518 ++++++++++++++---- 1 file changed, 423 insertions(+), 95 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index d19e511b..2d0142c6 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -5,14 +5,17 @@ - 对话级一次性抽取判定相关性 - 仅对"不相关对话"的消息按比例删除 - 重要信息(时间、编号、金额、联系方式、地址等)优先保留 +- 改进版:增强重要性判断、智能填充消息识别、问答对保护、并发优化 """ +import asyncio import os import hashlib import json import re +from collections import OrderedDict from datetime import datetime -from typing import List, Optional +from typing import List, Optional, Dict, Tuple, Set from pydantic import BaseModel, Field from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext @@ -36,6 +39,23 @@ class DialogExtractionResponse(BaseModel): keywords: List[str] = Field(default_factory=list) +class MessageImportanceResponse(BaseModel): + """消息重要性批量判断的结构化返回(用于LLM语义判断)。 + + - importance_scores: 消息索引到重要性分数的映射 (0-10分) + - reasons: 可选的判断理由 + """ + importance_scores: Dict[int, int] = Field(default_factory=dict, description="消息索引到重要性分数(0-10)的映射") + reasons: Optional[Dict[int, str]] = Field(default_factory=dict, description="可选的判断理由") + + +class QAPair(BaseModel): + """问答对模型,用于识别和保护对话中的问答结构。""" + question_idx: int = Field(..., description="问题消息的索引") + answer_idx: int = Field(..., description="答案消息的索引") + confidence: float = Field(default=1.0, description="问答对的置信度(0-1)") + + class SemanticPruner: """语义剪枝:在预处理与分块之间过滤与场景不相关内容。 @@ -43,109 +63,353 @@ class SemanticPruner: 重要信息(时间、编号、金额、联系方式、地址等)优先保留。 """ - def __init__(self, config: Optional[PruningConfig] = None, llm_client=None): - cfg_dict = get_pruning_config() if config is None else config.model_dump() - self.config = PruningConfig.model_validate(cfg_dict) + def __init__(self, config: Optional[PruningConfig] = None, llm_client=None, language: str = "zh", max_concurrent: int = 5): + # 如果没有提供config,使用默认配置 + if config is None: + # 使用默认的剪枝配置 + config = PruningConfig( + pruning_switch=False, # 默认关闭剪枝,保持向后兼容 + pruning_scene="education", + pruning_threshold=0.5 + ) + + self.config = config self.llm_client = llm_client + self.language = language # 保存语言配置 + self.max_concurrent = max_concurrent # 新增:最大并发数 + # Load Jinja2 template self.template = prompt_env.get_template("extracat_Pruning.jinja2") - # 对话抽取缓存:避免同一对话重复调用 LLM / 重复渲染 - self._dialog_extract_cache: dict[str, DialogExtractionResponse] = {} + + # 对话抽取缓存:使用 OrderedDict 实现 LRU 缓存 + self._dialog_extract_cache: OrderedDict[str, DialogExtractionResponse] = OrderedDict() + self._cache_max_size = 1000 # 缓存大小限制 + # 运行日志:收集关键终端输出,便于写入 JSON self.run_logs: List[str] = [] - # 采用顺序处理,移除并发配置以简化与稳定执行 + + # 扩展的填充词库(包含表情符号和网络用语) + self._extended_fillers = [ + # 基础寒暄 + "你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦", + "好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢", + "拜拜", "再见", "88", "拜", "回见", + # 口头禅 + "哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia", + "额", "呃", "啊", "诶", "唉", "哎", "嗯哼", + # 确认词 + "是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了", + # 标点和符号 + "。。。", "...", "???", "???", "!!!", "!!!", + # 表情符号(文本形式) + "[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]", + "[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]", + "[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]", + "[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]", + # 网络用语 + "hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok", + "emmm", "emm", "em", "mmp", "wtf", "omg", + ] def _is_important_message(self, message: ConversationMessage) -> bool: """基于启发式规则识别重要信息消息,优先保留。 - - 含日期/时间(如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)。 - - 含编号/ID/订单号/申请号/账号/电话/金额等关键字段。 - - 关键词:"时间"、"日期"、"编号"、"订单"、"流水"、"金额"、"¥"、"元"、"电话"、"手机号"、"邮箱"、"地址"。 + 改进版:增强了规则覆盖范围,包括: + - 含日期/时间(如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午) + - 含编号/ID/订单号/申请号/账号/电话/金额等关键字段 + - 关键词:"时间"、"日期"、"编号"、"订单"、"流水"、"金额"、"¥"、"元"、"电话"、"手机号"、"邮箱"、"地址" + - 新增:问句识别、决策性语句、承诺性语句 """ - import re text = message.msg.strip() if not text: return False + patterns = [ - r"\b\d{4}-\d{1,2}-\d{1,2}\b", - r"\b\d{1,2}:\d{2}\b", + # 原有模式 + r"\d{4}-\d{1,2}-\d{1,2}", # 修复:移除 \b 边界,因为中文前后没有单词边界 + r"\d{1,2}:\d{2}", # 修复:移除 \b r"\d{4}年\d{1,2}月\d{1,2}日", - r"上午|下午|AM|PM", - r"订单号|工单|申请号|编号|ID|账号|账户", - r"电话|手机号|微信|QQ|邮箱", - r"地址|地点", - r"金额|费用|价格|¥|¥|\d+元", - r"时间|日期|有效期|截止", + r"上午|下午|AM|PM|今天|明天|后天|昨天|前天|本周|下周|上周|本月|下月|上月", + r"订单号|工单|申请号|编号|ID|账号|账户|流水号|单号", + r"电话|手机号|微信|QQ|邮箱|联系方式", + r"地址|地点|位置|门牌号", + r"金额|费用|价格|¥|¥|\d+元|人民币|美元|欧元", + r"时间|日期|有效期|截止|期限|到期", + # 新增模式 + r"什么|为什么|怎么|如何|哪里|哪个|谁|多少|几点|何时", # 问句关键词 + r"必须|一定|务必|需要|要求|规定|应该", # 决策性语句 + r"承诺|保证|确保|负责|同意|答应", # 承诺性语句 + r"\d{11}", # 11位手机号 + r"\d{3,4}-\d{7,8}", # 固定电话 + r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", # 邮箱 ] + for p in patterns: if re.search(p, text, flags=re.IGNORECASE): return True + + # 检查是否为问句(以问号结尾或包含疑问词) + if text.endswith("?") or text.endswith("?"): + return True + return False + def _importance_score(self, message: ConversationMessage) -> int: """为重要消息打分,用于在保留比例内优先保留更关键的内容。 - 简单启发:匹配到的类别越多、越关键分值越高。 + 改进版:更细致的评分体系(0-10分) """ - import re text = message.msg.strip() score = 0 + weights = [ - (r"\b\d{4}-\d{1,2}-\d{1,2}\b", 3), - (r"\b\d{1,2}:\d{2}\b", 2), + # 高优先级(4-5分) + (r"订单号|工单|申请号|编号|ID|账号|账户", 5), + (r"金额|费用|价格|¥|¥|\d+元", 5), + (r"\d{11}", 4), # 手机号 + (r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱 + + # 中优先级(2-3分) + (r"\d{4}-\d{1,2}-\d{1,2}", 3), # 修复:移除 \b (r"\d{4}年\d{1,2}月\d{1,2}日", 3), - (r"订单号|工单|申请号|编号|ID|账号|账户", 4), - (r"电话|手机号|微信|QQ|邮箱", 3), - (r"地址|地点", 2), - (r"金额|费用|价格|¥|¥|\d+元", 4), - (r"时间|日期|有效期|截止", 2), + (r"电话|手机号|微信|QQ|联系方式", 3), + (r"地址|地点|位置", 2), + (r"时间|日期|有效期|截止|明天|后天|下周|下月", 2), # 新增时间相关词 + + # 低优先级(1分) + (r"\d{1,2}:\d{2}", 1), # 修复:移除 \b + (r"上午|下午|AM|PM", 1), ] + for p, w in weights: if re.search(p, text, flags=re.IGNORECASE): score += w - return score + + # 问句加分 + if text.endswith("?") or text.endswith("?"): + score += 2 + + # 长度加分(较长的消息通常包含更多信息) + if len(text) > 50: + score += 1 + if len(text) > 100: + score += 1 + + return min(score, 10) # 最高10分 def _is_filler_message(self, message: ConversationMessage) -> bool: """检测典型寒暄/口头禅/确认类短消息,用于跳过LLM分类以加速。 + 改进版:扩展了填充词库,支持表情符号和网络用语 满足以下之一视为填充消息: - - 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体; - - 常见词:你好/您好/在吗/嗯/嗯嗯/哦/好的/好/行/可以/不可以/谢谢/拜拜/再见/哈哈/呵呵/哈哈哈/。。。/??。 + - 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体 + - 在扩展填充词库中 + - 纯表情符号 """ - import re t = message.msg.strip() if not t: return True - # 常见填充语 - fillers = [ - "你好", "您好", "在吗", "嗯", "嗯嗯", "哦", "好的", "好", "行", "可以", "不可以", "谢谢", - "拜拜", "再见", "哈哈", "呵呵", "哈哈哈", "。。。", "??", "??" - ] - if t in fillers: + + # 检查是否在扩展填充词库中 + if t in self._extended_fillers: return True + + # 检查是否为纯表情符号(方括号包裹) + if re.fullmatch(r"(\[[^\]]+\])+", t): + return True + + # 检查是否为纯emoji(Unicode表情) + emoji_pattern = re.compile( + "[" + "\U0001F600-\U0001F64F" # 表情符号 + "\U0001F300-\U0001F5FF" # 符号和象形文字 + "\U0001F680-\U0001F6FF" # 交通和地图符号 + "\U0001F1E0-\U0001F1FF" # 旗帜 + "\U00002702-\U000027B0" + "\U000024C2-\U0001F251" + "]+", flags=re.UNICODE + ) + if emoji_pattern.fullmatch(t): + return True + # 长度与字符类型判断 if len(t) <= 8: # 非数字、无关键实体的短文本 if not re.search(r"[0-9]", t) and not self._is_important_message(message): # 主要是标点或简单确认词 - if re.fullmatch(r"[。!?,.!?…·\s]+", t) or t in fillers: + if re.fullmatch(r"[。!?,.!?…·\s]+", t): return True + return False + + async def _batch_evaluate_importance_with_llm( + self, + messages: List[ConversationMessage], + context: str = "" + ) -> Dict[int, int]: + """使用LLM批量评估消息的重要性(语义层面)。 + + Args: + messages: 消息列表 + context: 对话上下文(可选) + + Returns: + 消息索引到重要性分数(0-10)的映射 + """ + if not self.llm_client or not messages: + return {} + + # 构建批量评估的提示词 + msg_list = [] + for idx, msg in enumerate(messages): + msg_list.append(f"{idx}. {msg.msg}") + + msg_text = "\n".join(msg_list) + + prompt = f"""请评估以下消息的重要性,给每条消息打分(0-10分): +- 0-2分:无意义的寒暄、口头禅、纯表情 +- 3-5分:一般性对话,有一定信息量但不关键 +- 6-8分:包含重要信息(时间、地点、人物、事件等) +- 9-10分:关键决策、承诺、重要数据 + +对话上下文: +{context if context else "无"} + +待评估的消息: +{msg_text} + +请以JSON格式返回,格式为: +{{ + "importance_scores": {{ + "0": 分数, + "1": 分数, + ... + }} +}} +""" + + try: + messages_for_llm = [ + {"role": "system", "content": "你是一个专业的对话分析助手,擅长评估消息的重要性。"}, + {"role": "user", "content": prompt} + ] + + response = await self.llm_client.response_structured( + messages_for_llm, + MessageImportanceResponse + ) + + # 转换字符串键为整数键 + return {int(k): v for k, v in response.importance_scores.items()} + except Exception as e: + self._log(f"[剪枝-LLM] 批量重要性评估失败: {str(e)[:100]}") + return {} + + def _identify_qa_pairs(self, messages: List[ConversationMessage]) -> List[QAPair]: + """识别对话中的问答对,用于保护问答结构的完整性。 + + Args: + messages: 消息列表 + + Returns: + 问答对列表 + """ + qa_pairs = [] + + for i in range(len(messages) - 1): + current_msg = messages[i].msg.strip() + next_msg = messages[i + 1].msg.strip() + + # 简单规则:如果当前消息是问句,下一条消息可能是答案 + is_question = ( + current_msg.endswith("?") or + current_msg.endswith("?") or + any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗"]) + ) + + if is_question and next_msg: + # 检查下一条消息是否像答案(不是另一个问句) + is_answer = not (next_msg.endswith("?") or next_msg.endswith("?")) + + if is_answer: + qa_pairs.append(QAPair( + question_idx=i, + answer_idx=i + 1, + confidence=0.8 # 基于规则的置信度 + )) + + return qa_pairs + + def _get_protected_indices( + self, + messages: List[ConversationMessage], + qa_pairs: List[QAPair], + window_size: int = 2 + ) -> Set[int]: + """获取需要保护的消息索引集合(问答对+上下文窗口)。 + + Args: + messages: 消息列表 + qa_pairs: 问答对列表 + window_size: 上下文窗口大小(前后各保留几条消息) + + Returns: + 需要保护的消息索引集合 + """ + protected = set() + + for qa_pair in qa_pairs: + # 保护问答对本身 + protected.add(qa_pair.question_idx) + protected.add(qa_pair.answer_idx) + + # 保护上下文窗口 + for offset in range(-window_size, window_size + 1): + q_idx = qa_pair.question_idx + offset + a_idx = qa_pair.answer_idx + offset + + if 0 <= q_idx < len(messages): + protected.add(q_idx) + if 0 <= a_idx < len(messages): + protected.add(a_idx) + + return protected async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse: """对话级一次性抽取:从整段对话中提取重要信息并判定相关性。 - - 仅使用 LLM 结构化输出; + 改进版: + - LRU缓存管理 + - 重试机制 + - 降级策略 """ # 缓存命中则直接返回(场景+内容作为键) cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest() + + # LRU缓存:如果命中,移到末尾(最近使用) if cache_key in self._dialog_extract_cache: + self._dialog_extract_cache.move_to_end(cache_key) return self._dialog_extract_cache[cache_key] - rendered = self.template.render(pruning_scene=self.config.pruning_scene, dialog_text=dialog_text) - log_template_rendering("extracat_Pruning.jinja2", {"pruning_scene": self.config.pruning_scene}) + # LRU缓存大小限制:超过限制时删除最旧的条目 + if len(self._dialog_extract_cache) >= self._cache_max_size: + # 删除最旧的条目(OrderedDict的第一个) + oldest_key = next(iter(self._dialog_extract_cache)) + del self._dialog_extract_cache[oldest_key] + self._log(f"[剪枝-缓存] LRU缓存已满,删除最旧条目") + + rendered = self.template.render( + pruning_scene=self.config.pruning_scene, + dialog_text=dialog_text, + language=self.language + ) + log_template_rendering("extracat_Pruning.jinja2", { + "pruning_scene": self.config.pruning_scene, + "language": self.language + }) log_prompt_rendering("pruning-extract", rendered) - # 强制使用 LLM;移除正则回退 + # 强制使用 LLM if not self.llm_client: raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。") @@ -153,12 +417,32 @@ class SemanticPruner: {"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"}, {"role": "user", "content": rendered}, ] - try: - ex = await self.llm_client.response_structured(messages, DialogExtractionResponse) - self._dialog_extract_cache[cache_key] = ex - return ex - except Exception as e: - raise RuntimeError("LLM 结构化抽取失败;请检查 LLM 配置或重试。") from e + + # 重试机制 + max_retries = 3 + for attempt in range(max_retries): + try: + ex = await self.llm_client.response_structured(messages, DialogExtractionResponse) + self._dialog_extract_cache[cache_key] = ex + return ex + except Exception as e: + if attempt < max_retries - 1: + self._log(f"[剪枝-LLM] 第 {attempt + 1} 次尝试失败,重试中... 错误: {str(e)[:100]}") + await asyncio.sleep(0.5 * (attempt + 1)) # 指数退避 + continue + else: + # 降级策略:标记为相关,避免误删 + self._log(f"[剪枝-LLM] LLM 调用失败 {max_retries} 次,使用降级策略(标记为相关)") + fallback_response = DialogExtractionResponse( + is_related=True, + times=[], + ids=[], + amounts=[], + contacts=[], + addresses=[], + keywords=[] + ) + return fallback_response def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool: """判断消息是否包含任意抽取到的重要片段。""" @@ -248,12 +532,15 @@ class SemanticPruner: async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]: """数据集层面:全局消息级剪枝,保留所有对话。 - - 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动。 - - 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留。 - - 删除总量 = 阈值 * 全部不相关可删消息数,按可删容量比例分配;顺序删除。 - - 保证每段对话至少保留1条消息,不会删除整段对话。 + 改进版: + - 并发处理对话级相关性判断 + - 问答对识别和保护 + - 优化删除策略,保持上下文连贯性 + - 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动 + - 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留 + - 保证每段对话至少保留1条消息,不会删除整段对话 """ - # 如果剪枝功能关闭,直接返回原始数据集。 + # 如果剪枝功能关闭,直接返回原始数据集 if not self.config.pruning_switch: return dialogs @@ -264,29 +551,36 @@ class SemanticPruner: proportion = 0.9 if proportion < 0.0: proportion = 0.0 - evaluated_dialogs = [] # list of dicts: {dialog, is_related} self._log( f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}" ) - # 对话级相关性分类(一次性对整段对话文本进行判断,顺序执行并复用缓存) - evaluated_dialogs = [] - for idx, dd in enumerate(dialogs): - try: - ex = await self._extract_dialog_important(dd.content) - evaluated_dialogs.append({ - "dialog": dd, - "is_related": bool(ex.is_related), - "index": idx, - "extraction": ex - }) - except Exception: - evaluated_dialogs.append({ - "dialog": dd, - "is_related": True, - "index": idx, - "extraction": None - }) + + # 并发处理对话级相关性分类 + semaphore = asyncio.Semaphore(self.max_concurrent) + + async def classify_dialog(idx: int, dd: DialogData): + async with semaphore: + try: + ex = await self._extract_dialog_important(dd.content) + return { + "dialog": dd, + "is_related": bool(ex.is_related), + "index": idx, + "extraction": ex + } + except Exception as e: + self._log(f"[剪枝-并发] 对话 {idx} 分类失败: {str(e)[:100]}") + return { + "dialog": dd, + "is_related": True, # 失败时标记为相关,避免误删 + "index": idx, + "extraction": None + } + + # 并发执行所有对话的分类 + tasks = [classify_dialog(idx, dd) for idx, dd in enumerate(dialogs)] + evaluated_dialogs = await asyncio.gather(*tasks) # 统计相关 / 不相关对话 not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]] @@ -300,7 +594,6 @@ class SemanticPruner: inds = [i["index"] + 1 for i in items] if len(inds) <= cap: return inds - # 超过上限时只打印前cap个,并标注总数 return inds[:cap] + ["...", f"共{len(inds)}个"] rel_inds = _fmt_indices(related_dialogs) @@ -309,59 +602,83 @@ class SemanticPruner: result: List[DialogData] = [] if not_related_dialogs: - # 为每个不相关对话进行一次性抽取,识别重要/不重要(避免逐条 LLM) + # 为每个不相关对话进行分析 per_dialog_info = {} total_unrelated = 0 - total_capacity = 0 + for d in not_related_dialogs: dd = d["dialog"] extraction = d.get("extraction") if extraction is None: extraction = await self._extract_dialog_important(dd.content) + # 合并所有重要标记 tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords msgs = dd.context.msgs - # 分类消息 - imp_unrel_msgs = [m for m in msgs if self._msg_matches_tokens(m, tokens) or self._is_important_message(m)] - unimp_unrel_msgs = [m for m in msgs if m not in imp_unrel_msgs] + + # 识别问答对 + qa_pairs = self._identify_qa_pairs(msgs) + protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=1) + + # 分类消息(考虑问答对保护) + imp_unrel_msgs = [] + unimp_unrel_msgs = [] + + for idx, m in enumerate(msgs): + # 问答对中的消息自动标记为重要 + if idx in protected_indices: + imp_unrel_msgs.append((idx, m)) + elif self._msg_matches_tokens(m, tokens) or self._is_important_message(m): + imp_unrel_msgs.append((idx, m)) + elif not self._is_filler_message(m): + unimp_unrel_msgs.append((idx, m)) + # 填充消息不加入任何列表,优先删除 + # 重要消息按重要性排序 - imp_sorted_ids = [id(m) for m in sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))] + imp_sorted = sorted(imp_unrel_msgs, key=lambda x: self._importance_score(x[1])) + imp_sorted_ids = [id(m) for _, m in imp_sorted] + info = { "dialog": dd, "total_msgs": len(msgs), "unrelated_count": len(msgs), "imp_ids_sorted": imp_sorted_ids, - "unimp_ids": [id(m) for m in unimp_unrel_msgs], + "unimp_ids": [id(m) for _, m in unimp_unrel_msgs], + "protected_indices": protected_indices, + "qa_pairs_count": len(qa_pairs), } per_dialog_info[d["index"]] = info total_unrelated += info["unrelated_count"] - # 全局删除配额:比例作用于全部不相关消息(重要+不重要) + + # 全局删除配额计算 global_delete = int(total_unrelated * proportion) if proportion > 0 and total_unrelated > 0 and global_delete == 0: global_delete = 1 - # 每段的最大可删容量:不重要全部 + 重要最多删除 floor(len(重要)*比例),且至少保留1条消息 + + # 每段的最大可删容量 capacities = [] for d in not_related_dialogs: idx = d["index"] info = per_dialog_info[idx] - # 统计重要数量 imp_count = len(info["imp_ids_sorted"]) unimp_count = len(info["unimp_ids"]) imp_cap = int(imp_count * proportion) cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1)) capacities.append(cap) + total_capacity = sum(capacities) if global_delete > total_capacity: - print(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}(重要消息按比例保留)。将按最大可删执行。") + self._log(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}。将按最大可删执行。") global_delete = total_capacity - # 配额分配:按不相关消息占比分配到各对话,但不超过各自容量 + # 配额分配 alloc = [] for i, d in enumerate(not_related_dialogs): idx = d["index"] info = per_dialog_info[idx] share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0 alloc.append(min(share, capacities[i])) + allocated = sum(alloc) rem = global_delete - allocated turn = 0 @@ -378,34 +695,40 @@ class SemanticPruner: break turn += 1 - # 应用删除:相关对话不动;不相关按分配先删不重要,再删重要(低分优先) + # 应用删除 total_deleted_confirm = 0 for d in evaluated_dialogs: dd = d["dialog"] msgs = dd.context.msgs original = len(msgs) + if d["is_related"]: result.append(dd) continue + idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None) if idx_in_unrel is None: result.append(dd) continue + quota = alloc[idx_in_unrel] info = per_dialog_info[d["index"]] - # 计算本对话重要最多可删数量 + + # 计算删除ID imp_count = len(info["imp_ids_sorted"]) imp_del_cap = int(imp_count * proportion) - # 先构造顺序删除的"不重要ID集合"(按出现顺序前 quota 条) + unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))]) del_unimp = min(quota, len(unimp_delete_ids)) rem_quota = quota - del_unimp - # 再从重要里选低分优先的删除ID(不超过 imp_del_cap) + imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)]) + deleted_here = 0 actual_unimp_deleted = 0 actual_imp_deleted = 0 kept = [] + for m in msgs: mid = id(m) if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp: @@ -417,26 +740,30 @@ class SemanticPruner: deleted_here += 1 continue kept.append(m) + if not kept and msgs: kept = [msgs[0]] + dd.context.msgs = kept total_deleted_confirm += deleted_here + + qa_info = f",问答对={info['qa_pairs_count']}" if info['qa_pairs_count'] > 0 else "" self._log( - f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}" + f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}{qa_info}" ) result.append(dd) - self._log(f"[剪枝-数据集] 全局消息级顺序剪枝完成,总删除 {total_deleted_confirm} 条(不相关消息,重要按比例保留)。") + + self._log(f"[剪枝-数据集] 全局消息级剪枝完成,总删除 {total_deleted_confirm} 条(保护问答对和上下文)。") else: - # 全部相关:不执行剪枝 result = [d["dialog"] for d in evaluated_dialogs] + self._log(f"[剪枝-数据集] 剩余对话数={len(result)}") - # 将本次剪枝阶段的终端输出保存为 JSON 文件(仅在剪枝器内部完成) + # 保存日志 try: from app.core.config import settings settings.ensure_memory_output_dir() log_output_path = settings.get_memory_output_path("pruned_terminal.json") - # 去除日志前缀标签(如 [剪枝-数据集]、[剪枝-对话])后再解析为结构化字段保存 sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs] payload = self._parse_logs_to_structured(sanitized_logs) with open(log_output_path, "w", encoding="utf-8") as f: @@ -448,6 +775,7 @@ class SemanticPruner: if not result: print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断") return dialogs + return result def _log(self, msg: str) -> None: From 0655ff4a9133322592bdd69cb3feb807d14f1765 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 27 Feb 2026 16:45:34 +0800 Subject: [PATCH 028/164] [fix]Correct the flaws existing in the semantic segmentation method --- .../knowledge_extraction/chunk_extraction.py | 205 ++++++++++++++---- 1 file changed, 160 insertions(+), 45 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py index 40e98507..bbbf1c51 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py @@ -1,5 +1,7 @@ import os -from typing import Optional +from typing import Optional, List, Any +from enum import Enum +from pathlib import Path from app.core.logging_config import get_memory_logger from app.core.memory.models.message_models import DialogData, Chunk @@ -10,6 +12,20 @@ from app.core.memory.utils.config.config_utils import get_chunker_config logger = get_memory_logger(__name__) +class ChunkerStrategy(Enum): + """Supported chunking strategies.""" + RECURSIVE = "RecursiveChunker" + SEMANTIC = "SemanticChunker" + LATE = "LateChunker" + NEURAL = "NeuralChunker" + LLM = "LLMChunker" + + @classmethod + def get_valid_strategies(cls) -> List[str]: + """Get list of valid strategy names.""" + return [strategy.value for strategy in cls] + + class DialogueChunker: """A class that processes dialogues and fills them with chunks based on a specified strategy. @@ -17,23 +33,51 @@ class DialogueChunker: of different chunking strategies to dialogue data. """ - def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client=None): + def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client: Optional[Any] = None): """Initialize the DialogueChunker with a specific chunking strategy. Args: chunker_strategy: The chunking strategy to use (default: RecursiveChunker) - Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker + Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker, LLMChunker + llm_client: LLM client instance (required for LLMChunker strategy) + + Raises: + ValueError: If chunker_strategy is invalid or required parameters are missing """ - self.chunker_strategy = chunker_strategy - chunker_config_dict = get_chunker_config(chunker_strategy) - self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict) + # Validate strategy + valid_strategies = ChunkerStrategy.get_valid_strategies() + if chunker_strategy not in valid_strategies: + raise ValueError( + f"Invalid chunker_strategy: '{chunker_strategy}'. " + f"Must be one of {valid_strategies}" + ) - if self.chunker_config.chunker_strategy == "LLMChunker": - self.chunker_client = ChunkerClient(self.chunker_config, llm_client) - else: - self.chunker_client = ChunkerClient(self.chunker_config) + self.chunker_strategy = chunker_strategy + logger.info(f"Initializing DialogueChunker with strategy: {chunker_strategy}") + + try: + # Load and validate configuration + chunker_config_dict = get_chunker_config(chunker_strategy) + if not chunker_config_dict: + raise ValueError(f"Failed to load configuration for strategy: {chunker_strategy}") + + self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict) + + # Initialize chunker client + if self.chunker_config.chunker_strategy == "LLMChunker": + if not llm_client: + raise ValueError("llm_client is required for LLMChunker strategy") + self.chunker_client = ChunkerClient(self.chunker_config, llm_client) + else: + self.chunker_client = ChunkerClient(self.chunker_config) + + logger.info(f"DialogueChunker initialized successfully with strategy: {chunker_strategy}") + + except Exception as e: + logger.error(f"Failed to initialize DialogueChunker: {e}", exc_info=True) + raise - async def process_dialogue(self, dialogue: DialogData) -> list[Chunk]: + async def process_dialogue(self, dialogue: DialogData) -> List[Chunk]: """Process a dialogue by generating chunks and adding them to the DialogData object. Args: @@ -43,54 +87,125 @@ class DialogueChunker: A list of Chunk objects Raises: - ValueError: If chunking fails or returns empty chunks + ValueError: If dialogue is invalid or chunking fails + Exception: If chunking process encounters an error """ - result_dialogue = await self.chunker_client.generate_chunks(dialogue) - chunks = result_dialogue.chunks - - if not chunks or len(chunks) == 0: + # Validate input + if not dialogue: + raise ValueError("dialogue cannot be None") + + if not dialogue.context or not dialogue.context.msgs: raise ValueError( - f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. " - f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, " - f"Strategy: {self.chunker_config.chunker_strategy}" + f"Dialogue {dialogue.ref_id} has no messages to chunk. " + f"Context: {dialogue.context is not None}, " + f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}" ) + + logger.info( + f"Processing dialogue {dialogue.ref_id} with {len(dialogue.context.msgs)} messages " + f"using strategy: {self.chunker_strategy}" + ) + + try: + # Generate chunks + result_dialogue = await self.chunker_client.generate_chunks(dialogue) + chunks = result_dialogue.chunks - return chunks + # Validate results + if not chunks or len(chunks) == 0: + raise ValueError( + f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. " + f"Messages: {len(dialogue.context.msgs)}, " + f"Content length: {len(dialogue.content) if dialogue.content else 0}, " + f"Strategy: {self.chunker_config.chunker_strategy}" + ) - def save_chunking_results(self, dialogue: DialogData, output_path: Optional[str] = None) -> str: + logger.info( + f"Successfully generated {len(chunks)} chunks for dialogue {dialogue.ref_id}. " + f"Total characters processed: {len(dialogue.content) if dialogue.content else 0}" + ) + + return chunks + + except ValueError: + # Re-raise validation errors + raise + except Exception as e: + logger.error( + f"Error processing dialogue {dialogue.ref_id} with strategy {self.chunker_strategy}: {e}", + exc_info=True + ) + raise + + def save_chunking_results( + self, + chunks: List[Chunk], + dialogue: DialogData, + output_path: Optional[str] = None, + preview_length: int = 100 + ) -> str: """Save the chunking results to a file and return the output path. Args: - dialogue: The processed DialogData object with chunks - output_path: Optional path to save the output + chunks: List of Chunk objects to save + dialogue: The DialogData object that was processed + output_path: Optional path to save the output (defaults to current directory) + preview_length: Maximum length of content preview (default: 100) Returns: The path where the output was saved + + Raises: + ValueError: If chunks or dialogue is invalid + IOError: If file writing fails """ - if not output_path: - output_path = os.path.join( - os.path.dirname(__file__), "..", "..", - f"chunker_output_{self.chunker_strategy.lower()}.txt" - ) - - output_lines = [ - f"=== Chunking Results ({self.chunker_strategy}) ===", - f"Dialogue ID: {dialogue.ref_id}", - f"Original conversation has {len(dialogue.context.msgs)} messages", - f"Total characters: {len(dialogue.content)}", - f"Generated {len(dialogue.chunks)} chunks:" - ] + # Validate input + if not chunks: + raise ValueError("chunks list cannot be empty") + if not dialogue: + raise ValueError("dialogue cannot be None") - for i, chunk in enumerate(dialogue.chunks): - output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters") - output_lines.append(f" Content preview: {chunk.content}...") - if chunk.metadata: - output_lines.append(f" Metadata: {chunk.metadata}") + # Generate default output path if not provided + if not output_path: + output_dir = Path(__file__).parent.parent.parent + output_path = str(output_dir / f"chunker_output_{self.chunker_strategy.lower()}.txt") + + logger.info(f"Saving chunking results to: {output_path}") + + try: + # Prepare output content + output_lines = [ + f"=== Chunking Results ({self.chunker_strategy}) ===", + f"Dialogue ID: {dialogue.ref_id}", + f"Original conversation has {len(dialogue.context.msgs) if dialogue.context else 0} messages", + f"Total characters: {len(dialogue.content) if dialogue.content else 0}", + f"Generated {len(chunks)} chunks:", + "" + ] + + for i, chunk in enumerate(chunks, 1): + content_preview = chunk.content[:preview_length] if chunk.content else "" + if len(chunk.content) > preview_length: + content_preview += "..." + + output_lines.append(f" Chunk {i}: {len(chunk.content)} characters") + output_lines.append(f" Content preview: {content_preview}") + if chunk.metadata: + output_lines.append(f" Metadata: {chunk.metadata}") + output_lines.append("") - with open(output_path, "w", encoding="utf-8") as f: - f.write("\n".join(output_lines)) + # Write to file + with open(output_path, "w", encoding="utf-8") as f: + f.write("\n".join(output_lines)) - logger.info(f"Chunking results saved to: {output_path}") - return output_path + logger.info(f"Successfully saved chunking results to: {output_path}") + return output_path + + except IOError as e: + logger.error(f"Failed to write chunking results to {output_path}: {e}", exc_info=True) + raise + except Exception as e: + logger.error(f"Unexpected error saving chunking results: {e}", exc_info=True) + raise From 96590941cf81e1de21c54d411c890cb1d8195f51 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Sat, 28 Feb 2026 17:18:42 +0800 Subject: [PATCH 029/164] [add]The semantic pruning function is activated, removing the protection of question-answer pairs. --- .../core/memory/agent/utils/get_dialogs.py | 57 +- .../data_preprocessing/data_pruning.py | 511 ++++++++---------- .../data_preprocessing/scene_config.py | 326 +++++++++++ .../extraction_orchestrator.py | 16 +- 4 files changed, 619 insertions(+), 291 deletions(-) create mode 100644 api/app/core/memory/storage_services/extraction_engine/data_preprocessing/scene_config.py diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index bfb0f675..22555fff 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -21,7 +21,7 @@ async def get_chunked_dialogs( end_user_id: Group identifier messages: Structured message list [{"role": "user", "content": "..."}, ...] ref_id: Reference identifier - config_id: Configuration ID for processing + config_id: Configuration ID for processing (used to load pruning config) Returns: List of DialogData objects with generated chunks @@ -57,6 +57,61 @@ async def get_chunked_dialogs( end_user_id=end_user_id, config_id=config_id ) + + # 语义剪枝步骤(在分块之前) + try: + from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner + from app.core.memory.models.config_models import PruningConfig + from app.db import get_db_context + from app.services.memory_config_service import MemoryConfigService + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + + # 加载剪枝配置 + pruning_config = None + if config_id: + try: + with get_db_context() as db: + # 使用 MemoryConfigService 加载完整的 MemoryConfig 对象 + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + service_name="semantic_pruning" + ) + + if memory_config: + pruning_config = PruningConfig( + pruning_switch=memory_config.pruning_enabled, + pruning_scene=memory_config.pruning_scene or "education", + pruning_threshold=memory_config.pruning_threshold + ) + logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}") + + # 获取LLM客户端用于剪枝 + if pruning_config.pruning_switch: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client_from_config(memory_config) + + # 执行剪枝 - 使用 prune_dataset 支持消息级剪枝 + pruner = SemanticPruner(config=pruning_config, llm_client=llm_client) + original_msg_count = len(dialog_data.context.msgs) + + # 使用 prune_dataset 而不是 prune_dialog + # prune_dataset 会进行消息级剪枝,即使对话整体相关也会删除不重要消息 + pruned_dialogs = await pruner.prune_dataset([dialog_data]) + + if pruned_dialogs: + dialog_data = pruned_dialogs[0] + remaining_msg_count = len(dialog_data.context.msgs) + deleted_count = original_msg_count - remaining_msg_count + logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)") + else: + logger.warning("[剪枝] prune_dataset 返回空列表") + else: + logger.info("[剪枝] 配置中剪枝开关关闭,跳过剪枝") + except Exception as e: + logger.warning(f"[剪枝] 加载配置失败,跳过剪枝: {e}", exc_info=True) + except Exception as e: + logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True) chunker = DialogueChunker(chunker_strategy) extracted_chunks = await chunker.process_dialogue(dialog_data) diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index 2d0142c6..d932c542 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -22,6 +22,10 @@ from app.core.memory.models.message_models import DialogData, ConversationMessag from app.core.memory.models.config_models import PruningConfig from app.core.memory.utils.config.config_utils import get_pruning_config from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering +from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import ( + SceneConfigRegistry, + ScenePatterns +) class DialogExtractionResponse(BaseModel): @@ -78,6 +82,20 @@ class SemanticPruner: self.language = language # 保存语言配置 self.max_concurrent = max_concurrent # 新增:最大并发数 + # 加载场景特定配置 + self.scene_config: ScenePatterns = SceneConfigRegistry.get_config( + self.config.pruning_scene, + fallback_to_generic=True + ) + + # 检查场景是否有专门支持 + is_supported = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene) + if is_supported: + self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用专门配置") + else: + self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 未预定义,使用通用配置(保守策略)") + self._log(f"[剪枝-初始化] 支持的场景: {SceneConfigRegistry.get_all_scenes()}") + # Load Jinja2 template self.template = prompt_env.get_template("extracat_Pruning.jinja2") @@ -87,108 +105,80 @@ class SemanticPruner: # 运行日志:收集关键终端输出,便于写入 JSON self.run_logs: List[str] = [] - - # 扩展的填充词库(包含表情符号和网络用语) - self._extended_fillers = [ - # 基础寒暄 - "你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦", - "好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢", - "拜拜", "再见", "88", "拜", "回见", - # 口头禅 - "哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia", - "额", "呃", "啊", "诶", "唉", "哎", "嗯哼", - # 确认词 - "是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了", - # 标点和符号 - "。。。", "...", "???", "???", "!!!", "!!!", - # 表情符号(文本形式) - "[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]", - "[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]", - "[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]", - "[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]", - # 网络用语 - "hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok", - "emmm", "emm", "em", "mmp", "wtf", "omg", - ] def _is_important_message(self, message: ConversationMessage) -> bool: """基于启发式规则识别重要信息消息,优先保留。 - 改进版:增强了规则覆盖范围,包括: - - 含日期/时间(如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午) - - 含编号/ID/订单号/申请号/账号/电话/金额等关键字段 - - 关键词:"时间"、"日期"、"编号"、"订单"、"流水"、"金额"、"¥"、"元"、"电话"、"手机号"、"邮箱"、"地址" - - 新增:问句识别、决策性语句、承诺性语句 + 改进版:使用场景特定的模式进行识别 + - 根据 pruning_scene 动态加载对应的识别规则 + - 支持教育、在线服务、外呼三个场景的特定模式 """ text = message.msg.strip() if not text: return False - patterns = [ - # 原有模式 - r"\d{4}-\d{1,2}-\d{1,2}", # 修复:移除 \b 边界,因为中文前后没有单词边界 - r"\d{1,2}:\d{2}", # 修复:移除 \b - r"\d{4}年\d{1,2}月\d{1,2}日", - r"上午|下午|AM|PM|今天|明天|后天|昨天|前天|本周|下周|上周|本月|下月|上月", - r"订单号|工单|申请号|编号|ID|账号|账户|流水号|单号", - r"电话|手机号|微信|QQ|邮箱|联系方式", - r"地址|地点|位置|门牌号", - r"金额|费用|价格|¥|¥|\d+元|人民币|美元|欧元", - r"时间|日期|有效期|截止|期限|到期", - # 新增模式 - r"什么|为什么|怎么|如何|哪里|哪个|谁|多少|几点|何时", # 问句关键词 - r"必须|一定|务必|需要|要求|规定|应该", # 决策性语句 - r"承诺|保证|确保|负责|同意|答应", # 承诺性语句 - r"\d{11}", # 11位手机号 - r"\d{3,4}-\d{7,8}", # 固定电话 - r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", # 邮箱 - ] + # 使用场景特定的模式 + all_patterns = ( + self.scene_config.high_priority_patterns + + self.scene_config.medium_priority_patterns + + self.scene_config.low_priority_patterns + ) - for p in patterns: - if re.search(p, text, flags=re.IGNORECASE): + for pattern, _ in all_patterns: + if re.search(pattern, text, flags=re.IGNORECASE): return True # 检查是否为问句(以问号结尾或包含疑问词) if text.endswith("?") or text.endswith("?"): return True + + # 检查是否包含问句关键词 + if any(keyword in text for keyword in self.scene_config.question_keywords): + return True + + # 检查是否包含决策性关键词 + if any(keyword in text for keyword in self.scene_config.decision_keywords): + return True return False def _importance_score(self, message: ConversationMessage) -> int: """为重要消息打分,用于在保留比例内优先保留更关键的内容。 - 改进版:更细致的评分体系(0-10分) + 改进版:使用场景特定的权重体系(0-10分) + - 根据场景动态调整不同信息类型的权重 + - 高优先级模式:4-6分 + - 中优先级模式:2-3分 + - 低优先级模式:1分 """ text = message.msg.strip() score = 0 - weights = [ - # 高优先级(4-5分) - (r"订单号|工单|申请号|编号|ID|账号|账户", 5), - (r"金额|费用|价格|¥|¥|\d+元", 5), - (r"\d{11}", 4), # 手机号 - (r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱 - - # 中优先级(2-3分) - (r"\d{4}-\d{1,2}-\d{1,2}", 3), # 修复:移除 \b - (r"\d{4}年\d{1,2}月\d{1,2}日", 3), - (r"电话|手机号|微信|QQ|联系方式", 3), - (r"地址|地点|位置", 2), - (r"时间|日期|有效期|截止|明天|后天|下周|下月", 2), # 新增时间相关词 - - # 低优先级(1分) - (r"\d{1,2}:\d{2}", 1), # 修复:移除 \b - (r"上午|下午|AM|PM", 1), - ] + # 使用场景特定的权重 + for pattern, weight in self.scene_config.high_priority_patterns: + if re.search(pattern, text, flags=re.IGNORECASE): + score += weight - for p, w in weights: - if re.search(p, text, flags=re.IGNORECASE): - score += w + for pattern, weight in self.scene_config.medium_priority_patterns: + if re.search(pattern, text, flags=re.IGNORECASE): + score += weight + + for pattern, weight in self.scene_config.low_priority_patterns: + if re.search(pattern, text, flags=re.IGNORECASE): + score += weight # 问句加分 if text.endswith("?") or text.endswith("?"): score += 2 + # 包含问句关键词加分 + if any(keyword in text for keyword in self.scene_config.question_keywords): + score += 1 + + # 包含决策性关键词加分 + if any(keyword in text for keyword in self.scene_config.decision_keywords): + score += 2 + # 长度加分(较长的消息通常包含更多信息) if len(text) > 50: score += 1 @@ -198,20 +188,35 @@ class SemanticPruner: return min(score, 10) # 最高10分 def _is_filler_message(self, message: ConversationMessage) -> bool: - """检测典型寒暄/口头禅/确认类短消息,用于跳过LLM分类以加速。 + """检测典型寒暄/口头禅/确认类短消息。 - 改进版:扩展了填充词库,支持表情符号和网络用语 + 改进版:更严格的填充消息判断,避免误删场景相关内容 满足以下之一视为填充消息: - - 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体 - - 在扩展填充词库中 + - 纯标点或空白 + - 在场景特定填充词库中(精确匹配) - 纯表情符号 + - 常见寒暄(精确匹配短语) + + 注意:不再使用长度判断,避免误删短但重要的消息 """ t = message.msg.strip() if not t: return True - # 检查是否在扩展填充词库中 - if t in self._extended_fillers: + # 检查是否在场景特定填充词库中(精确匹配) + if t in self.scene_config.filler_phrases: + return True + + # 常见寒暄和问候(精确匹配,避免误删) + common_greetings = { + "在吗", "在不在", "在呢", "在的", + "你好", "您好", "hello", "hi", + "拜拜", "再见", "拜", "88", "bye", + "好的", "好", "行", "可以", "嗯", "哦", "啊", + "是的", "对", "对的", "没错", "是啊", + "哈哈", "呵呵", "嘿嘿", "嗯嗯" + } + if t in common_greetings: return True # 检查是否为纯表情符号(方括号包裹) @@ -232,13 +237,9 @@ class SemanticPruner: if emoji_pattern.fullmatch(t): return True - # 长度与字符类型判断 - if len(t) <= 8: - # 非数字、无关键实体的短文本 - if not re.search(r"[0-9]", t) and not self._is_important_message(message): - # 主要是标点或简单确认词 - if re.fullmatch(r"[。!?,.!?…·\s]+", t): - return True + # 纯标点符号 + if re.fullmatch(r"[。!?,.!?…·\s]+", t): + return True return False @@ -308,6 +309,8 @@ class SemanticPruner: def _identify_qa_pairs(self, messages: List[ConversationMessage]) -> List[QAPair]: """识别对话中的问答对,用于保护问答结构的完整性。 + 改进版:使用场景特定的问句关键词,并排除寒暄类问句 + Args: messages: 消息列表 @@ -316,21 +319,39 @@ class SemanticPruner: """ qa_pairs = [] + # 寒暄类问句,不应该被保护(这些不是真正的问答) + greeting_questions = { + "在吗", "在不在", "你好吗", "怎么样", "好吗", + "有空吗", "忙吗", "睡了吗", "起床了吗" + } + for i in range(len(messages) - 1): current_msg = messages[i].msg.strip() next_msg = messages[i + 1].msg.strip() - # 简单规则:如果当前消息是问句,下一条消息可能是答案 - is_question = ( - current_msg.endswith("?") or - current_msg.endswith("?") or - any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗"]) - ) + # 排除寒暄类问句 + if current_msg in greeting_questions: + continue + + # 使用场景特定的问句关键词,但要求更严格 + is_question = False + + # 1. 以问号结尾 + if current_msg.endswith("?") or current_msg.endswith("?"): + is_question = True + # 2. 包含实质性问句关键词(排除"吗"这种太宽泛的) + elif any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时"]): + is_question = True if is_question and next_msg: - # 检查下一条消息是否像答案(不是另一个问句) + # 检查下一条消息是否像答案(不是另一个问句,也不是寒暄) is_answer = not (next_msg.endswith("?") or next_msg.endswith("?")) + # 排除寒暄类回复 + greeting_answers = {"你好", "您好", "在呢", "在的", "嗯", "哦", "好的"} + if next_msg in greeting_answers: + is_answer = False + if is_answer: qa_pairs.append(QAPair( question_idx=i, @@ -533,10 +554,9 @@ class SemanticPruner: """数据集层面:全局消息级剪枝,保留所有对话。 改进版: - - 并发处理对话级相关性判断 - - 问答对识别和保护 - - 优化删除策略,保持上下文连贯性 - - 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动 + - 消息级独立判断,每条消息根据场景规则独立评估 + - 问答对保护已注释(暂不启用,留作观察) + - 优化删除策略:填充消息 → 不重要消息 → 低分重要消息 - 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留 - 保证每段对话至少保留1条消息,不会删除整段对话 """ @@ -553,209 +573,122 @@ class SemanticPruner: proportion = 0.0 self._log( - f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}" + f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断" ) - # 并发处理对话级相关性分类 - semaphore = asyncio.Semaphore(self.max_concurrent) - - async def classify_dialog(idx: int, dd: DialogData): - async with semaphore: - try: - ex = await self._extract_dialog_important(dd.content) - return { - "dialog": dd, - "is_related": bool(ex.is_related), - "index": idx, - "extraction": ex - } - except Exception as e: - self._log(f"[剪枝-并发] 对话 {idx} 分类失败: {str(e)[:100]}") - return { - "dialog": dd, - "is_related": True, # 失败时标记为相关,避免误删 - "index": idx, - "extraction": None - } - - # 并发执行所有对话的分类 - tasks = [classify_dialog(idx, dd) for idx, dd in enumerate(dialogs)] - evaluated_dialogs = await asyncio.gather(*tasks) - - # 统计相关 / 不相关对话 - not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]] - related_dialogs = [d for d in evaluated_dialogs if d["is_related"]] - self._log( - f"[剪枝-数据集] 相关对话数={len(related_dialogs)} 不相关对话数={len(not_related_dialogs)}" - ) - - # 简洁打印第几段对话相关/不相关(索引基于1) - def _fmt_indices(items, cap: int = 10): - inds = [i["index"] + 1 for i in items] - if len(inds) <= cap: - return inds - return inds[:cap] + ["...", f"共{len(inds)}个"] - - rel_inds = _fmt_indices(related_dialogs) - nrel_inds = _fmt_indices(not_related_dialogs) - self._log(f"[剪枝-数据集] 相关对话:第{rel_inds}段;不相关对话:第{nrel_inds}段") - result: List[DialogData] = [] - if not_related_dialogs: - # 为每个不相关对话进行分析 - per_dialog_info = {} - total_unrelated = 0 + total_original_msgs = 0 + total_deleted_msgs = 0 + + for d_idx, dd in enumerate(dialogs): + msgs = dd.context.msgs + original_count = len(msgs) + total_original_msgs += original_count - for d in not_related_dialogs: - dd = d["dialog"] - extraction = d.get("extraction") - if extraction is None: - extraction = await self._extract_dialog_important(dd.content) - - # 合并所有重要标记 - tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords - msgs = dd.context.msgs - - # 识别问答对 - qa_pairs = self._identify_qa_pairs(msgs) - protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=1) - - # 分类消息(考虑问答对保护) - imp_unrel_msgs = [] - unimp_unrel_msgs = [] - - for idx, m in enumerate(msgs): - # 问答对中的消息自动标记为重要 - if idx in protected_indices: - imp_unrel_msgs.append((idx, m)) - elif self._msg_matches_tokens(m, tokens) or self._is_important_message(m): - imp_unrel_msgs.append((idx, m)) - elif not self._is_filler_message(m): - unimp_unrel_msgs.append((idx, m)) - # 填充消息不加入任何列表,优先删除 - - # 重要消息按重要性排序 - imp_sorted = sorted(imp_unrel_msgs, key=lambda x: self._importance_score(x[1])) - imp_sorted_ids = [id(m) for _, m in imp_sorted] - - info = { - "dialog": dd, - "total_msgs": len(msgs), - "unrelated_count": len(msgs), - "imp_ids_sorted": imp_sorted_ids, - "unimp_ids": [id(m) for _, m in unimp_unrel_msgs], - "protected_indices": protected_indices, - "qa_pairs_count": len(qa_pairs), - } - per_dialog_info[d["index"]] = info - total_unrelated += info["unrelated_count"] + # ========== 问答对保护(已注释,暂不启用,留作观察) ========== + # qa_pairs = self._identify_qa_pairs(msgs) + # protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=0) + # ======================================================== - # 全局删除配额计算 - global_delete = int(total_unrelated * proportion) - if proportion > 0 and total_unrelated > 0 and global_delete == 0: - global_delete = 1 + # 消息级分类:每条消息独立判断 + important_msgs = [] # 重要消息(保留) + unimportant_msgs = [] # 不重要消息(可删除) + filler_msgs = [] # 填充消息(优先删除) - # 每段的最大可删容量 - capacities = [] - for d in not_related_dialogs: - idx = d["index"] - info = per_dialog_info[idx] - imp_count = len(info["imp_ids_sorted"]) - unimp_count = len(info["unimp_ids"]) - imp_cap = int(imp_count * proportion) - cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1)) - capacities.append(cap) + for idx, m in enumerate(msgs): + msg_text = m.msg.strip() + + # ========== 问答对保护判断(已注释) ========== + # if idx in protected_indices: + # important_msgs.append((idx, m)) + # self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(问答对保护)") + # ========================================== + + # 填充消息(寒暄、表情等) + if self._is_filler_message(m): + filler_msgs.append((idx, m)) + self._log(f" [{idx}] '{msg_text[:30]}...' → 填充") + # 重要信息(学号、成绩、时间、金额等) + elif self._is_important_message(m): + important_msgs.append((idx, m)) + self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)") + # 其他消息 + else: + unimportant_msgs.append((idx, m)) + self._log(f" [{idx}] '{msg_text[:30]}...' → 不重要") - total_capacity = sum(capacities) - if global_delete > total_capacity: - self._log(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}。将按最大可删执行。") - global_delete = total_capacity - - # 配额分配 - alloc = [] - for i, d in enumerate(not_related_dialogs): - idx = d["index"] - info = per_dialog_info[idx] - share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0 - alloc.append(min(share, capacities[i])) + # 计算删除配额 + delete_target = int(original_count * proportion) + if proportion > 0 and original_count > 0 and delete_target == 0: + delete_target = 1 - allocated = sum(alloc) - rem = global_delete - allocated - turn = 0 - while rem > 0 and turn < 100000: - progressed = False - for i in range(len(not_related_dialogs)): - if rem <= 0: - break - if alloc[i] < capacities[i]: - alloc[i] += 1 - rem -= 1 - progressed = True - if not progressed: - break - turn += 1 - - # 应用删除 - total_deleted_confirm = 0 - for d in evaluated_dialogs: - dd = d["dialog"] - msgs = dd.context.msgs - original = len(msgs) - - if d["is_related"]: - result.append(dd) - continue - - idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None) - if idx_in_unrel is None: - result.append(dd) - continue - - quota = alloc[idx_in_unrel] - info = per_dialog_info[d["index"]] - - # 计算删除ID - imp_count = len(info["imp_ids_sorted"]) - imp_del_cap = int(imp_count * proportion) - - unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))]) - del_unimp = min(quota, len(unimp_delete_ids)) - rem_quota = quota - del_unimp - - imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)]) - - deleted_here = 0 - actual_unimp_deleted = 0 - actual_imp_deleted = 0 - kept = [] - - for m in msgs: - mid = id(m) - if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp: - actual_unimp_deleted += 1 - deleted_here += 1 - continue - if mid in imp_delete_ids and actual_imp_deleted < len(imp_delete_ids): - actual_imp_deleted += 1 - deleted_here += 1 - continue - kept.append(m) - - if not kept and msgs: - kept = [msgs[0]] - - dd.context.msgs = kept - total_deleted_confirm += deleted_here - - qa_info = f",问答对={info['qa_pairs_count']}" if info['qa_pairs_count'] > 0 else "" - self._log( - f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}{qa_info}" - ) - result.append(dd) + # 确保至少保留1条消息 + max_deletable = max(0, original_count - 1) + delete_target = min(delete_target, max_deletable) - self._log(f"[剪枝-数据集] 全局消息级剪枝完成,总删除 {total_deleted_confirm} 条(保护问答对和上下文)。") - else: - result = [d["dialog"] for d in evaluated_dialogs] + # 删除策略:优先删除填充消息,再删除不重要消息 + to_delete_indices = set() + deleted_details = [] # 记录删除的消息详情 + + # 第一步:删除填充消息 + filler_to_delete = min(len(filler_msgs), delete_target) + for i in range(filler_to_delete): + idx, msg = filler_msgs[i] + to_delete_indices.add(idx) + deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'") + + # 第二步:如果还需要删除,删除不重要消息 + remaining_quota = delete_target - len(to_delete_indices) + if remaining_quota > 0: + unimp_to_delete = min(len(unimportant_msgs), remaining_quota) + for i in range(unimp_to_delete): + idx, msg = unimportant_msgs[i] + to_delete_indices.add(idx) + deleted_details.append(f"[{idx}] 不重要: '{msg.msg[:50]}'") + + # 第三步:如果还需要删除,按重要性分数删除重要消息 + remaining_quota = delete_target - len(to_delete_indices) + if remaining_quota > 0 and important_msgs: + # 按重要性分数排序(分数低的优先删除) + imp_sorted = sorted(important_msgs, key=lambda x: self._importance_score(x[1])) + imp_to_delete = min(len(imp_sorted), remaining_quota) + for i in range(imp_to_delete): + idx, msg = imp_sorted[i] + to_delete_indices.add(idx) + score = self._importance_score(msg) + deleted_details.append(f"[{idx}] 重要(分数{score}): '{msg.msg[:50]}'") + + # 执行删除 + kept_msgs = [] + for idx, m in enumerate(msgs): + if idx not in to_delete_indices: + kept_msgs.append(m) + + # 确保至少保留1条 + if not kept_msgs and msgs: + kept_msgs = [msgs[0]] + + dd.context.msgs = kept_msgs + deleted_count = original_count - len(kept_msgs) + total_deleted_msgs += deleted_count + + # 输出删除详情 + if deleted_details: + self._log(f"[剪枝-删除详情] 对话 {d_idx+1} 删除了以下消息:") + for detail in deleted_details: + self._log(f" {detail}") + + # ========== 问答对统计(已注释) ========== + # qa_info = f",问答对={len(qa_pairs)}" if qa_pairs else "" + # ======================================== + + self._log( + f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} " + f"(重要={len(important_msgs)} 不重要={len(unimportant_msgs)} 填充={len(filler_msgs)}) " + f"删除={deleted_count} 保留={len(kept_msgs)}" + ) + + result.append(dd) self._log(f"[剪枝-数据集] 剩余对话数={len(result)}") diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/scene_config.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/scene_config.py new file mode 100644 index 00000000..ed9592af --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/scene_config.py @@ -0,0 +1,326 @@ +""" +场景特定配置 - 为不同场景提供定制化的剪枝规则 + +功能: +- 场景特定的重要信息识别模式 +- 场景特定的重要性评分权重 +- 场景特定的填充词库 +- 场景特定的问答对识别规则 +""" + +from typing import Dict, List, Set, Tuple +from dataclasses import dataclass, field + + +@dataclass +class ScenePatterns: + """场景特定的识别模式""" + + # 重要信息的正则模式(优先级从高到低) + high_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) # (pattern, weight) + medium_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) + low_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) + + # 填充词库(无意义对话) + filler_phrases: Set[str] = field(default_factory=set) + + # 问句关键词(用于识别问答对) + question_keywords: Set[str] = field(default_factory=set) + + # 决策性/承诺性关键词 + decision_keywords: Set[str] = field(default_factory=set) + + +class SceneConfigRegistry: + """场景配置注册表 - 管理所有场景的特定配置""" + + # 基础通用模式(所有场景共享) + BASE_HIGH_PRIORITY = [ + (r"订单号|工单|申请号|编号|ID|账号|账户", 5), + (r"金额|费用|价格|¥|¥|\d+元", 5), + (r"\d{11}", 4), # 手机号 + (r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱 + ] + + BASE_MEDIUM_PRIORITY = [ + (r"\d{4}-\d{1,2}-\d{1,2}", 3), # 日期 + (r"\d{4}年\d{1,2}月\d{1,2}日", 3), + (r"电话|手机号|微信|QQ|联系方式", 3), + (r"地址|地点|位置", 2), + (r"时间|日期|有效期|截止", 2), + (r"今天|明天|后天|昨天|前天", 3), # 相对时间(提高权重) + (r"下周|下月|下年|上周|上月|上年|本周|本月|本年", 3), + (r"今年|去年|明年", 3), + ] + + BASE_LOW_PRIORITY = [ + (r"\d{1,2}:\d{2}", 2), # 时间点 HH:MM + (r"\d{1,2}点\d{0,2}分?", 2), # 时间点 X点Y分 或 X点 + (r"上午|下午|中午|晚上|早上|傍晚|凌晨", 2), # 时段(提高权重并扩充) + (r"AM|PM|am|pm", 1), + ] + + BASE_FILLERS = { + # 基础寒暄 + "你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦", + "好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢", + "拜拜", "再见", "88", "拜", "回见", + # 口头禅 + "哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia", + "额", "呃", "啊", "诶", "唉", "哎", "嗯哼", + # 确认词 + "是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了", + # 标点和符号 + "。。。", "...", "???", "???", "!!!", "!!!", + # 表情符号 + "[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]", + "[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]", + "[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]", + "[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]", + # 网络用语 + "hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok", + "emmm", "emm", "em", "mmp", "wtf", "omg", + } + + BASE_QUESTION_KEYWORDS = { + "什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗" + } + + BASE_DECISION_KEYWORDS = { + "必须", "一定", "务必", "需要", "要求", "规定", "应该", + "承诺", "保证", "确保", "负责", "同意", "答应" + } + + @classmethod + def get_education_config(cls) -> ScenePatterns: + """教育场景配置""" + return ScenePatterns( + high_priority_patterns=cls.BASE_HIGH_PRIORITY + [ + # 成绩相关(最高优先级) + (r"成绩|分数|得分|满分|及格|不及格", 6), + (r"GPA|绩点|学分|平均分", 6), + (r"\d+分|\d+\.?\d*分", 5), # 具体分数 + (r"排名|名次|第.{1,3}名", 5), # 支持"第三名"、"第1名"等 + + # 学籍信息 + (r"学号|学生证|教师工号|工号", 5), + (r"班级|年级|专业|院系", 4), + + # 课程相关 + (r"课程|科目|学科|必修|选修", 4), + (r"教材|课本|教科书|参考书", 4), + (r"章节|第.{1,3}章|第.{1,3}节", 3), # 支持"第三章"、"第1章"等 + + # 学科内容(新增) + (r"微积分|导数|积分|函数|极限|微分", 4), + (r"代数|几何|三角|概率|统计", 4), + (r"物理|化学|生物|历史|地理", 4), + (r"英语|语文|数学|政治|哲学", 4), + (r"定义|定理|公式|概念|原理|法则", 3), + (r"例题|解题|证明|推导|计算", 3), + ], + medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [ + # 教学活动 + (r"作业|练习|习题|题目", 3), + (r"考试|测验|测试|考核|期中|期末", 3), + (r"上课|下课|课堂|讲课", 2), + (r"提问|回答|发言|讨论", 2), + (r"问一下|请教|咨询|询问", 2), # 新增:问询相关 + (r"理解|明白|懂|掌握|学会", 2), # 新增:学习状态 + + # 时间安排 + (r"课表|课程表|时间表", 3), + (r"第.{1,3}节课|第.{1,3}周", 2), # 支持"第三节课"、"第1周"等 + ], + low_priority_patterns=cls.BASE_LOW_PRIORITY + [ + (r"老师|教师|同学|学生", 1), + (r"教室|实验室|图书馆", 1), + ], + filler_phrases=cls.BASE_FILLERS | { + # 教育场景特有填充词(移除了"明白了"、"懂了"、"不懂"等,这些在教育场景中有意义) + "老师好", "同学们好", "上课", "下课", "起立", "坐下", + "举手", "请坐", "很好", "不错", "继续", + "下一个", "下一题", "下一位", "还有吗", "还有问题吗", + }, + question_keywords=cls.BASE_QUESTION_KEYWORDS | { + "为啥", "咋", "咋办", "怎样", "如何做", + "能不能", "可不可以", "行不行", "对不对", "是不是", + }, + decision_keywords=cls.BASE_DECISION_KEYWORDS | { + "必考", "重点", "考点", "难点", "关键", + "记住", "背诵", "掌握", "理解", "复习", + } + ) + + @classmethod + def get_online_service_config(cls) -> ScenePatterns: + """在线服务场景配置""" + return ScenePatterns( + high_priority_patterns=cls.BASE_HIGH_PRIORITY + [ + # 工单相关(最高优先级) + (r"工单号|工单编号|ticket|TK\d+", 6), + (r"工单状态|处理中|已解决|已关闭|待处理", 5), + (r"优先级|紧急|高优先级|P0|P1|P2", 5), + + # 产品信息 + (r"产品型号|型号|SKU|产品编号", 5), + (r"序列号|SN|设备号", 5), + (r"版本号|软件版本|固件版本", 4), + + # 问题描述 + (r"故障|错误|异常|bug|问题", 4), + (r"错误代码|故障代码|error code", 5), + (r"无法|不能|失败|报错", 3), + ], + medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [ + # 服务相关 + (r"退款|退货|换货|补发", 4), + (r"发票|收据|凭证", 3), + (r"物流|快递|运单号", 3), + (r"保修|质保|售后", 3), + + # 时效相关 + (r"SLA|响应时间|处理时长", 4), + (r"超时|延迟|等待", 2), + ], + low_priority_patterns=cls.BASE_LOW_PRIORITY + [ + (r"客服|工程师|技术支持", 1), + (r"用户|客户|会员", 1), + ], + filler_phrases=cls.BASE_FILLERS | { + # 在线服务特有填充词 + "您好", "请问", "请稍等", "稍等", "马上", "立即", + "正在查询", "正在处理", "正在为您", "帮您查一下", + "还有其他问题吗", "还需要什么帮助", "很高兴为您服务", + "感谢您的耐心等待", "抱歉让您久等了", + "已记录", "已反馈", "已转接", "已升级", + "祝您生活愉快", "再见", "欢迎下次咨询", + }, + question_keywords=cls.BASE_QUESTION_KEYWORDS | { + "能否", "可否", "是否", "有没有", "能不能", + "怎么办", "如何处理", "怎么解决", + }, + decision_keywords=cls.BASE_DECISION_KEYWORDS | { + "立即处理", "马上解决", "尽快", "优先", + "升级", "转接", "派单", "跟进", + "补偿", "赔偿", "退款", "换货", + } + ) + + @classmethod + def get_outbound_config(cls) -> ScenePatterns: + """外呼场景配置""" + return ScenePatterns( + high_priority_patterns=cls.BASE_HIGH_PRIORITY + [ + # 意向相关(最高优先级) + (r"意向|意愿|兴趣|感兴趣", 6), + (r"A类|B类|C类|D类|高意向|低意向", 6), + (r"成交|签约|下单|购买|确认", 6), + + # 联系信息(外呼场景中更重要) + (r"预约|约定|安排|确定时间", 5), + (r"下次联系|回访|跟进", 5), + (r"方便|有空|可以|时间", 4), + + # 通话状态 + (r"接通|未接通|占线|关机|停机", 4), + (r"通话时长|通话时间", 3), + ], + medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [ + # 客户信息 + (r"姓名|称呼|先生|女士", 3), + (r"公司|单位|职位|职务", 3), + (r"需求|要求|期望", 3), + + # 跟进状态 + (r"跟进状态|进展|进度", 3), + (r"已联系|待联系|联系中", 2), + (r"拒绝|不感兴趣|考虑|再说", 3), + ], + low_priority_patterns=cls.BASE_LOW_PRIORITY + [ + (r"销售|客户经理|业务员", 1), + (r"产品|服务|方案", 1), + ], + filler_phrases=cls.BASE_FILLERS | { + # 外呼场景特有填充词 + "您好", "喂", "hello", "打扰了", "不好意思", + "方便接电话吗", "现在方便吗", "占用您一点时间", + "我是", "我们是", "我们公司", "我们这边", + "了解一下", "介绍一下", "简单说一下", + "考虑考虑", "想一想", "再说", "再看看", + "不需要", "不感兴趣", "没兴趣", "不用了", + "好的", "行", "可以", "没问题", "那就这样", + "再联系", "回头聊", "有需要再说", + }, + question_keywords=cls.BASE_QUESTION_KEYWORDS | { + "有没有", "需不需要", "要不要", "考虑不考虑", + "了解吗", "知道吗", "听说过吗", + "方便吗", "有空吗", "在吗", + }, + decision_keywords=cls.BASE_DECISION_KEYWORDS | { + "确定", "决定", "选择", "购买", "下单", + "预约", "安排", "约定", "确认", + "跟进", "回访", "联系", "沟通", + } + ) + + @classmethod + def get_config(cls, scene: str, fallback_to_generic: bool = True) -> ScenePatterns: + """根据场景名称获取配置 + + Args: + scene: 场景名称 ('education', 'online_service', 'outbound' 或其他) + fallback_to_generic: 如果场景不存在,是否降级到通用配置 + + Returns: + 对应场景的配置,如果场景不存在: + - fallback_to_generic=True: 返回通用配置(仅基础规则) + - fallback_to_generic=False: 抛出异常 + """ + scene_map = { + 'education': cls.get_education_config, + 'online_service': cls.get_online_service_config, + 'outbound': cls.get_outbound_config, + } + + if scene in scene_map: + return scene_map[scene]() + + if fallback_to_generic: + # 返回通用配置(仅包含基础规则,不包含场景特定规则) + return cls.get_generic_config() + else: + raise ValueError(f"不支持的场景: {scene},支持的场景: {list(scene_map.keys())}") + + @classmethod + def get_generic_config(cls) -> ScenePatterns: + """通用场景配置 - 仅包含基础规则,适用于未定义的场景 + + 这是一个保守的配置,只使用最通用的规则,避免误删重要信息 + """ + return ScenePatterns( + high_priority_patterns=cls.BASE_HIGH_PRIORITY, + medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY, + low_priority_patterns=cls.BASE_LOW_PRIORITY, + filler_phrases=cls.BASE_FILLERS, + question_keywords=cls.BASE_QUESTION_KEYWORDS, + decision_keywords=cls.BASE_DECISION_KEYWORDS + ) + + @classmethod + def get_all_scenes(cls) -> List[str]: + """获取所有预定义场景的列表""" + return ['education', 'online_service', 'outbound'] + + @classmethod + def is_scene_supported(cls, scene: str) -> bool: + """检查场景是否有专门的配置支持 + + Args: + scene: 场景名称 + + Returns: + True: 有专门配置 + False: 将使用通用配置 + """ + return scene in cls.get_all_scenes() diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index a47497da..17bda0e4 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1988,6 +1988,7 @@ async def get_chunked_dialogs_with_preprocessing( input_data_path: Optional[str] = None, llm_client: Optional[Any] = None, skip_cleaning: bool = True, + pruning_config: Optional[Dict] = None, ) -> List[DialogData]: """包含数据预处理步骤的完整分块流程 @@ -2000,6 +2001,7 @@ async def get_chunked_dialogs_with_preprocessing( input_data_path: 输入数据路径 llm_client: LLM 客户端 skip_cleaning: 是否跳过数据清洗步骤(默认False) + pruning_config: 剪枝配置字典,包含 pruning_switch, pruning_scene, pruning_threshold Returns: 带 chunks 的 DialogData 列表 @@ -2030,7 +2032,19 @@ async def get_chunked_dialogs_with_preprocessing( from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import ( SemanticPruner, ) - pruner = SemanticPruner(llm_client=llm_client) + from app.core.memory.models.config_models import PruningConfig + + # 构建剪枝配置 + if pruning_config: + # 使用传入的配置 + config = PruningConfig(**pruning_config) + print(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") + else: + # 使用默认配置(关闭剪枝) + config = None + print("[剪枝] 未提供配置,使用默认配置(剪枝关闭)") + + pruner = SemanticPruner(config=config, llm_client=llm_client) # 记录单对话场景下剪枝前的消息数量 single_dialog_original_msgs = None From 2d5c2de613ad9a5f4c3f3af5e3c05102917be4f5 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Sat, 28 Feb 2026 17:51:12 +0800 Subject: [PATCH 030/164] [add]New semantic pruning effect display for streaming output --- api/app/services/pilot_run_service.py | 97 +++++++++++++++++++++++++-- 1 file changed, 92 insertions(+), 5 deletions(-) diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index 34b8867e..31e4d6dd 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -101,14 +101,101 @@ async def run_pilot_extraction( ) if progress_callback: - await progress_callback("text_preprocessing", "开始预处理文本...") + await progress_callback("text_preprocessing", "开始预处理文本(语义剪枝 + 语义分块)...") + # ========== 步骤 2.1: 语义剪枝 ========== + pruned_dialogs = [dialog] + deleted_messages = [] # 记录被删除的消息 + + if memory_config.pruning_enabled: + try: + from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import ( + SemanticPruner, + ) + from app.core.memory.models.config_models import PruningConfig + + # 构建剪枝配置 + pruning_config_dict = { + "pruning_switch": memory_config.pruning_enabled, + "pruning_scene": memory_config.pruning_scene, + "pruning_threshold": memory_config.pruning_threshold, + "llm_model_id": str(memory_config.llm_model_id), + } + config = PruningConfig(**pruning_config_dict) + + logger.info(f"[PILOT_RUN] 开始语义剪枝: scene={config.pruning_scene}, threshold={config.pruning_threshold}") + + # 记录剪枝前的消息(用于对比) + original_messages = [{"role": msg.role, "content": msg.msg} for msg in dialog.context.msgs] + original_msg_count = len(original_messages) + + # 执行剪枝 + pruner = SemanticPruner(config=config, llm_client=llm_client) + pruned_dialogs = await pruner.prune_dataset([dialog]) + + # 计算剪枝结果并找出被删除的消息 + if pruned_dialogs and pruned_dialogs[0].context: + remaining_messages = [{"role": msg.role, "content": msg.msg} for msg in pruned_dialogs[0].context.msgs] + remaining_msg_count = len(remaining_messages) + deleted_msg_count = original_msg_count - remaining_msg_count + + # 找出被删除的消息(通过内容对比) + remaining_contents = {msg["content"] for msg in remaining_messages} + deleted_messages = [ + {"index": idx, "role": msg["role"], "content": msg["content"]} + for idx, msg in enumerate(original_messages) + if msg["content"] not in remaining_contents + ] + + pruning_result = { + "enabled": True, + "scene": config.pruning_scene, + "threshold": config.pruning_threshold, + "original_count": original_msg_count, + "remaining_count": remaining_msg_count, + "deleted_count": deleted_msg_count, + "deleted_messages": deleted_messages, + } + + logger.info( + f"[PILOT_RUN] 语义剪枝完成: 原始{original_msg_count}条 -> " + f"保留{remaining_msg_count}条 (删除{deleted_msg_count}条)" + ) + + if progress_callback: + await progress_callback("text_preprocessing_pruning", "语义剪枝完成", pruning_result) + else: + logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话") + pruned_dialogs = [dialog] + + except Exception as e: + logger.error(f"[PILOT_RUN] 语义剪枝失败,使用原始对话: {e}", exc_info=True) + pruned_dialogs = [dialog] + if progress_callback: + error_result = { + "enabled": True, + "error": str(e), + "fallback": "使用原始对话" + } + await progress_callback("text_preprocessing_pruning", "语义剪枝失败", error_result) + else: + logger.info("[PILOT_RUN] 语义剪枝已关闭,跳过") + if progress_callback: + pruning_result = { + "enabled": False, + "message": "语义剪枝已关闭" + } + await progress_callback("text_preprocessing_pruning", "语义剪枝已关闭", pruning_result) + + # ========== 步骤 2.2: 语义分块 ========== chunked_dialogs = await get_chunked_dialogs_from_preprocessed( - data=[dialog], + data=pruned_dialogs, chunker_strategy=memory_config.chunker_strategy, llm_client=llm_client, ) - logger.info(f"Processed dialogue text: {len(messages)} messages") + + remaining_msg_count = len(pruned_dialogs[0].context.msgs) if pruned_dialogs and pruned_dialogs[0].context else 0 + logger.info(f"Processed dialogue text: {remaining_msg_count} messages after pruning") # 进度回调:输出每个分块的结果 if progress_callback: @@ -121,14 +208,14 @@ async def run_pilot_extraction( "dialog_id": dlg.id, "chunker_strategy": memory_config.chunker_strategy, } - await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result) + await progress_callback("text_preprocessing_chunking", f"分块 {i + 1} 处理完成", chunk_result) preprocessing_summary = { "total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs), "total_dialogs": len(chunked_dialogs), "chunker_strategy": memory_config.chunker_strategy, } - await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary) + await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary) log_time("Data Loading & Chunking", time.time() - step_start, log_file) From 4aeb653ed2633acba37a9c58890f7f254aca637e Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Sat, 28 Feb 2026 18:19:44 +0800 Subject: [PATCH 031/164] [fix]Fix the display issue of semantic chunking for streaming output --- api/app/services/pilot_run_service.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index 31e4d6dd..4cfa158d 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -200,18 +200,19 @@ async def run_pilot_extraction( # 进度回调:输出每个分块的结果 if progress_callback: for dlg in chunked_dialogs: - for i, chunk in enumerate(dlg.chunks): - chunk_result = { - "chunk_index": i + 1, - "content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, - "full_length": len(chunk.content), - "dialog_id": dlg.id, - "chunker_strategy": memory_config.chunker_strategy, - } - await progress_callback("text_preprocessing_chunking", f"分块 {i + 1} 处理完成", chunk_result) + if hasattr(dlg, 'chunks') and dlg.chunks: + for i, chunk in enumerate(dlg.chunks): + chunk_result = { + "chunk_index": i + 1, + "content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, + "full_length": len(chunk.content), + "dialog_id": dlg.id, + "chunker_strategy": memory_config.chunker_strategy, + } + await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result) preprocessing_summary = { - "total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs), + "total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs if hasattr(dlg, 'chunks') and dlg.chunks), "total_dialogs": len(chunked_dialogs), "chunker_strategy": memory_config.chunker_strategy, } From 4ac63e1c239374abaaeb9325685af8f9ef0a63c3 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Sat, 28 Feb 2026 19:26:16 +0800 Subject: [PATCH 032/164] [add]Complete the interface integration for the display of semantic pruning for streaming output. --- api/app/services/pilot_run_service.py | 33 +++++++++++++++++---------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index 4cfa158d..c39d089e 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -106,6 +106,7 @@ async def run_pilot_extraction( # ========== 步骤 2.1: 语义剪枝 ========== pruned_dialogs = [dialog] deleted_messages = [] # 记录被删除的消息 + pruning_stats = None # 保存剪枝统计信息,用于最终汇总 if memory_config.pruning_enabled: try: @@ -147,13 +148,17 @@ async def run_pilot_extraction( if msg["content"] not in remaining_contents ] - pruning_result = { + # 保存剪枝统计信息(用于最终汇总,只保留deleted_count) + pruning_stats = { "enabled": True, "scene": config.pruning_scene, "threshold": config.pruning_threshold, - "original_count": original_msg_count, - "remaining_count": remaining_msg_count, "deleted_count": deleted_msg_count, + } + + # 输出剪枝结果(显示删除的消息详情) + pruning_result = { + "type": "pruning", "deleted_messages": deleted_messages, } @@ -163,7 +168,7 @@ async def run_pilot_extraction( ) if progress_callback: - await progress_callback("text_preprocessing_pruning", "语义剪枝完成", pruning_result) + await progress_callback("text_preprocessing_result", "语义剪枝完成", pruning_result) else: logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话") pruned_dialogs = [dialog] @@ -173,19 +178,16 @@ async def run_pilot_extraction( pruned_dialogs = [dialog] if progress_callback: error_result = { - "enabled": True, + "type": "pruning", "error": str(e), "fallback": "使用原始对话" } - await progress_callback("text_preprocessing_pruning", "语义剪枝失败", error_result) + await progress_callback("text_preprocessing_result", "语义剪枝失败", error_result) else: logger.info("[PILOT_RUN] 语义剪枝已关闭,跳过") - if progress_callback: - pruning_result = { - "enabled": False, - "message": "语义剪枝已关闭" - } - await progress_callback("text_preprocessing_pruning", "语义剪枝已关闭", pruning_result) + pruning_stats = { + "enabled": False, + } # ========== 步骤 2.2: 语义分块 ========== chunked_dialogs = await get_chunked_dialogs_from_preprocessed( @@ -203,6 +205,7 @@ async def run_pilot_extraction( if hasattr(dlg, 'chunks') and dlg.chunks: for i, chunk in enumerate(dlg.chunks): chunk_result = { + "type": "chunking", "chunk_index": i + 1, "content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, "full_length": len(chunk.content), @@ -211,11 +214,17 @@ async def run_pilot_extraction( } await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result) + # 构建预处理完成总结(包含剪枝统计) preprocessing_summary = { "total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs if hasattr(dlg, 'chunks') and dlg.chunks), "total_dialogs": len(chunked_dialogs), "chunker_strategy": memory_config.chunker_strategy, } + + # 添加剪枝统计信息 + if pruning_stats: + preprocessing_summary["pruning"] = pruning_stats + await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary) log_time("Data Loading & Chunking", time.time() - step_start, log_file) From 8e15a340f685c9e448675694d80d3f56a3c40322 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 2 Mar 2026 12:09:10 +0800 Subject: [PATCH 033/164] [changes]Correct log output, log level, and pruning conditions --- .../data_preprocessing/data_pruning.py | 18 ++++++++++++--- .../extraction_orchestrator.py | 22 +++++++++---------- api/app/services/pilot_run_service.py | 16 +++++++++++--- 3 files changed, 39 insertions(+), 17 deletions(-) diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index d932c542..0a913633 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -82,6 +82,10 @@ class SemanticPruner: self.language = language # 保存语言配置 self.max_concurrent = max_concurrent # 新增:最大并发数 + # 详细日志配置:限制逐条消息日志的数量 + self._detailed_prune_logging = True # 是否启用详细日志 + self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志 + # 加载场景特定配置 self.scene_config: ScenePatterns = SceneConfigRegistry.get_config( self.config.pruning_scene, @@ -595,6 +599,11 @@ class SemanticPruner: unimportant_msgs = [] # 不重要消息(可删除) filler_msgs = [] # 填充消息(优先删除) + # 判断是否需要详细日志(仅对前N条消息记录) + should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog + if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog: + self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志") + for idx, m in enumerate(msgs): msg_text = m.msg.strip() @@ -607,15 +616,18 @@ class SemanticPruner: # 填充消息(寒暄、表情等) if self._is_filler_message(m): filler_msgs.append((idx, m)) - self._log(f" [{idx}] '{msg_text[:30]}...' → 填充") + if should_log_details or idx < self._max_debug_msgs_per_dialog: + self._log(f" [{idx}] '{msg_text[:30]}...' → 填充") # 重要信息(学号、成绩、时间、金额等) elif self._is_important_message(m): important_msgs.append((idx, m)) - self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)") + if should_log_details or idx < self._max_debug_msgs_per_dialog: + self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)") # 其他消息 else: unimportant_msgs.append((idx, m)) - self._log(f" [{idx}] '{msg_text[:30]}...' → 不重要") + if should_log_details or idx < self._max_debug_msgs_per_dialog: + self._log(f" [{idx}] '{msg_text[:30]}...' → 不重要") # 计算删除配额 delete_target = int(original_count * proportion) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 17bda0e4..1242e4e6 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1932,17 +1932,17 @@ def preprocess_data( Returns: 经过清洗转换后的 DialogData 列表 """ - print("\n=== 数据预处理 ===") + logger.debug("=== 数据预处理 ===") from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import ( DataPreprocessor, ) preprocessor = DataPreprocessor() try: cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices) - print(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据") + logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据") return cleaned_data except Exception as e: - print(f"数据预处理过程中出现错误: {e}") + logger.error(f"数据预处理过程中出现错误: {e}") raise @@ -1961,7 +1961,7 @@ async def get_chunked_dialogs_from_preprocessed( Returns: 带 chunks 的 DialogData 列表 """ - print(f"\n=== 批量对话分块处理 (使用 {chunker_strategy}) ===") + logger.debug(f"=== 批量对话分块处理 (使用 {chunker_strategy}) ===") if not data: raise ValueError("预处理数据为空,无法进行分块") @@ -2006,7 +2006,7 @@ async def get_chunked_dialogs_with_preprocessing( Returns: 带 chunks 的 DialogData 列表 """ - print("\n=== 完整数据处理流程(包含预处理)===") + logger.debug("=== 完整数据处理流程(包含预处理)===") if input_data_path is None: input_data_path = os.path.join( @@ -2038,11 +2038,11 @@ async def get_chunked_dialogs_with_preprocessing( if pruning_config: # 使用传入的配置 config = PruningConfig(**pruning_config) - print(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") + logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") else: # 使用默认配置(关闭剪枝) config = None - print("[剪枝] 未提供配置,使用默认配置(剪枝关闭)") + logger.debug("[剪枝] 未提供配置,使用默认配置(剪枝关闭)") pruner = SemanticPruner(config=config, llm_client=llm_client) @@ -2057,12 +2057,12 @@ async def get_chunked_dialogs_with_preprocessing( if len(preprocessed_data) == 1 and single_dialog_original_msgs is not None: remaining_msgs = len(preprocessed_data[0].context.msgs) if preprocessed_data[0].context else 0 deleted_msgs = max(0, single_dialog_original_msgs - remaining_msgs) - print( + logger.debug( f"语义剪枝完成!剩余 1 条对话!原始消息数:{single_dialog_original_msgs}," f"保留消息数:{remaining_msgs},删除 {deleted_msgs} 条。" ) else: - print(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话") + logger.debug(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话") # 保存剪枝后的数据 try: @@ -2073,9 +2073,9 @@ async def get_chunked_dialogs_with_preprocessing( dp = DataPreprocessor(output_file_path=pruned_output_path) dp.save_data(preprocessed_data, output_path=pruned_output_path) except Exception as se: - print(f"保存剪枝结果失败:{se}") + logger.error(f"保存剪枝结果失败:{se}") except Exception as e: - print(f"语义剪枝过程中出现错误,跳过剪枝: {e}") + logger.error(f"语义剪枝过程中出现错误,跳过剪枝: {e}") # 步骤3: 对话分块 return await get_chunked_dialogs_from_preprocessed( diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index c39d089e..4d9cbb5e 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -140,12 +140,22 @@ async def run_pilot_extraction( remaining_msg_count = len(remaining_messages) deleted_msg_count = original_msg_count - remaining_msg_count - # 找出被删除的消息(通过内容对比) - remaining_contents = {msg["content"] for msg in remaining_messages} + # 找出被删除的消息(基于索引精确匹配) + # 为剩余消息创建带索引的列表,用于精确追踪 + remaining_with_index = [] + remaining_idx = 0 + for orig_idx, orig_msg in enumerate(original_messages): + if remaining_idx < len(remaining_messages) and \ + orig_msg["role"] == remaining_messages[remaining_idx]["role"] and \ + orig_msg["content"] == remaining_messages[remaining_idx]["content"]: + remaining_with_index.append(orig_idx) + remaining_idx += 1 + + # 找出未在保留列表中的消息索引 deleted_messages = [ {"index": idx, "role": msg["role"], "content": msg["content"]} for idx, msg in enumerate(original_messages) - if msg["content"] not in remaining_contents + if idx not in remaining_with_index ] # 保存剪枝统计信息(用于最终汇总,只保留deleted_count) From 2ff9000d2527094d0b36e6d939b90a5447cdbfb9 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Fri, 27 Feb 2026 13:45:31 +0800 Subject: [PATCH 034/164] feat(web): form add rules --- web/src/components/SearchInput/index.tsx | 2 + web/src/i18n/en.ts | 5 +- web/src/i18n/zh.ts | 5 +- web/src/utils/validator.ts | 52 +++++++++++++++++++ .../components/ApiKeyModal.tsx | 10 +++- web/src/views/ApplicationConfig/Agent.tsx | 10 ++-- .../components/ApplicationModal.tsx | 8 ++- .../components/MemberModal.tsx | 6 ++- .../components/MemoryForm.tsx | 8 ++- .../components/CustomModelModal.tsx | 20 +++++-- .../components/GroupModelModal.tsx | 23 ++++++-- web/src/views/ModelManagement/index.tsx | 1 + .../components/OntologyClassExtractModal.tsx | 5 +- .../components/OntologyClassModal.tsx | 8 ++- .../Ontology/components/OntologyModal.tsx | 8 ++- web/src/views/Skills/pages/SkillConfig.tsx | 10 +++- web/src/views/Skills/types.ts | 2 + .../SpaceManagement/components/SpaceModal.tsx | 12 +++-- 18 files changed, 167 insertions(+), 28 deletions(-) create mode 100644 web/src/utils/validator.ts diff --git a/web/src/components/SearchInput/index.tsx b/web/src/components/SearchInput/index.tsx index 32a64310..476c2cbb 100644 --- a/web/src/components/SearchInput/index.tsx +++ b/web/src/components/SearchInput/index.tsx @@ -41,6 +41,8 @@ interface SearchInputProps { className?: string; /** Input size */ size?: InputProps['size'] + /** Maximum length of the input value */ + maxLength?: number; } /** Search input component with debounce and throttle support */ diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 02add0ec..c5c81f13 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -453,6 +453,8 @@ export const en = { prevStep: 'Previous Step', exportSuccess: 'Export successful', recommend: 'Recommend', + logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`, + imageSquareRequired: 'Please upload a square image', }, model: { searchPlaceholder: 'search model…', @@ -542,7 +544,8 @@ export const en = { ollama: "Ollama", xinference: "Xinference", gpustack: "Gpustack", - bedrock: "Bedrock" + bedrock: "Bedrock", + nameInvalid: 'Model name can only contain letters, numbers, underscores and spaces, cannot be empty or pure whitespace', }, modelNew: { group: 'Model Group', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 06abf63a..f63e5981 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1029,6 +1029,8 @@ export const zh = { prevStep: '上一步', exportSuccess: '导出成功', recommend: '推荐', + logoTip: `支持图片格式(JPG、PNG)\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`, + imageSquareRequired: '请上传正方形比例图片', }, model: { searchPlaceholder: '搜索模型…', @@ -1176,7 +1178,8 @@ export const zh = { ollama: "Ollama", xinference: "Xinference", gpustack: "Gpustack", - bedrock: "Bedrock" + bedrock: "Bedrock", + nameInvalid: '模型名称只能包含字母、数字、下划线和空格, 不能为空或纯空格', }, timezones: { 'Asia/Shanghai': '中国标准时间 (UTC+8)', diff --git a/web/src/utils/validator.ts b/web/src/utils/validator.ts new file mode 100644 index 00000000..70f0cb02 --- /dev/null +++ b/web/src/utils/validator.ts @@ -0,0 +1,52 @@ +/* + * @Author: zhaoying yzhao96@best-inc.com + * @Date: 2026-03-02 13:46:53 + * @LastEditors: zhaoying yzhao96@best-inc.com + * @LastEditTime: 2026-03-02 14:38:33 + * @FilePath: /web/src/utils/validator.ts + * @Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE + */ +/** + * Form validation utilities + */ + +interface UploadFile { + originFileObj: Blob; + [key: string]: unknown; +} + +/** + * Validate if uploaded image is square (width === height) + * @param errorMessage - Error message to display when validation fails + * @returns Ant Design form validator + */ +export const validateSquareImage = (errorMessage: string = 'Image must be square') => { + return (_: unknown, value: UploadFile | UploadFile[] | undefined) => { + if (!value || (Array.isArray(value) && value.length === 0)) { + return Promise.resolve(); + } + + const file = Array.isArray(value) ? value[0] : value; + + if (file?.originFileObj) { + return new Promise((resolve, reject) => { + const img = new Image(); + img.onload = () => { + if (img.width === img.height) { + resolve(); + } else { + reject(new Error(errorMessage)); + } + }; + img.onerror = () => reject(new Error('Failed to load image')); + img.src = URL.createObjectURL(file.originFileObj); + }); + } + + return Promise.resolve(); + }; +}; + +// - Cannot be empty or pure whitespace +// - Cannot start with a space +export const stringRegExp = /^[a-zA-Z0-9\u4e00-\u9fa5][a-zA-Z0-9\u4e00-\u9fa5\s]*$/ \ No newline at end of file diff --git a/web/src/views/ApiKeyManagement/components/ApiKeyModal.tsx b/web/src/views/ApiKeyManagement/components/ApiKeyModal.tsx index 9395df43..05e73992 100644 --- a/web/src/views/ApiKeyManagement/components/ApiKeyModal.tsx +++ b/web/src/views/ApiKeyManagement/components/ApiKeyModal.tsx @@ -12,6 +12,7 @@ import dayjs from 'dayjs' import type { ApiKey, ApiKeyModalRef } from '../types'; import RbModal from '@/components/RbModal' import { createApiKey, updateApiKey } from '@/api/apiKey'; +import { stringRegExp } from '@/utils/validator'; const FormItem = Form.Item; @@ -78,7 +79,7 @@ const ApiKeyModal = forwardRef(({ form.validateFields() .then((values) => { const { memory, rag, expires_at, ...rest } = values - let scopes = [] + const scopes = [] if (memory) { scopes.push('memory') @@ -130,7 +131,11 @@ const ApiKeyModal = forwardRef(({ @@ -138,6 +143,7 @@ const ApiKeyModal = forwardRef(({ diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 2ece4b6e..4bee291b 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -169,8 +169,8 @@ const Agent = forwardRef((_props, ref) => { getApplicationConfig(id as string).then(res => { const response = res as Config const { skills, variables } = response - let allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : [] - let allTools = Array.isArray(response.tools) ? response.tools : [] + const allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : [] + const allTools = Array.isArray(response.tools) ? response.tools : [] const memoryContent = response.memory?.memory_config_id const parsedMemoryContent = memoryContent === null || memoryContent === '' ? undefined @@ -431,7 +431,11 @@ const Agent = forwardRef((_props, ref) => {
- + ( diff --git a/web/src/views/MemberManagement/components/MemberModal.tsx b/web/src/views/MemberManagement/components/MemberModal.tsx index 002c8632..e16c60ba 100644 --- a/web/src/views/MemberManagement/components/MemberModal.tsx +++ b/web/src/views/MemberManagement/components/MemberModal.tsx @@ -152,7 +152,11 @@ const MemberModal = forwardRef(({ diff --git a/web/src/views/MemoryManagement/components/MemoryForm.tsx b/web/src/views/MemoryManagement/components/MemoryForm.tsx index 93246ca9..282199af 100644 --- a/web/src/views/MemoryManagement/components/MemoryForm.tsx +++ b/web/src/views/MemoryManagement/components/MemoryForm.tsx @@ -18,6 +18,7 @@ import RbModal from '@/components/RbModal' import { createMemoryConfig, updateMemoryConfig } from '@/api/memory' import { getOntologyScenesSimpleUrl } from '@/api/ontology' import CustomSelect from '@/components/CustomSelect'; +import { stringRegExp } from '@/utils/validator'; const FormItem = Form.Item; @@ -110,7 +111,11 @@ const MemoryForm = forwardRef(({ @@ -118,6 +123,7 @@ const MemoryForm = forwardRef(({ diff --git a/web/src/views/ModelManagement/components/CustomModelModal.tsx b/web/src/views/ModelManagement/components/CustomModelModal.tsx index 17373a02..112534a5 100644 --- a/web/src/views/ModelManagement/components/CustomModelModal.tsx +++ b/web/src/views/ModelManagement/components/CustomModelModal.tsx @@ -20,6 +20,7 @@ import CustomSelect from '@/components/CustomSelect' import UploadImages from '@/components/Upload/UploadImages' import { updateCustomModel, addCustomModel, modelTypeUrl, modelProviderUrl } from '@/api/models' import { getFileLink } from '@/api/fileStorage' +import { validateSquareImage, stringRegExp } from '@/utils/validator' /** * Custom model modal component @@ -65,7 +66,7 @@ const CustomModelModal = forwardRef( const res = isEdit ? updateCustomModel(model.id, rest) : addCustomModel(data) res.then(() => { - refresh && refresh(isEdit) + refresh?.(isEdit) handleClose() message.success(isEdit ? t('common.updateSuccess') : t('common.createSuccess')) }) @@ -79,7 +80,7 @@ const CustomModelModal = forwardRef( .validateFields() .then((values) => { const { logo, ...rest } = values; - let formData: CustomModelForm = { + const formData: CustomModelForm = { ...rest } @@ -125,14 +126,22 @@ const CustomModelModal = forwardRef( name="logo" label={t('modelNew.logo')} valuePropName="fileList" - rules={[{ required: true, message: t('common.pleaseSelect') }]} + rules={[ + { required: true, message: t('common.pleaseSelect') }, + { validator: validateSquareImage(t('common.imageSquareRequired')) } + ]} + extra={t('common.logoTip')?.split('\n').map((vo, index) =>
{vo}
)} > - +
@@ -166,6 +175,7 @@ const CustomModelModal = forwardRef( diff --git a/web/src/views/ModelManagement/components/GroupModelModal.tsx b/web/src/views/ModelManagement/components/GroupModelModal.tsx index efcd77f6..5ca46548 100644 --- a/web/src/views/ModelManagement/components/GroupModelModal.tsx +++ b/web/src/views/ModelManagement/components/GroupModelModal.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 16:49:33 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 16:49:33 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-02 12:23:13 */ /** * Group Model Modal @@ -21,6 +21,7 @@ import { updateCompositeModel, modelTypeUrl, addCompositeModel } from '@/api/mod import UploadImages from '@/components/Upload/UploadImages' import ModelImplement from './ModelImplement' import { getFileLink } from '@/api/fileStorage' +import { validateSquareImage, stringRegExp } from '@/utils/validator' /** * Group model modal component @@ -133,15 +134,26 @@ const GroupModelModal = forwardRef(({ name="logo" label={t('modelNew.logo')} valuePropName="fileList" - rules={[{ required: true, message: t('common.pleaseSelect') }]} + rules={[ + { required: true, message: t('common.pleaseSelect') }, + { validator: validateSquareImage(t('common.imageSquareRequired')) } + ]} + extra={t('common.logoTip')?.split('\n').map((vo, index) =>
{vo}
)} > - +
@@ -165,6 +177,7 @@ const GroupModelModal = forwardRef(({ diff --git a/web/src/views/ModelManagement/index.tsx b/web/src/views/ModelManagement/index.tsx index 35d7d864..539ff5e3 100644 --- a/web/src/views/ModelManagement/index.tsx +++ b/web/src/views/ModelManagement/index.tsx @@ -121,6 +121,7 @@ const tabKeys = ['group', 'list', 'square'] {activeTab !== 'list' && diff --git a/web/src/views/Ontology/components/OntologyClassExtractModal.tsx b/web/src/views/Ontology/components/OntologyClassExtractModal.tsx index 802202ef..2fd305c6 100644 --- a/web/src/views/Ontology/components/OntologyClassExtractModal.tsx +++ b/web/src/views/Ontology/components/OntologyClassExtractModal.tsx @@ -182,7 +182,10 @@ const OntologyClassExtractModal = forwardRef diff --git a/web/src/views/Ontology/components/OntologyClassModal.tsx b/web/src/views/Ontology/components/OntologyClassModal.tsx index 087e542c..a467294e 100644 --- a/web/src/views/Ontology/components/OntologyClassModal.tsx +++ b/web/src/views/Ontology/components/OntologyClassModal.tsx @@ -11,6 +11,7 @@ import { useTranslation } from 'react-i18next'; import type { AddClassItem, OntologyClassModalRef } from '../types' import RbModal from '@/components/RbModal' import { createOntologyClass } from '@/api/ontology' +import { stringRegExp } from '@/utils/validator'; const FormItem = Form.Item; @@ -105,7 +106,11 @@ const OntologyClassModal = forwardRef @@ -113,6 +118,7 @@ const OntologyClassModal = forwardRef diff --git a/web/src/views/Ontology/components/OntologyModal.tsx b/web/src/views/Ontology/components/OntologyModal.tsx index a4c203ed..92e94bb6 100644 --- a/web/src/views/Ontology/components/OntologyModal.tsx +++ b/web/src/views/Ontology/components/OntologyModal.tsx @@ -11,6 +11,7 @@ import { useTranslation } from 'react-i18next'; import type { OntologyItem, OntologyModalData, OntologyModalRef } from '../types' import RbModal from '@/components/RbModal' import { createOntologyScene, updateOntologyScene } from '@/api/ontology' +import { stringRegExp } from '@/utils/validator'; const FormItem = Form.Item; @@ -109,7 +110,11 @@ const OntologyModal = forwardRef(({ @@ -117,6 +122,7 @@ const OntologyModal = forwardRef(({ diff --git a/web/src/views/Skills/pages/SkillConfig.tsx b/web/src/views/Skills/pages/SkillConfig.tsx index 6e12e72d..f9f76dea 100644 --- a/web/src/views/Skills/pages/SkillConfig.tsx +++ b/web/src/views/Skills/pages/SkillConfig.tsx @@ -17,6 +17,7 @@ import type { AiPromptModalRef } from '@/views/ApplicationConfig/types' import exitIcon from '@/assets/images/knowledgeBase/exit.png'; import type { SkillFormData } from '../types' import { getSkillDetail, createSkill, updateSkill } from '@/api/skill' +import { stringRegExp } from '@/utils/validator'; /** * Skill Configuration Page Component @@ -110,7 +111,7 @@ const SkillConfig: FC = () => { // Format tools data for API const formData = { ...rest, - tools: tools?.map((item: any) => ({ + tools: tools?.map((item) => ({ tool_id: item.tool_id, operation: item.operation })) @@ -144,13 +145,18 @@ const SkillConfig: FC = () => { diff --git a/web/src/views/Skills/types.ts b/web/src/views/Skills/types.ts index 950bfb03..c0df3fbf 100644 --- a/web/src/views/Skills/types.ts +++ b/web/src/views/Skills/types.ts @@ -17,6 +17,8 @@ export interface SkillFormData { tools: Array<{ /** Tool identifier */ tool_id: string; + /** Tool operation/action */ + operation?: string; }>; /** Skill configuration settings */ config: { diff --git a/web/src/views/SpaceManagement/components/SpaceModal.tsx b/web/src/views/SpaceManagement/components/SpaceModal.tsx index 4f37b246..5a639244 100644 --- a/web/src/views/SpaceManagement/components/SpaceModal.tsx +++ b/web/src/views/SpaceManagement/components/SpaceModal.tsx @@ -23,6 +23,7 @@ import UploadImages from '@/components/Upload/UploadImages' import { getFileLink } from '@/api/fileStorage' import ragIcon from '@/assets/images/space/rag.png' import neo4jIcon from '@/assets/images/space/neo4j.png' +import { stringRegExp } from '@/utils/validator'; const FormItem = Form.Item; @@ -91,7 +92,7 @@ const SpaceModal = forwardRef(({ setCurrentStep(1) } else { const { icon, ...rest } = values - let formData: SpaceModalData = { + const formData: SpaceModalData = { ...rest } if (icon?.response?.data.file_id) { @@ -164,14 +165,19 @@ const SpaceModal = forwardRef(({ valuePropName="fileList" hidden={currentStep === 1} rules={[{ required: true, message: t('common.selectPlaceholder', { title: t('space.spaceIcon') }) }]} + extra={t('common.logoTip')?.split('\n').map((vo, index) =>
{vo}
)} > - +
From 62b2ecdfc2eff71c372a03d5b1634cf16693f703 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 2 Mar 2026 14:41:58 +0800 Subject: [PATCH 035/164] feat(web): form add rules --- web/src/components/SearchInput/index.tsx | 2 + web/src/i18n/en.ts | 5 +- web/src/i18n/zh.ts | 5 +- web/src/utils/validator.ts | 50 +++++++++++++++++++ .../components/ApiKeyModal.tsx | 10 +++- web/src/views/ApplicationConfig/Agent.tsx | 10 ++-- .../components/ApplicationModal.tsx | 8 ++- .../components/MemberModal.tsx | 6 ++- .../components/MemoryForm.tsx | 8 ++- .../components/CustomModelModal.tsx | 20 ++++++-- .../components/GroupModelModal.tsx | 23 +++++++-- web/src/views/ModelManagement/index.tsx | 1 + .../components/OntologyClassExtractModal.tsx | 5 +- .../components/OntologyClassModal.tsx | 8 ++- .../Ontology/components/OntologyModal.tsx | 8 ++- web/src/views/Skills/pages/SkillConfig.tsx | 10 +++- web/src/views/Skills/types.ts | 2 + .../SpaceManagement/components/SpaceModal.tsx | 12 +++-- 18 files changed, 165 insertions(+), 28 deletions(-) create mode 100644 web/src/utils/validator.ts diff --git a/web/src/components/SearchInput/index.tsx b/web/src/components/SearchInput/index.tsx index 32a64310..476c2cbb 100644 --- a/web/src/components/SearchInput/index.tsx +++ b/web/src/components/SearchInput/index.tsx @@ -41,6 +41,8 @@ interface SearchInputProps { className?: string; /** Input size */ size?: InputProps['size'] + /** Maximum length of the input value */ + maxLength?: number; } /** Search input component with debounce and throttle support */ diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 02add0ec..c5c81f13 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -453,6 +453,8 @@ export const en = { prevStep: 'Previous Step', exportSuccess: 'Export successful', recommend: 'Recommend', + logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`, + imageSquareRequired: 'Please upload a square image', }, model: { searchPlaceholder: 'search model…', @@ -542,7 +544,8 @@ export const en = { ollama: "Ollama", xinference: "Xinference", gpustack: "Gpustack", - bedrock: "Bedrock" + bedrock: "Bedrock", + nameInvalid: 'Model name can only contain letters, numbers, underscores and spaces, cannot be empty or pure whitespace', }, modelNew: { group: 'Model Group', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 06abf63a..f63e5981 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1029,6 +1029,8 @@ export const zh = { prevStep: '上一步', exportSuccess: '导出成功', recommend: '推荐', + logoTip: `支持图片格式(JPG、PNG)\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`, + imageSquareRequired: '请上传正方形比例图片', }, model: { searchPlaceholder: '搜索模型…', @@ -1176,7 +1178,8 @@ export const zh = { ollama: "Ollama", xinference: "Xinference", gpustack: "Gpustack", - bedrock: "Bedrock" + bedrock: "Bedrock", + nameInvalid: '模型名称只能包含字母、数字、下划线和空格, 不能为空或纯空格', }, timezones: { 'Asia/Shanghai': '中国标准时间 (UTC+8)', diff --git a/web/src/utils/validator.ts b/web/src/utils/validator.ts new file mode 100644 index 00000000..c55c52ca --- /dev/null +++ b/web/src/utils/validator.ts @@ -0,0 +1,50 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-03-02 13:46:53 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-02 14:38:33 + */ +/** + * Form validation utilities + */ + +interface UploadFile { + originFileObj: Blob; + [key: string]: unknown; +} + +/** + * Validate if uploaded image is square (width === height) + * @param errorMessage - Error message to display when validation fails + * @returns Ant Design form validator + */ +export const validateSquareImage = (errorMessage: string = 'Image must be square') => { + return (_: unknown, value: UploadFile | UploadFile[] | undefined) => { + if (!value || (Array.isArray(value) && value.length === 0)) { + return Promise.resolve(); + } + + const file = Array.isArray(value) ? value[0] : value; + + if (file?.originFileObj) { + return new Promise((resolve, reject) => { + const img = new Image(); + img.onload = () => { + if (img.width === img.height) { + resolve(); + } else { + reject(new Error(errorMessage)); + } + }; + img.onerror = () => reject(new Error('Failed to load image')); + img.src = URL.createObjectURL(file.originFileObj); + }); + } + + return Promise.resolve(); + }; +}; + +// - Cannot be empty or pure whitespace +// - Cannot start with a space +export const stringRegExp = /^[a-zA-Z0-9\u4e00-\u9fa5][a-zA-Z0-9\u4e00-\u9fa5\s]*$/ \ No newline at end of file diff --git a/web/src/views/ApiKeyManagement/components/ApiKeyModal.tsx b/web/src/views/ApiKeyManagement/components/ApiKeyModal.tsx index 9395df43..05e73992 100644 --- a/web/src/views/ApiKeyManagement/components/ApiKeyModal.tsx +++ b/web/src/views/ApiKeyManagement/components/ApiKeyModal.tsx @@ -12,6 +12,7 @@ import dayjs from 'dayjs' import type { ApiKey, ApiKeyModalRef } from '../types'; import RbModal from '@/components/RbModal' import { createApiKey, updateApiKey } from '@/api/apiKey'; +import { stringRegExp } from '@/utils/validator'; const FormItem = Form.Item; @@ -78,7 +79,7 @@ const ApiKeyModal = forwardRef(({ form.validateFields() .then((values) => { const { memory, rag, expires_at, ...rest } = values - let scopes = [] + const scopes = [] if (memory) { scopes.push('memory') @@ -130,7 +131,11 @@ const ApiKeyModal = forwardRef(({ @@ -138,6 +143,7 @@ const ApiKeyModal = forwardRef(({ diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 2ece4b6e..4bee291b 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -169,8 +169,8 @@ const Agent = forwardRef((_props, ref) => { getApplicationConfig(id as string).then(res => { const response = res as Config const { skills, variables } = response - let allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : [] - let allTools = Array.isArray(response.tools) ? response.tools : [] + const allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : [] + const allTools = Array.isArray(response.tools) ? response.tools : [] const memoryContent = response.memory?.memory_config_id const parsedMemoryContent = memoryContent === null || memoryContent === '' ? undefined @@ -431,7 +431,11 @@ const Agent = forwardRef((_props, ref) => {
- + ( diff --git a/web/src/views/MemberManagement/components/MemberModal.tsx b/web/src/views/MemberManagement/components/MemberModal.tsx index 002c8632..e16c60ba 100644 --- a/web/src/views/MemberManagement/components/MemberModal.tsx +++ b/web/src/views/MemberManagement/components/MemberModal.tsx @@ -152,7 +152,11 @@ const MemberModal = forwardRef(({ diff --git a/web/src/views/MemoryManagement/components/MemoryForm.tsx b/web/src/views/MemoryManagement/components/MemoryForm.tsx index 93246ca9..282199af 100644 --- a/web/src/views/MemoryManagement/components/MemoryForm.tsx +++ b/web/src/views/MemoryManagement/components/MemoryForm.tsx @@ -18,6 +18,7 @@ import RbModal from '@/components/RbModal' import { createMemoryConfig, updateMemoryConfig } from '@/api/memory' import { getOntologyScenesSimpleUrl } from '@/api/ontology' import CustomSelect from '@/components/CustomSelect'; +import { stringRegExp } from '@/utils/validator'; const FormItem = Form.Item; @@ -110,7 +111,11 @@ const MemoryForm = forwardRef(({ @@ -118,6 +123,7 @@ const MemoryForm = forwardRef(({ diff --git a/web/src/views/ModelManagement/components/CustomModelModal.tsx b/web/src/views/ModelManagement/components/CustomModelModal.tsx index 17373a02..112534a5 100644 --- a/web/src/views/ModelManagement/components/CustomModelModal.tsx +++ b/web/src/views/ModelManagement/components/CustomModelModal.tsx @@ -20,6 +20,7 @@ import CustomSelect from '@/components/CustomSelect' import UploadImages from '@/components/Upload/UploadImages' import { updateCustomModel, addCustomModel, modelTypeUrl, modelProviderUrl } from '@/api/models' import { getFileLink } from '@/api/fileStorage' +import { validateSquareImage, stringRegExp } from '@/utils/validator' /** * Custom model modal component @@ -65,7 +66,7 @@ const CustomModelModal = forwardRef( const res = isEdit ? updateCustomModel(model.id, rest) : addCustomModel(data) res.then(() => { - refresh && refresh(isEdit) + refresh?.(isEdit) handleClose() message.success(isEdit ? t('common.updateSuccess') : t('common.createSuccess')) }) @@ -79,7 +80,7 @@ const CustomModelModal = forwardRef( .validateFields() .then((values) => { const { logo, ...rest } = values; - let formData: CustomModelForm = { + const formData: CustomModelForm = { ...rest } @@ -125,14 +126,22 @@ const CustomModelModal = forwardRef( name="logo" label={t('modelNew.logo')} valuePropName="fileList" - rules={[{ required: true, message: t('common.pleaseSelect') }]} + rules={[ + { required: true, message: t('common.pleaseSelect') }, + { validator: validateSquareImage(t('common.imageSquareRequired')) } + ]} + extra={t('common.logoTip')?.split('\n').map((vo, index) =>
{vo}
)} > - +
@@ -166,6 +175,7 @@ const CustomModelModal = forwardRef( diff --git a/web/src/views/ModelManagement/components/GroupModelModal.tsx b/web/src/views/ModelManagement/components/GroupModelModal.tsx index efcd77f6..5ca46548 100644 --- a/web/src/views/ModelManagement/components/GroupModelModal.tsx +++ b/web/src/views/ModelManagement/components/GroupModelModal.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 16:49:33 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 16:49:33 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-02 12:23:13 */ /** * Group Model Modal @@ -21,6 +21,7 @@ import { updateCompositeModel, modelTypeUrl, addCompositeModel } from '@/api/mod import UploadImages from '@/components/Upload/UploadImages' import ModelImplement from './ModelImplement' import { getFileLink } from '@/api/fileStorage' +import { validateSquareImage, stringRegExp } from '@/utils/validator' /** * Group model modal component @@ -133,15 +134,26 @@ const GroupModelModal = forwardRef(({ name="logo" label={t('modelNew.logo')} valuePropName="fileList" - rules={[{ required: true, message: t('common.pleaseSelect') }]} + rules={[ + { required: true, message: t('common.pleaseSelect') }, + { validator: validateSquareImage(t('common.imageSquareRequired')) } + ]} + extra={t('common.logoTip')?.split('\n').map((vo, index) =>
{vo}
)} > - +
@@ -165,6 +177,7 @@ const GroupModelModal = forwardRef(({ diff --git a/web/src/views/ModelManagement/index.tsx b/web/src/views/ModelManagement/index.tsx index 35d7d864..539ff5e3 100644 --- a/web/src/views/ModelManagement/index.tsx +++ b/web/src/views/ModelManagement/index.tsx @@ -121,6 +121,7 @@ const tabKeys = ['group', 'list', 'square'] {activeTab !== 'list' && diff --git a/web/src/views/Ontology/components/OntologyClassExtractModal.tsx b/web/src/views/Ontology/components/OntologyClassExtractModal.tsx index 802202ef..2fd305c6 100644 --- a/web/src/views/Ontology/components/OntologyClassExtractModal.tsx +++ b/web/src/views/Ontology/components/OntologyClassExtractModal.tsx @@ -182,7 +182,10 @@ const OntologyClassExtractModal = forwardRef diff --git a/web/src/views/Ontology/components/OntologyClassModal.tsx b/web/src/views/Ontology/components/OntologyClassModal.tsx index 087e542c..a467294e 100644 --- a/web/src/views/Ontology/components/OntologyClassModal.tsx +++ b/web/src/views/Ontology/components/OntologyClassModal.tsx @@ -11,6 +11,7 @@ import { useTranslation } from 'react-i18next'; import type { AddClassItem, OntologyClassModalRef } from '../types' import RbModal from '@/components/RbModal' import { createOntologyClass } from '@/api/ontology' +import { stringRegExp } from '@/utils/validator'; const FormItem = Form.Item; @@ -105,7 +106,11 @@ const OntologyClassModal = forwardRef @@ -113,6 +118,7 @@ const OntologyClassModal = forwardRef diff --git a/web/src/views/Ontology/components/OntologyModal.tsx b/web/src/views/Ontology/components/OntologyModal.tsx index a4c203ed..92e94bb6 100644 --- a/web/src/views/Ontology/components/OntologyModal.tsx +++ b/web/src/views/Ontology/components/OntologyModal.tsx @@ -11,6 +11,7 @@ import { useTranslation } from 'react-i18next'; import type { OntologyItem, OntologyModalData, OntologyModalRef } from '../types' import RbModal from '@/components/RbModal' import { createOntologyScene, updateOntologyScene } from '@/api/ontology' +import { stringRegExp } from '@/utils/validator'; const FormItem = Form.Item; @@ -109,7 +110,11 @@ const OntologyModal = forwardRef(({ @@ -117,6 +122,7 @@ const OntologyModal = forwardRef(({ diff --git a/web/src/views/Skills/pages/SkillConfig.tsx b/web/src/views/Skills/pages/SkillConfig.tsx index 6e12e72d..f9f76dea 100644 --- a/web/src/views/Skills/pages/SkillConfig.tsx +++ b/web/src/views/Skills/pages/SkillConfig.tsx @@ -17,6 +17,7 @@ import type { AiPromptModalRef } from '@/views/ApplicationConfig/types' import exitIcon from '@/assets/images/knowledgeBase/exit.png'; import type { SkillFormData } from '../types' import { getSkillDetail, createSkill, updateSkill } from '@/api/skill' +import { stringRegExp } from '@/utils/validator'; /** * Skill Configuration Page Component @@ -110,7 +111,7 @@ const SkillConfig: FC = () => { // Format tools data for API const formData = { ...rest, - tools: tools?.map((item: any) => ({ + tools: tools?.map((item) => ({ tool_id: item.tool_id, operation: item.operation })) @@ -144,13 +145,18 @@ const SkillConfig: FC = () => { diff --git a/web/src/views/Skills/types.ts b/web/src/views/Skills/types.ts index 950bfb03..c0df3fbf 100644 --- a/web/src/views/Skills/types.ts +++ b/web/src/views/Skills/types.ts @@ -17,6 +17,8 @@ export interface SkillFormData { tools: Array<{ /** Tool identifier */ tool_id: string; + /** Tool operation/action */ + operation?: string; }>; /** Skill configuration settings */ config: { diff --git a/web/src/views/SpaceManagement/components/SpaceModal.tsx b/web/src/views/SpaceManagement/components/SpaceModal.tsx index 4f37b246..5a639244 100644 --- a/web/src/views/SpaceManagement/components/SpaceModal.tsx +++ b/web/src/views/SpaceManagement/components/SpaceModal.tsx @@ -23,6 +23,7 @@ import UploadImages from '@/components/Upload/UploadImages' import { getFileLink } from '@/api/fileStorage' import ragIcon from '@/assets/images/space/rag.png' import neo4jIcon from '@/assets/images/space/neo4j.png' +import { stringRegExp } from '@/utils/validator'; const FormItem = Form.Item; @@ -91,7 +92,7 @@ const SpaceModal = forwardRef(({ setCurrentStep(1) } else { const { icon, ...rest } = values - let formData: SpaceModalData = { + const formData: SpaceModalData = { ...rest } if (icon?.response?.data.file_id) { @@ -164,14 +165,19 @@ const SpaceModal = forwardRef(({ valuePropName="fileList" hidden={currentStep === 1} rules={[{ required: true, message: t('common.selectPlaceholder', { title: t('space.spaceIcon') }) }]} + extra={t('common.logoTip')?.split('\n').map((vo, index) =>
{vo}
)} > - +
From 9be1c01b70a4ea741c00803894ae03184151ec7d Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 2 Mar 2026 14:43:44 +0800 Subject: [PATCH 036/164] feat(web): chat content support scroll --- web/src/components/Chat/ChatContent.tsx | 35 ++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/web/src/components/Chat/ChatContent.tsx b/web/src/components/Chat/ChatContent.tsx index 32e6ae23..c1f5223c 100644 --- a/web/src/components/Chat/ChatContent.tsx +++ b/web/src/components/Chat/ChatContent.tsx @@ -27,12 +27,45 @@ const ChatContent: FC = ({ }) => { // Scroll container reference for controlling auto-scroll to bottom const scrollContainerRef = useRef<(HTMLDivElement | null)>(null) + const prevDataLengthRef = useRef(data.length); + const isScrolledToBottomRef = useRef(true); // Track if user is scrolled to bottom + + // Track scroll position to determine if user is at bottom + useEffect(() => { + const handleScroll = () => { + if (scrollContainerRef.current) { + const { scrollTop, scrollHeight, clientHeight } = scrollContainerRef.current; + // Consider user is at bottom if within 20px of the bottom + isScrolledToBottomRef.current = scrollHeight - scrollTop - clientHeight < 20; + } + }; + + const container = scrollContainerRef.current; + if (container) { + container.addEventListener('scroll', handleScroll); + // Initial check + handleScroll(); + } + + return () => { + if (container) { + container.removeEventListener('scroll', handleScroll); + } + }; + }, []); // Auto-scroll to bottom when data changes to show latest messages + // When data array length remains unchanged, if data is updated and user manually scrolled up, don't auto-scroll to bottom + // When data array length changes, auto-scroll to bottom + // If already scrolled to bottom, will auto-scroll to bottom useEffect(() => { setTimeout(() => { if (scrollContainerRef.current) { - scrollContainerRef.current.scrollTop = scrollContainerRef.current.scrollHeight; + // Auto-scroll if data length changed OR user is currently at bottom + if (data.length !== prevDataLengthRef.current || isScrolledToBottomRef.current) { + scrollContainerRef.current.scrollTop = scrollContainerRef.current.scrollHeight; + } + prevDataLengthRef.current = data.length; } }, 0); }, [data]) From 5cf2b087771af886debc1f681c5ee89f5e21028b Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 2 Mar 2026 14:52:51 +0800 Subject: [PATCH 037/164] fix(workflow): handle non-stream output field changes, add file type support to HTTP node, fix iteration node flattening bug --- .../core/workflow/adapters/dify/converter.py | 2 +- api/app/core/workflow/engine/variable_pool.py | 16 +++++++- .../core/workflow/nodes/cycle_graph/node.py | 2 +- .../core/workflow/nodes/http_request/node.py | 35 ++++++++++++---- .../workflow/variable/variable_objects.py | 41 ++++++++++++++----- api/app/services/workflow_service.py | 1 + 6 files changed, 76 insertions(+), 21 deletions(-) diff --git a/api/app/core/workflow/adapters/dify/converter.py b/api/app/core/workflow/adapters/dify/converter.py index 18beef15..798b78b7 100644 --- a/api/app/core/workflow/adapters/dify/converter.py +++ b/api/app/core/workflow/adapters/dify/converter.py @@ -671,4 +671,4 @@ class DifyConverter(BaseConverter): type=ExceptionType.CONFIG, detail=f"Please reconfigure the tool node.", )) - return {} \ No newline at end of file + return {} diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index 22be08c8..fd28eba5 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -73,7 +73,7 @@ class VariableStruct(BaseModel, Generic[T]): instance: The concrete variable object. The actual Python type is represented by the generic parameter ``T`` (e.g. StringVariable, - NumberVariable, ArrayObject[StringVariable]). + NumberVariable, ArrayVariable[StringVariable]). mut: Whether the variable is mutable. """ @@ -152,6 +152,20 @@ class VariablePool: return None return var_instance + def get_instance( + self, + selector: str, + default: Any = None, + strict: bool = True + ): + variable_struct = self._get_variable_struct(selector) + if variable_struct is None: + if strict: + raise KeyError(f"{selector} not exist") + return default + + return variable_struct.instance + def get_value( self, selector: str, diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index f2912e2c..71e0dbdb 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -66,7 +66,7 @@ class CycleGraphNode(BaseNode): if config.flatten: outputs['output'] = config.output_type else: - outputs['output'] = VariableType.ARRAY_STRING + outputs['output'] = VariableType.NESTED_ARRAY else: outputs['output'] = VariableType(f"array[{config.output_type}]") return outputs diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index df899940..e6c00eff 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import uuid from typing import Any, Callable, Coroutine import httpx @@ -13,6 +14,7 @@ from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable logger = logging.getLogger(__file__) @@ -115,7 +117,7 @@ class HttpRequestNode(BaseNode): params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool) return params - def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]: + async def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]: """ Build HTTP request body arguments for httpx request methods. @@ -135,16 +137,35 @@ class HttpRequestNode(BaseNode): )) case HttpContentType.FROM_DATA: data = {} + content["files"] = {} for item in self.typed_config.body.data: if item.type == "text": - data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool) + data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, + variable_pool) elif item.type == "file": - # TODO: File support (Feature) - pass + content["files"][self._render_template(item.key, variable_pool)] = ( + uuid.uuid4().hex, + await variable_pool.get_instance(item.value).get_content() + ) content["data"] = data case HttpContentType.BINARY: - # TODO: File support (Feature) - pass + content["files"] = [] + file_instence = variable_pool.get_instance(self.typed_config.body.data) + if isinstance(file_instence, ArrayVariable): + for v in file_instence.value: + if isinstance(v, FileVariable): + content["files"].append( + ( + "files", (uuid.uuid4().hex, await v.get_content()) + ) + ) + elif isinstance(file_instence, FileVariable): + content["files"].append( + ( + "file", (uuid.uuid4().hex, await file_instence.get_content()) + ) + ) + case HttpContentType.WWW_FORM: content["data"] = json.loads(self._render_template( json.dumps(self.typed_config.body.data), variable_pool @@ -207,7 +228,7 @@ class HttpRequestNode(BaseNode): request_func = self._get_client_method(client) resp = await request_func( url=self._render_template(self.typed_config.url, variable_pool), - **self._build_content(variable_pool) + **(await self._build_content(variable_pool)) ) resp.raise_for_status() logger.info(f"Node {self.node_id}: HTTP request succeeded") diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py index 7a39835c..49541afc 100644 --- a/api/app/core/workflow/variable/variable_objects.py +++ b/api/app/core/workflow/variable/variable_objects.py @@ -1,8 +1,10 @@ from typing import Any, TypeVar, Type, Generic +import httpx from deprecated import deprecated from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType +from app.core.config import settings T = TypeVar("T", bound=BaseVariable) @@ -80,8 +82,23 @@ class FileVariable(BaseVariable): def get_value(self) -> Any: return self.value.model_dump() + async def get_content(self): + total_bytes = 0 + chunks = [] -class ArrayObject(BaseVariable, Generic[T]): + async with httpx.AsyncClient() as client: + async with client.stream("GET", self.value.url) as resp: + resp.raise_for_status() + async for chunk in resp.aiter_bytes(8192): + total_bytes += len(chunk) + if total_bytes > settings.MAX_FILE_SIZE: + raise ValueError(f"File too large: {total_bytes} bytes") + chunks.append(chunk) + + return b"".join(chunks) + + +class ArrayVariable(BaseVariable, Generic[T]): type = 'array' def __init__(self, child_type: Type[T], value: list[Any]): @@ -108,7 +125,7 @@ class ArrayObject(BaseVariable, Generic[T]): return [v.get_value() for v in self.value] -class NestedArrayObject(BaseVariable): +class NestedArrayVariable(BaseVariable): type = 'array_nest' def valid_value(self, value: list[T]) -> list[T]: @@ -116,23 +133,23 @@ class NestedArrayObject(BaseVariable): raise TypeError(f"Value must be a list - {type(value)}:{value}") final_value = [] for v in value: - if not isinstance(v, ArrayObject): + if not isinstance(v, list): raise TypeError("All elements must be of type list") - final_value.append(v) + final_value.append(make_array(AnyVariable, v)) return final_value def to_literal(self) -> str: - return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value]) + return "\n".join(["\n".join([str(item) for item in row.get_value()]) for row in self.value]) def get_value(self) -> Any: - return [[item.get_value() for item in row] for row in self.value] + return [[item for item in row.get_value()] for row in self.value] @deprecated( reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.", category=RuntimeWarning ) -class AnyObject(BaseVariable): +class AnyVariable(BaseVariable): type = 'any' def valid_value(self, value: Any) -> Any: @@ -142,10 +159,10 @@ class AnyObject(BaseVariable): return str(self.value) -def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]: - """简化 ArrayObject 创建,不需要重复写类型""" +def make_array(child_type: Type[T], value: list[Any]) -> ArrayVariable[T]: + """简化 ArrayVariable 创建,不需要重复写类型""" - return ArrayObject(child_type, value) + return ArrayVariable(child_type, value) def create_variable_instance(var_type: VariableType, value: Any) -> T: @@ -168,7 +185,9 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T: return make_array(DictVariable, value) case VariableType.ARRAY_FILE: return make_array(FileVariable, value) + case VariableType.NESTED_ARRAY: + return NestedArrayVariable(value) case VariableType.ANY: - return AnyObject(value) + return AnyVariable(value) case _: raise TypeError(f"Invalid type - {var_type}") diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 188ef6cd..ffcf8b0c 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -580,6 +580,7 @@ class WorkflowService: # "variables": result.get("variables"), # "messages": result.get("messages"), "output": result.get("output"), # 最终输出(字符串) + "message": result.get("output"), # 最终输出(字符串) # "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID "error_message": result.get("error"), From 9962a61c21ecf54c7040f6a34bcef8807699ff7f Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 2 Mar 2026 15:54:35 +0800 Subject: [PATCH 038/164] feat(web): update app api --- web/src/views/ApplicationConfig/Api.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/views/ApplicationConfig/Api.tsx b/web/src/views/ApplicationConfig/Api.tsx index c4b0fefb..22cec3e8 100644 --- a/web/src/views/ApplicationConfig/Api.tsx +++ b/web/src/views/ApplicationConfig/Api.tsx @@ -29,7 +29,7 @@ const Api: FC<{ application: Application | null }> = ({ application }) => { const { t } = useTranslation(); const activeMethods = ['POST']; const { message, modal } = App.useApp() - const copyContent = window.location.origin + '/v1/chat' + const copyContent = window.location.origin + '/v1/app/chat' const apiKeyModalRef = useRef(null); const apiKeyConfigModalRef = useRef(null); const [apiKeyList, setApiKeyList] = useState([]) From 5abfcdfbe8c282f2389a79753b2217526010328e Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 2 Mar 2026 17:07:29 +0800 Subject: [PATCH 039/164] feat(web): add unknown node --- web/src/assets/images/workflow/unknown.svg | 26 +++++++ web/src/i18n/en.ts | 1 + web/src/i18n/zh.ts | 7 +- .../Workflow/components/Properties/index.tsx | 73 ++++++++++++++++--- web/src/views/Workflow/constant.ts | 5 ++ .../views/Workflow/hooks/useWorkflowGraph.ts | 4 +- 6 files changed, 104 insertions(+), 12 deletions(-) create mode 100644 web/src/assets/images/workflow/unknown.svg diff --git a/web/src/assets/images/workflow/unknown.svg b/web/src/assets/images/workflow/unknown.svg new file mode 100644 index 00000000..4c8198dd --- /dev/null +++ b/web/src/assets/images/workflow/unknown.svg @@ -0,0 +1,26 @@ + + + 未知节点 + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 9df6d018..fe38734d 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1980,6 +1980,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re evolutionAndGovernance: 'Evolution & Governance', self_optimization: 'Self Optimization', process_evolution: 'Process Evolution', + unknown: 'Unknown Node', clickToConfigure: 'Click to configure node parameters', nodeProperties: 'Node Properties', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 3fe37ea8..19ff6d53 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1977,6 +1977,7 @@ export const zh = { evolutionAndGovernance: '演化与治理', self_optimization: '自我优化', process_evolution: '流程演化', + unknown: '未知节点', clickToConfigure: '点击配置节点参数', nodeProperties: '节点属性', @@ -2164,6 +2165,9 @@ export const zh = { output_variables: '输出变量', refreshTip: '同步函数签名至代码', }, + unknown: { + replaceNodeType: '替换节点' + }, name: '键', type: '类型', value: '值', @@ -2195,7 +2199,8 @@ export const zh = { iteration: '迭代', input_cycle_vars: '初始循环变量', output_cycle_vars: '最终循环变量', - } + }, + sureReplace: '确认替换', }, emotionEngine: { emotionEngineConfig: '情感引擎配置', diff --git a/web/src/views/Workflow/components/Properties/index.tsx b/web/src/views/Workflow/components/Properties/index.tsx index e96b1757..76fc9ad0 100644 --- a/web/src/views/Workflow/components/Properties/index.tsx +++ b/web/src/views/Workflow/components/Properties/index.tsx @@ -1,14 +1,14 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 15:39:59 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-11 12:07:06 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-02 17:06:41 */ import { type FC, useEffect, useState, useMemo } from "react"; import clsx from 'clsx' import { useTranslation } from 'react-i18next' import { Graph, Node } from '@antv/x6'; -import { Form, Input, Select, InputNumber, Switch, Divider, Space } from 'antd' +import { Form, Input, Select, InputNumber, Switch, Divider, Space, Button } from 'antd' import { CaretDownOutlined, CaretRightOutlined } from '@ant-design/icons'; import type { NodeConfig, NodeProperties, ChatVariable } from '../../types' @@ -36,6 +36,7 @@ import Editor, { type LexicalEditorProps } from "../Editor"; import RbSlider from './RbSlider' import JinjaRender from './JinjaRender' import CodeExecution from './CodeExecution' +import { nodeLibrary } from '../../constant'; /** * Props for Properties component @@ -69,7 +70,8 @@ interface PropertiesProps { const Properties: FC = ({ selectedNode, graphRef, - chatVariables + chatVariables, + blankClick }) => { const { t } = useTranslation() const [form] = Form.useForm(); @@ -80,9 +82,8 @@ const Properties: FC = ({ useEffect(() => { if (selectedNode?.getData()?.id) { setOutputCollapsed(true) - } else { - form.resetFields() } + form.resetFields() }, [selectedNode?.getData()?.id]) useEffect(() => { @@ -94,7 +95,7 @@ const Properties: FC = ({ initialValue[key] = config[key].defaultValue } }) - + form.setFieldsValue({ type, id: selectedNode.id, @@ -380,6 +381,41 @@ const Properties: FC = ({ } } console.log('variableList', variableList, currentNodeVariables) + const handleSureReplace = () => { + const { replaceNode } = values; + const nodeLibraryConfig = [...nodeLibrary] + .flatMap(category => category.nodes) + .find(n => n.type === replaceNode) + + if (replaceNode && nodeLibraryConfig) { + // Preserve existing config values when switching node types + const currentData = selectedNode?.data || {}; + const currentConfig = currentData.config || {}; + const newConfig = nodeLibraryConfig.config || {}; + + // Merge configs: keep existing values for matching keys, add new keys from template + const mergedConfig: Record = {}; + Object.keys(newConfig).forEach(key => { + if (currentConfig[key] && currentConfig[key].defaultValue !== undefined) { + // Preserve existing value if it exists + mergedConfig[key] = { + ...newConfig[key], + defaultValue: currentConfig[key].defaultValue + }; + } else { + // Use new config template + mergedConfig[key] = { ...newConfig[key] }; + } + }); + + selectedNode?.setData({ + ...currentData, + ...nodeLibraryConfig, + config: mergedConfig + }) + blankClick() + } + } return (
@@ -399,8 +435,27 @@ const Properties: FC = ({ - - {selectedNode?.data?.type === 'http-request' + {selectedNode?.data?.type === 'unknown' + ? <> + + ({ @@ -270,16 +277,17 @@ const UploadWorkflowModal = forwardRef - + { - console.log('文件列表变化:', fileList); - }} /> diff --git a/web/src/views/ApplicationManagement/index.tsx b/web/src/views/ApplicationManagement/index.tsx index ca9888de..055c0c8f 100644 --- a/web/src/views/ApplicationManagement/index.tsx +++ b/web/src/views/ApplicationManagement/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:34:12 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-09 13:52:22 + * @Last Modified time: 2026-03-02 17:48:51 */ /** * Application Management Page @@ -12,7 +12,7 @@ import React, { useState, useRef, useEffect } from 'react'; import { useTranslation } from 'react-i18next'; -import { Button, Row, Col, App, Select, Space } from 'antd'; +import { Button, Row, Col, App, Select, Space, Dropdown } from 'antd'; import clsx from 'clsx'; import { DeleteOutlined } from '@ant-design/icons'; import { useSearchParams } from 'react-router-dom' @@ -86,6 +86,13 @@ const ApplicationManagement: React.FC = () => { const handleImport = () => { uploadWorkflowModalRef.current?.handleOpen() } + const handleClick = ({ key }: { key: string } ) => { + switch (key) { + case 'thirdParty': + handleImport() + break; + } + } return ( <> @@ -111,9 +118,16 @@ const ApplicationManagement: React.FC = () => { - + + + From 6cebddf8938933e91ce29a9518cad4f4da99bae7 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 2 Mar 2026 18:14:36 +0800 Subject: [PATCH 043/164] feat(web): workflow runtime add error info --- web/src/components/Chat/types.ts | 1 + .../views/Workflow/components/Chat/Chat.tsx | 3 ++- .../Workflow/components/Chat/Runtime.tsx | 22 ++++++++++++------- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/web/src/components/Chat/types.ts b/web/src/components/Chat/types.ts index 96e8e284..0cf1b130 100644 --- a/web/src/components/Chat/types.ts +++ b/web/src/components/Chat/types.ts @@ -23,6 +23,7 @@ export interface ChatItem { status?: string; subContent?: Record[]; files?: any[]; + error?: string; } /** diff --git a/web/src/views/Workflow/components/Chat/Chat.tsx b/web/src/views/Workflow/components/Chat/Chat.tsx index 8ca6efac..65989b30 100644 --- a/web/src/views/Workflow/components/Chat/Chat.tsx +++ b/web/src/views/Workflow/components/Chat/Chat.tsx @@ -320,7 +320,8 @@ const Chat = forwardRef(({ appId newList[lastIndex] = { ...newList[lastIndex], status, - content: newList[lastIndex].content === '' ? null : newList[lastIndex].content + error, + content: newList[lastIndex].content === '' ? null : newList[lastIndex].content, } } return newList diff --git a/web/src/views/Workflow/components/Chat/Runtime.tsx b/web/src/views/Workflow/components/Chat/Runtime.tsx index e3608e10..e41531b0 100644 --- a/web/src/views/Workflow/components/Chat/Runtime.tsx +++ b/web/src/views/Workflow/components/Chat/Runtime.tsx @@ -217,14 +217,20 @@ const Runtime: FC<{ item: ChatItem; index: number;}> = ({ children: ( detail ? ( -
- - {renderDetailChild(detail.subContent)} -
- ) - : renderChild(item.subContent) +
+ + {renderDetailChild(detail.subContent)} +
+ ) + : <> + {item.error + ?
+ +
+ : renderChild(item.subContent) + } ) }]} /> From ce8a2cbe34db7c63c218c3fe3a2f9a0e1cb4cdc1 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 2 Mar 2026 18:32:19 +0800 Subject: [PATCH 044/164] feat(web): update file type --- .../ApplicationManagement/components/UploadWorkflowModal.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/views/ApplicationManagement/components/UploadWorkflowModal.tsx b/web/src/views/ApplicationManagement/components/UploadWorkflowModal.tsx index f09328a5..503ea55a 100644 --- a/web/src/views/ApplicationManagement/components/UploadWorkflowModal.tsx +++ b/web/src/views/ApplicationManagement/components/UploadWorkflowModal.tsx @@ -287,7 +287,7 @@ const UploadWorkflowModal = forwardRef
From c9a8753473bcae908c9fa77281634ea21b43dea5 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 2 Mar 2026 18:38:08 +0800 Subject: [PATCH 045/164] revert(web): revert file --- web/.gitignore | 8 +- web/vite.config.ts | 14 +- web/变量信息.md | 654 --------------------------------------------- 3 files changed, 3 insertions(+), 673 deletions(-) delete mode 100644 web/变量信息.md diff --git a/web/.gitignore b/web/.gitignore index 2a94c851..89a253b3 100644 --- a/web/.gitignore +++ b/web/.gitignore @@ -23,10 +23,4 @@ dist-ssr *.sln *.sw? vite.config.js -package-lock.json - -src/test/* -src/*/__tests__/* -vitest.config.ts -public/vitest-auto-imports.d.ts -package_test.json \ No newline at end of file +package-lock.json \ No newline at end of file diff --git a/web/vite.config.ts b/web/vite.config.ts index cf3f5013..88b3cd75 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -3,7 +3,6 @@ import react from '@vitejs/plugin-react' import { resolve } from 'path' import AutoImport from 'unplugin-auto-import/vite' import tailwindcss from '@tailwindcss/vite' -import svgr from 'vite-plugin-svgr' // https://vite.dev/config/ export default defineConfig({ @@ -12,15 +11,7 @@ export default defineConfig({ proxy: { // 主要API代理,支持 /api 和 /api/* 格式 '/api': { - // target: 'http://192.168.110.86:8000', // lxy - // target: 'http://192.168.110.25:8000', // xjn - // target: 'http://192.168.110.217:8000', // llq - target: 'http://192.168.110.111:8000', // myh - // target: 'https://devmemorybear.redbearai.com/', // 开发后端服务地址 - // target: 'https://devcopymemorybear.redbearai.com/', // 开发sass后端服务地址 - // target: 'https://testmemorybear.redbearai.com/', // 测试后端服务地址 - // target: 'https://memorybear.redbearai.com/', // 预发服务地址 - // target: 'https://cloud.memorybear.ai/', // AMAZON 生产地址 + target: 'http://0.0.0.0:5173', // 后端服务地址 changeOrigin: true, // 匹配所有以/api开头的请求,包括/api/token @@ -35,7 +26,6 @@ export default defineConfig({ }, plugins: [ tailwindcss(), - svgr({ svgrOptions: { icon: true } }), react(), AutoImport({ imports: ['react', 'react-router-dom'], @@ -98,4 +88,4 @@ export default defineConfig({ }, }, }, -}) +}) \ No newline at end of file diff --git a/web/变量信息.md b/web/变量信息.md deleted file mode 100644 index 008af6b7..00000000 --- a/web/变量信息.md +++ /dev/null @@ -1,654 +0,0 @@ -# 系统变量:需和开始节点拆分 - -# end: string/number/boolean/object/array[file]/array[object]/array[number]/array[string] - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - - parameter-extractor: __is_success / __reason - - memory-read: answer / intermediate_outputs - - - - question-classifier: class_name / output - - iteration: output - - loop: cycle_vars - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body / status_code - - tool: data - - jinja-render: output - -# llm: 不能选 boolean 类型 -## 上下文:string/number/array[file]/array[object]/array[string]/array[number]; 不要object - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - knowledge-retrieval: output - - parameter-extractor: __is_success / __reason - - memory-read: answer / intermediate_outputs - - - - question-classifier: class_name - - iteration: output - - loop: cycle_vars - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body / status_code - - tool: data - - code: output_variables - - jinja-render: output - -## 提示词: string/number/array[file]/array[number]/array[string]; 不要object,boolean - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - - parameter-extractor: __is_success / __reason - - memory-read: answer / intermediate_outputs - - - - question-classifier: class_name - - iteration: output - - loop: cycle_vars - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body / status_code - - tool: data - - code: output_variables - - jinja-render: output - -# knowledge-retrieval: string - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - - parameter-extractor: __reason - - memory-read: answer - - - - question-classifier: class_name - - - loop: cycle_vars - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body - - tool: data - - code: output_variables - - jinja-render: output - -# parameter-extractor: -## 输入变量: string - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - - parameter-extractor: __reason - - memory-read: answer - - - - question-classifier: class_name - - - loop: cycle_vars - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body - - tool: data - - code: output_variables - - jinja-render: output -## 指令:string/number - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - - parameter-extractor: __is_success / __reason - - memory-read: answer - - - - question-classifier: class_name - - - loop: cycle_vars - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body / status_code - - tool: data - - code: output_variables - - jinja-render: output - -# memory-read: string - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - - parameter-extractor: __reason - - memory-read: answer - - - - question-classifier: class_name - - - loop: cycle_vars - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body - - tool: data - - code: output_variables - - jinja-render: output - - -# memory-write: string - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - - parameter-extractor: __reason - - memory-read: answer - - - - question-classifier: class_name - - - loop: cycle_vars - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body - - tool: data - - code: output_variables - - jinja-render: output - -# if-else: boolean/string/number/array[file]/array[object]/array[string]/object - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - knowledge-retrieval: output - - parameter-extractor: __is_success / __reason - - memory-read: answer - - - - question-classifier: class_name - - iteration: output - - loop: cycle_vars - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body / status_code - - tool: data - - code: output_variables - - jinja-render: output - -# question-classifier -## 输入变量: string - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - - parameter-extractor: __reason - - memory-read: answer - - - - question-classifier: class_name - - - loop: cycle_vars - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body - - tool: data - - code: output_variables - - jinja-render: output -## 分类: string - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - - parameter-extractor: __reason - - memory-read: answer - - - - question-classifier: class_name - - - loop: cycle_vars - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body - - tool: data - - code: output_variables - - jinja-render: output - -# iteration -## 输入变量: array[file] | array[object] | array[string] | array[number] | array[boolean] - - - - - - knowledge-retrieval: output - - parameter-extractor: array类型的提取参数 params - - - - - - iteration: output - - loop: cycle_vars - - - - - - code: output_variables - -## 输出变量 - - 系统变量 - - - - - - - - - - - - - - - - - - - 子节点的输出变量 - - llm: output - - knowledge-retrieval: output - - parameter-extractor: __reason, params - - memory-read: answer - - memory-write - - - question-classifier: class_name - - - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body - - tool: data - - code: output_variables - - jinja-render: output - -# loop -## 循环变量 - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - knowledge-retrieval: output - - parameter-extractor: __is_success / __reason / params - - memory-read: answer - - - - question-classifier: class_name - - iteration: output - - loop: cycle_vars - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body / status_code - - tool: data - - code: output_variables - - jinja-render: output -## 循环终止条件 -### left - - 系统变量 - - 会话变量 - - - - - - - - - - - loop: cycle_vars 当前loop节点的 - - - - - - code: output_variables - - - 子节点的输出变量 - - llm: output - - knowledge-retrieval: output - - parameter-extractor: __reason - - memory-read: answer - - memory-write - - - question-classifier: class_name - - - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body - - tool: data - - jinja-render: output -### right: number - - 系统变量 - - 会话变量 - - start: variables / sys - - - - parameter-extractor: __is_success - - - - - - - loop: cycle_vars 当前loop节点的 - - - - http-request: status_code - - - code: output_variables - - -# var-aggregator: string/number/boolean - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - knowledge-retrieval: output - - parameter-extractor: __reason - - memory-read: answer - - - - question-classifier: class_name - - iteration: output - - loop: cycle_vars 当前loop节点的 - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body / status_code - - tool: data - - code: output_variables - - jinja-render: output - -# assigner -## variable_selector - - - 会话变量 - - - - - - - - - - - loop: cycle_vars 当前loop节点的 - - - - - - -## value - - - 会话变量 - - start: variables / sys - - llm: output - - knowledge-retrieval: output - - parameter-extractor: __reason / __is_success - - memory-read: answer - - - - question-classifier: class_name - - iteration: output - - loop: cycle_vars 当前loop节点的 - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body - - tool: data - - code: output_variables - - jinja-render: output - -# http-request -## url/headers/params: string/number - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - - parameter-extractor: __reason / __is_success - - memory-read: answer - - - - question-classifier: class_name - - - loop: cycle_vars 当前loop节点的 - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body / status_code - - tool: data - - code: output_variables - - jinja-render: output -## ['body', 'data'] -### body?.content_type = form-data/x-www-form-urlencoded/json/raw: string/number -### body?.content_type = binary: file/array[file] - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - - parameter-extractor: __reason / __is_success - - memory-read: answer - - - - question-classifier: class_name - - - loop: cycle_vars 当前loop节点的 - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body / status_code - - tool: data - - code: output_variables - - jinja-render: output - -# tool: 不需要 - -# jinja-render -## mappingList 输入变量 - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - knowledge-retrieval: output - - parameter-extractor: __reason / __is_success - - memory-read: answer - - - - question-classifier: class_name - - iteration: output - - loop: cycle_vars 当前loop节点的 - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body / status_code - - tool: data - - code: output_variables - - jinja-render: output - -# code -## input_variables - - 系统变量 - - 会话变量 - - start: variables / sys - - llm: output - - knowledge-retrieval: output - - parameter-extractor: __reason / __is_success - - memory-read: answer - - - - question-classifier: class_name - - iteration: output - - loop: cycle_vars 当前loop节点的 - - var-aggregator - - group = false 时,output - - group = true 时,group_variables - - - http-request: body / status_code - - tool: data - - jinja-render: output - - code: output_variables - - code: output_variables - - -# 迭代子节点 -- llm -- if-else -- parameter-extractor && prompt -- var-aggregator -- assigner -- http-request && body.content_type !== 'binary' -- tool -- jinja-render - - iteration: item / index - -- knowledge-retrieval -- parameter-extractor && !prompt -- memory-read -- memory-write -- question-classifier - - iteration的输入变量是array[string]时,可选item - -- iteration -- loop - - 不可添加此类节点 - -# 循环子节点 -- llm -- knowledge-retrieval -- parameter-extractor -- memory-read -- memory-write -- if-else -- question-classifier -- var-aggregator -- assigner -- http-request -- tool -- jinja-render - - loop: cycle_vars - -- iteration -- loop - - 不可添加此类节点 - - - -# TODO - -## 需要后端支持的需求 -1. 集群调试:对话过程数据输出【需后端】 - -3. 应用调试、分享增加变量配置【需后端】 -4. 应用导入导出,导出已完成,导入【需后端】 -6. 单个节点的运行【需后端】 -7. 列表 节点的配置【需后端】 -9. 对话支持附件(非图片)【需后端】 - -## 前端需求 -1. 工作流整理布局、历史撤销、回退 -2. 问题分类节点,分类中英文 -3. 感知记忆:文本类型增加片段展示 -- variableConfig -4. 工作流UI - - - - - 变量聚合器 -7. 记忆萃取 - - 本体场景不可编辑 -- rb:truncate -- 注释翻译 - - RbCard - - src/views/KnowledgeBase/index.tsx - - src/components/Upload/UploadFiles.tsx - - src/components/Chat - - -# 分支 - -## 0.2.6 -- feature/workflow_import_zy - - 工作流导入 | 导出 - - input_type: Constant / Variable 统一成小写 - - 结束节点内容被覆写 - - 增加未知节点 - - http 节点 - - 变量下拉列表替换成编辑器 - - body form-data file时,值支持选择sys.files -- feature/form_zy - - 表单校验规则 - - 流式输出时,向上滚动后,自动滚动到最底部的效果失效 - - 应用 API URL更新 -- feature/memory_zy - - 记忆萃取增加剪枝 - -## 20260212 -1. A2A 协议适配 -2. 日志跟踪系统 -3. Agent、集群、工作流共享 -4. 试运行、分享会话支持文件(包含语言、其他附件)【待联调】 -2. 导入 Agent、工作流 - 合并到应用管理创建方式 - - -a7da914dcbb80186b9aaf9ac4d21a9881e60ecb5 -e115353811b34de2fd359962860fdafe87fef503 \ No newline at end of file From aa733354e8f774571df3c0926934bfbccbea40ad Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 3 Mar 2026 10:14:36 +0800 Subject: [PATCH 046/164] fix(web): Editor input type add blur event --- .../components/Editor/commands/index.ts | 16 ++++- .../Workflow/components/Editor/index.tsx | 33 +++++++++-- .../Editor/plugin/AutocompletePlugin.tsx | 45 +++++++++++--- .../components/Editor/plugin/BlurPlugin.tsx | 59 +++++++++++++------ .../Properties/HttpRequest/EditableTable.tsx | 2 +- 5 files changed, 123 insertions(+), 32 deletions(-) diff --git a/web/src/views/Workflow/components/Editor/commands/index.ts b/web/src/views/Workflow/components/Editor/commands/index.ts index 7f30c46a..839e13f1 100644 --- a/web/src/views/Workflow/components/Editor/commands/index.ts +++ b/web/src/views/Workflow/components/Editor/commands/index.ts @@ -1,13 +1,25 @@ +/* + * @Author: ZhaoYing + * @Date: 2025-12-23 12:29:46 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-03 10:12:48 + */ import { createCommand, type LexicalCommand } from 'lexical'; import type { Suggestion } from '../plugin/AutocompletePlugin'; - +// Payload interface for inserting variable command export interface InsertVariableCommandPayload { data: Suggestion; } +// Command to insert a variable into the editor export const INSERT_VARIABLE_COMMAND: LexicalCommand = createCommand('INSERT_VARIABLE_COMMAND'); +// Command to clear all editor content export const CLEAR_EDITOR_COMMAND: LexicalCommand = createCommand('CLEAR_EDITOR_COMMAND'); -export const FOCUS_EDITOR_COMMAND: LexicalCommand = createCommand('FOCUS_EDITOR_COMMAND'); \ No newline at end of file +// Command to focus the editor +export const FOCUS_EDITOR_COMMAND: LexicalCommand = createCommand('FOCUS_EDITOR_COMMAND'); + +// Command to close the autocomplete dropdown +export const CLOSE_AUTOCOMPLETE_COMMAND: LexicalCommand = createCommand('CLOSE_AUTOCOMPLETE_COMMAND'); \ No newline at end of file diff --git a/web/src/views/Workflow/components/Editor/index.tsx b/web/src/views/Workflow/components/Editor/index.tsx index 5e376cc8..4707e20c 100644 --- a/web/src/views/Workflow/components/Editor/index.tsx +++ b/web/src/views/Workflow/components/Editor/index.tsx @@ -1,3 +1,9 @@ +/* + * @Author: ZhaoYing + * @Date: 2025-12-23 16:22:51 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-03 10:11:48 + */ import { type FC, useState, useEffect, useMemo } from 'react'; import { LexicalComposer } from '@lexical/react/LexicalComposer'; import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin'; @@ -19,6 +25,7 @@ import LineNumberPlugin from './plugin/LineNumberPlugin'; import BlurPlugin from './plugin/BlurPlugin'; import { VariableNode } from './nodes/VariableNode' +// Props interface for Lexical Editor component export interface LexicalEditorProps { placeholder?: string; value?: string; @@ -34,6 +41,7 @@ export interface LexicalEditorProps { className?: string; } +// Default theme for editor const theme = { paragraph: 'editor-paragraph', text: { @@ -42,6 +50,7 @@ const theme = { }, }; +// Theme with Jinja2 syntax highlighting const jinja2Theme = { ...theme, code: 'jinja2-expression', @@ -51,7 +60,8 @@ const jinja2Theme = { }, }; -const Editor: FC =({ +// Main Lexical Editor component +const Editor: FC =(({ placeholder = "请输入内容...", value = "", onChange, @@ -67,6 +77,7 @@ const Editor: FC =({ const [enableJinja2, setEnableJinja2] = useState(false) const [enableLineNumbers, setEnableLineNumbers] = useState(false) + // Setup Jinja2 mode and inject styles when language changes useEffect(() => { const needsLineNumbers = language === 'jinja2'; setEnableJinja2(language === 'jinja2'); @@ -139,11 +150,12 @@ const Editor: FC =({ } }, [language]) + // Lexical editor configuration const initialConfig = { namespace: 'AutocompleteEditor', theme: enableJinja2 ? jinja2Theme : theme, nodes: enableJinja2 ? [ - // 当启用jinja2时,不使用VariableNode,使用普通文本 + // When Jinja2 is enabled, use plain text instead of VariableNode ] : [ // HeadingNode, // QuoteNode, @@ -157,18 +169,26 @@ const Editor: FC =({ console.error(error); }, }; + + // Calculate minimum height based on type and size const minheight = useMemo(() => { if (type === 'input') { return `${height ? height : size === 'small' ? 28 : 30}px` } return `${height ? height : size === 'small' ? 60 : 120}px` }, [type, size, height]) + + // Calculate font size based on size prop const fontSize = useMemo(() => { return `${size === 'small' ? 12 : 14}px` }, [size]) + + // Calculate line height based on size prop const lineHeight = useMemo(() => { return `${height ? height : size === 'small' ? 16 : 20}px` }, [size]) + + // Calculate placeholder minimum height const placeHolderMinheight = useMemo(() => { return `${height ? height : size === 'small' ? 16 : 30}px` }, [type, size, height]) @@ -179,6 +199,7 @@ const Editor: FC =({ =({
) : ( + // Standard editor without line numbers =({ } ErrorBoundary={LexicalErrorBoundary} /> + {/* Editor plugins */} {language === 'jinja2' && } @@ -242,10 +265,10 @@ const Editor: FC =({ { setCount(count) }} onChange={onChange} /> - {enableJinja2 && } +
); -}; +}); -export default Editor; \ No newline at end of file +export default Editor; diff --git a/web/src/views/Workflow/components/Editor/plugin/AutocompletePlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/AutocompletePlugin.tsx index 8e2687f1..25ef511f 100644 --- a/web/src/views/Workflow/components/Editor/plugin/AutocompletePlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/AutocompletePlugin.tsx @@ -1,10 +1,17 @@ +/* + * @Author: ZhaoYing + * @Date: 2025-12-23 16:22:51 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-03 10:12:33 + */ import { useEffect, useState, type FC } from 'react'; import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; import { $getSelection, $isRangeSelection, $isTextNode, COMMAND_PRIORITY_HIGH, KEY_ENTER_COMMAND, KEY_ARROW_DOWN_COMMAND, KEY_ARROW_UP_COMMAND, KEY_ESCAPE_COMMAND } from 'lexical'; -import { INSERT_VARIABLE_COMMAND } from '../commands'; +import { INSERT_VARIABLE_COMMAND, CLOSE_AUTOCOMPLETE_COMMAND } from '../commands'; import type { NodeProperties } from '../../../types' +// Suggestion item interface for autocomplete dropdown export interface Suggestion { key: string; label: string; @@ -13,16 +20,18 @@ export interface Suggestion { value: string; group?: string nodeData: NodeProperties; - isContext?: boolean; // 标记是否为context变量 - disabled?: boolean; // 标记是否禁用 + isContext?: boolean; // Flag for context variable + disabled?: boolean; // Flag for disabled state } +// Autocomplete plugin for variable suggestions triggered by '/' character const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> = ({ options, enableJinja2 = false }) => { const [editor] = useLexicalComposerContext(); const [showSuggestions, setShowSuggestions] = useState(false); const [selectedIndex, setSelectedIndex] = useState(0); const [popupPosition, setPopupPosition] = useState({ top: 0, left: 0 }); + // Listen to editor updates and show suggestions when '/' is typed useEffect(() => { return editor.registerUpdateListener(({ editorState }) => { editorState.read(() => { @@ -49,6 +58,7 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> setSelectedIndex(0); } + // Calculate popup position to keep it within viewport bounds if (shouldShow) { const domSelection = window.getSelection(); if (domSelection && domSelection.rangeCount > 0) { @@ -84,9 +94,22 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> }); }, [editor]); + // Register command to close autocomplete popup + useEffect(() => { + return editor.registerCommand( + CLOSE_AUTOCOMPLETE_COMMAND, + () => { + setShowSuggestions(false); + return true; + }, + COMMAND_PRIORITY_HIGH + ); + }, [editor]); + + // Insert selected suggestion into editor const insertMention = (suggestion: Suggestion) => { if (enableJinja2) { - // 在jinja2模式下,插入{{variable}}格式的文本 + // In Jinja2 mode, insert {{variable}} format text editor.update(() => { const selection = $getSelection(); if ($isRangeSelection(selection)) { @@ -94,7 +117,7 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> const anchorOffset = selection.anchor.offset; const nodeText = anchorNode.getTextContent(); - // 移除触发字符'/' + // Remove trigger character '/' const textBefore = nodeText.substring(0, anchorOffset - 1); const textAfter = nodeText.substring(anchorOffset); const newText = textBefore + `{{${suggestion.value}}}` + textAfter; @@ -103,19 +126,20 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> anchorNode.setTextContent(newText); } - // 设置光标位置到插入文本之后 + // Set cursor position after inserted text const newOffset = textBefore.length + `{{${suggestion.value}}}`.length; selection.anchor.offset = newOffset; selection.focus.offset = newOffset; } }); } else { - // 普通模式下使用VariableNode + // In normal mode, use VariableNode editor.dispatchCommand(INSERT_VARIABLE_COMMAND, { data: suggestion }); } setShowSuggestions(false); }; + // Group suggestions by node ID const groupedSuggestions = options.reduce((groups: Record, suggestion) => { const { nodeData } = suggestion const nodeId = nodeData.id as string; @@ -126,6 +150,7 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> return groups; }, {}); + // Handle Enter key to select suggestion useEffect(() => { if (!showSuggestions) return; @@ -148,11 +173,13 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> ); }, [showSuggestions, selectedIndex, groupedSuggestions, insertMention, editor]); + // Handle keyboard navigation (Arrow Up/Down, Escape) useEffect(() => { if (!showSuggestions) return; const allOptions = Object.values(groupedSuggestions).flat(); + // Navigate down through suggestions, skip disabled items const unregisterArrowDown = editor.registerCommand( KEY_ARROW_DOWN_COMMAND, (event) => { @@ -172,6 +199,7 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> COMMAND_PRIORITY_HIGH ); + // Navigate up through suggestions, skip disabled items const unregisterArrowUp = editor.registerCommand( KEY_ARROW_UP_COMMAND, (event) => { @@ -191,6 +219,7 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> COMMAND_PRIORITY_HIGH ); + // Close suggestions on Escape key const unregisterEscape = editor.registerCommand( KEY_ESCAPE_COMMAND, (event) => { @@ -239,7 +268,9 @@ const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> const nodeName = nodeOptions[0]?.nodeData?.name || nodeId; return (
+ {/* Divider between groups */} {groupIndex > 0 &&
} + {/* Group header with node name */}
{nodeName}
diff --git a/web/src/views/Workflow/components/Editor/plugin/BlurPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/BlurPlugin.tsx index 0fb6c48f..13eb48b6 100644 --- a/web/src/views/Workflow/components/Editor/plugin/BlurPlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/BlurPlugin.tsx @@ -1,39 +1,64 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-01-20 10:42:13 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-03 10:12:10 + */ import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; import { useEffect } from 'react'; import { $setSelection } from 'lexical'; +import { CLOSE_AUTOCOMPLETE_COMMAND } from '../commands'; -export default function BlurPlugin() { +// Plugin to handle blur events and close autocomplete when clicking outside +export default function BlurPlugin({ enableJinja2 }: { enableJinja2: boolean }) { const [editor] = useLexicalComposerContext(); useEffect(() => { + // Close autocomplete when clicking outside the popup + const handleClickOutside = (e: MouseEvent) => { + const target = e.target as HTMLElement; + if (target?.closest('[data-autocomplete-popup="true"]')) { + return; + } + editor.dispatchCommand(CLOSE_AUTOCOMPLETE_COMMAND, undefined); + }; + + document.addEventListener('mousedown', handleClickOutside); + return editor.registerRootListener((rootElement) => { if (rootElement) { const handleBlur = (e: FocusEvent) => { - // 检查是否点击了自动完成弹窗 - const target = e.target as HTMLElement; - console.log('target', target) - if (target?.closest('[data-autocomplete-popup="true"]')) { - return; + if (enableJinja2) { + // Check if autocomplete popup was clicked + const target = e.target as HTMLElement; + if (target?.closest('[data-autocomplete-popup="true"]')) { + return; + } + + // Check if blur was caused by paste operation + const relatedTarget = e.relatedTarget as HTMLElement; + if (!relatedTarget || relatedTarget === document.body) { + return; + } + + // Clear selection on blur + editor.update(() => { + $setSelection(null); + }); } - - // 检查是否是粘贴操作导致的焦点变化 - const relatedTarget = e.relatedTarget as HTMLElement; - if (!relatedTarget || relatedTarget === document.body) { - return; - } - - editor.update(() => { - $setSelection(null); - }); }; rootElement.addEventListener('blur', handleBlur); return () => { + document.removeEventListener('mousedown', handleClickOutside); rootElement.removeEventListener('blur', handleBlur); }; } + return () => { + document.removeEventListener('mousedown', handleClickOutside); + }; }); - }, [editor]); + }, [editor, enableJinja2]); return null; } diff --git a/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx b/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx index ead15759..74593913 100644 --- a/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx +++ b/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx @@ -49,7 +49,7 @@ const EditableTable: FC = ({ const getColumns = (remove: (index: number) => void): TableProps['columns'] => { const hasType = typeOptions.length > 0; const cellClassName="rb:p-1!" - const contentClassName ="rb:w-[108px]! rb:text-[12px]!" + const contentClassName ="rb:w-[108px]! rb:text-[12px]! rb:overflow-hidden!" return [ { From aa7d52568b1bcaefa32b773d68101213364ed79e Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 3 Mar 2026 10:24:21 +0800 Subject: [PATCH 047/164] fix(web): change string regExp --- web/src/i18n/en.ts | 2 +- web/src/i18n/zh.ts | 2 +- web/src/utils/validator.ts | 5 ++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index c5c81f13..94dd67cc 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -455,6 +455,7 @@ export const en = { recommend: 'Recommend', logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`, imageSquareRequired: 'Please upload a square image', + nameInvalid: 'Name cannot start or end with a space', }, model: { searchPlaceholder: 'search model…', @@ -545,7 +546,6 @@ export const en = { xinference: "Xinference", gpustack: "Gpustack", bedrock: "Bedrock", - nameInvalid: 'Model name can only contain letters, numbers, underscores and spaces, cannot be empty or pure whitespace', }, modelNew: { group: 'Model Group', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index f63e5981..49632789 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1031,6 +1031,7 @@ export const zh = { recommend: '推荐', logoTip: `支持图片格式(JPG、PNG)\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`, imageSquareRequired: '请上传正方形比例图片', + nameInvalid: '不能是空格开头或结尾', }, model: { searchPlaceholder: '搜索模型…', @@ -1179,7 +1180,6 @@ export const zh = { xinference: "Xinference", gpustack: "Gpustack", bedrock: "Bedrock", - nameInvalid: '模型名称只能包含字母、数字、下划线和空格, 不能为空或纯空格', }, timezones: { 'Asia/Shanghai': '中国标准时间 (UTC+8)', diff --git a/web/src/utils/validator.ts b/web/src/utils/validator.ts index c55c52ca..650266ab 100644 --- a/web/src/utils/validator.ts +++ b/web/src/utils/validator.ts @@ -45,6 +45,5 @@ export const validateSquareImage = (errorMessage: string = 'Image must be square }; }; -// - Cannot be empty or pure whitespace -// - Cannot start with a space -export const stringRegExp = /^[a-zA-Z0-9\u4e00-\u9fa5][a-zA-Z0-9\u4e00-\u9fa5\s]*$/ \ No newline at end of file +// - Cannot start or end with a space +export const stringRegExp = /^(?!\s).*(? Date: Tue, 3 Mar 2026 10:27:01 +0800 Subject: [PATCH 048/164] feat(app): add API to retrieve app configuration fields --- .../controllers/public_share_controller.py | 62 ++- api/app/services/shared_chat_service.py | 383 +++++++++--------- api/app/services/workflow_service.py | 12 +- 3 files changed, 243 insertions(+), 214 deletions(-) diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 9f5f8075..3c634ae0 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -2,25 +2,32 @@ import hashlib import json import uuid from typing import Annotated + from fastapi import APIRouter, Depends, Query, Request from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger -from app.core.response_utils import success +from app.core.response_utils import success, fail from app.db import get_db, get_db_read from app.dependencies import get_share_user_id, ShareTokenData +from app.models.app_model import App +from app.models.app_model import AppType from app.repositories import knowledge_repository +from app.repositories.end_user_repository import EndUserRepository from app.repositories.workflow_repository import WorkflowConfigRepository from app.schemas import release_share_schema, conversation_schema from app.schemas.response_schema import PageData, PageMeta from app.services import workspace_service +from app.services.app_chat_service import AppChatService, get_app_chat_service from app.services.auth_service import create_access_token from app.services.conversation_service import ConversationService from app.services.release_share_service import ReleaseShareService from app.services.shared_chat_service import SharedChatService -from app.services.app_chat_service import AppChatService, get_app_chat_service -from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, \ +from app.services.workflow_service import WorkflowService +from app.utils.app_config_utils import workflow_config_4_app_release, \ agent_config_4_app_release, multi_agent_config_4_app_release router = APIRouter(prefix="/public/share", tags=["Public Share"]) @@ -206,15 +213,13 @@ def list_conversations( logger.debug(f"share_data:{share_data.user_id}") other_id = share_data.user_id service = SharedChatService(db) - share, release = service._get_release_by_share_token(share_data.share_token, password) - from app.repositories.end_user_repository import EndUserRepository + share, release = service.get_release_by_share_token(share_data.share_token, password) end_user_repo = EndUserRepository(db) new_end_user = end_user_repo.get_or_create_end_user( app_id=share.app_id, other_id=other_id ) logger.debug(new_end_user.id) - service = SharedChatService(db) conversations, total = service.list_conversations( share_token=share_data.share_token, user_id=str(new_end_user.id), @@ -293,19 +298,15 @@ async def chat( # 提前验证和准备(在流式响应开始前完成) # 这样可以确保错误能正确返回,而不是在流式响应中间出错 - from app.models.app_model import AppType + try: - from app.core.exceptions import BusinessException - from app.core.error_codes import BizCode - from app.services.app_service import AppService # 验证分享链接和密码 - share, release = service._get_release_by_share_token(share_token, password) + share, release = service.get_release_by_share_token(share_token, password) # # Create end_user_id by concatenating app_id with user_id # end_user_id = f"{share.app_id}_{user_id}" # Store end_user_id in database with original user_id - from app.repositories.end_user_repository import EndUserRepository end_user_repo = EndUserRepository(db) new_end_user = end_user_repo.get_or_create_end_user( app_id=share.app_id, @@ -318,7 +319,6 @@ async def chat( """获取存储类型和工作空间的ID""" # 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用) - from app.models.app_model import App app = db.query(App).filter( App.id == appid, App.is_active.is_(True) @@ -359,12 +359,12 @@ async def chat( app_type = release.app.type if release.app else None # 根据应用类型验证配置 - if app_type == "agent": + if app_type == AppType.AGENT: # Agent 类型:验证模型配置 model_config_id = release.default_model_config_id if not model_config_id: raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING) - elif app_type == "multi_agent": + elif app_type == AppType.MULTI_AGENT: # Multi-Agent 类型:验证多 Agent 配置 config = release.config or {} if not config.get("sub_agents"): @@ -638,6 +638,34 @@ async def chat( # return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) else: - from app.core.exceptions import BusinessException - from app.core.error_codes import BizCode raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) + + +@router.get("/config", summary="获取应用启动配置") +async def config_query( + password: str = Query(None, description="访问密码"), + share_data: ShareTokenData = Depends(get_share_user_id), + db: Session = Depends(get_db), +): + share_service = SharedChatService(db) + share_token = share_data.share_token + share, release = share_service.get_release_by_share_token(share_token, password) + if release.app.type == AppType.WORKFLOW: + workflow_service = WorkflowService(db) + content = { + "app_type": release.app.type, + "variables": workflow_service.get_start_node_variables(release.config) + } + elif release.app.type == AppType.AGENT: + content = { + "app_type": release.app.type, + "variables": release.config.get("variables") + } + elif release.app.type == AppType.MULTI_AGENT: + content = { + "app_type": release.app.type, + "variables": [] + } + else: + return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED) + return success(data=content) diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index c7b81999..89d3f3d6 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -21,63 +21,64 @@ from app.repositories import knowledge_repository import json from app.services.task_service import get_task_memory_write_result from app.tasks import write_message_task + logger = get_business_logger() class SharedChatService: """基于分享链接的聊天服务""" - + def __init__(self, db: Session): self.db = db self.conversation_service = ConversationService(db) self.share_service = ReleaseShareService(db) - - def _get_release_by_share_token( - self, - share_token: str, - password: Optional[str] = None + + def get_release_by_share_token( + self, + share_token: str, + password: Optional[str] = None ) -> tuple[ReleaseShare, AppRelease]: """通过 share_token 获取发布版本""" # 获取分享配置 share = self.share_service.repo.get_by_share_token(share_token) if not share: raise ResourceNotFoundException("分享链接", share_token) - + # 验证分享是否启用 if not share.is_enabled: raise BusinessException("该分享链接已被禁用", BizCode.SHARE_DISABLED) - + # 验证密码 if share.require_password: if not password: raise BusinessException("需要提供访问密码", BizCode.PASSWORD_REQUIRED) - + if not self.share_service.verify_password(share_token, password): raise BusinessException("访问密码错误", BizCode.INVALID_PASSWORD) - + # 获取发布版本 release = self.db.get(AppRelease, share.release_id) if not release: raise ResourceNotFoundException("发布版本", str(share.release_id)) - + # 更新访问统计 try: self.share_service.repo.increment_view_count(share.id) except Exception as e: logger.warning(f"更新访问统计失败: {str(e)}") - + return share, release - + def create_or_get_conversation( - self, - share_token: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - password: Optional[str] = None + self, + share_token: str, + conversation_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None, + password: Optional[str] = None ) -> Conversation: """创建或获取会话""" - share, release = self._get_release_by_share_token(share_token, password) - + share, release = self.get_release_by_share_token(share_token, password) + # 如果提供了 conversation_id,尝试获取现有会话 if conversation_id: try: @@ -85,18 +86,18 @@ class SharedChatService: conversation_id=conversation_id, workspace_id=release.app.workspace_id ) - + # 验证会话是否属于该应用 if conversation.app_id != release.app_id: raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION) - + return conversation except ResourceNotFoundException: logger.warning( "会话不存在,将创建新会话", extra={"conversation_id": str(conversation_id)} ) - + # 创建新会话(使用发布版本的配置) conversation = self.conversation_service.create_conversation( app_id=release.app_id, @@ -105,7 +106,7 @@ class SharedChatService: is_draft=False, # 分享链接使用发布版本 config_snapshot=release.config ) - + logger.info( "为分享链接创建新会话", extra={ @@ -114,25 +115,25 @@ class SharedChatService: "release_id": str(release.id) } ) - + return conversation - + async def chat( - self, - share_token: str, - message: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - password: Optional[str] = None, - web_search: bool = False, - memory: bool = True, + self, + share_token: str, + message: str, + conversation_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + password: Optional[str] = None, + web_search: bool = False, + memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """聊天(非流式)""" actual_config_id = None - config_id=actual_config_id + config_id = actual_config_id from app.core.agent.langchain_agent import LangChainAgent from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.services.model_parameter_merger import ModelParameterMerger @@ -140,32 +141,30 @@ class SharedChatService: from sqlalchemy import select from app.models import ModelApiKey - start_time = time.time() - actual_config_id=None - config_id=actual_config_id - + actual_config_id = None + config_id = actual_config_id + if variables is None: variables = {} - + # 获取发布版本和配置 - share, release = self._get_release_by_share_token(share_token, password) - + share, release = self.get_release_by_share_token(share_token, password) + # 获取 Agent 配置 config = release.config or {} - # 获取模型配置ID model_config_id = release.default_model_config_id if not model_config_id: raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING) - + # 获取模型配置 from app.models import ModelConfig model_config = self.db.get(ModelConfig, model_config_id) if not model_config: raise ResourceNotFoundException("模型配置", str(model_config_id)) - + # 获取 API Key # stmt = ( # select(ModelApiKey).join( @@ -184,7 +183,7 @@ class SharedChatService: api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) if not api_key_obj: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) - + # 获取或创建会话 conversation = self.create_or_get_conversation( share_token=share_token, @@ -192,7 +191,7 @@ class SharedChatService: user_id=user_id, password=password ) - + # 处理系统提示词(支持变量替换) system_prompt = config.get("system_prompt", "你是一个专业的AI助手") if variables: @@ -202,31 +201,31 @@ class SharedChatService: variables ) system_prompt = system_prompt_rendered.get_text_content() or system_prompt - + # 准备工具列表 tools = [] - + # 添加知识库检索工具 knowledge_retrieval = config.get("knowledge_retrieval") if knowledge_retrieval: knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id) + kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) tools.append(kb_tool) # 添加长期记忆工具 - memory_flag=False + memory_flag = False if memory: memory_config = config.get("memory", {}) if memory_config.get("enabled") and user_id: - memory_flag=True + memory_flag = True memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.append(memory_tool) - web_tools=config.get("tools") + web_tools = config.get("tools") web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled",False) + web_search_enable = web_search_choice.get("enabled", False) if web_search: if web_search_enable: search_tool = create_web_search_tool({}) @@ -238,10 +237,10 @@ class SharedChatService: "tool_count": len(tools) } ) - + # 获取模型参数 model_parameters = config.get("model_parameters", {}) - + # 创建 LangChain Agent agent = LangChainAgent( model_name=api_key_obj.model_name, @@ -254,10 +253,10 @@ class SharedChatService: tools=tools, ) - + # 加载历史消息 history = [] - memory_config={"enabled":True,'max_history':10} + memory_config = {"enabled": True, 'max_history': 10} if memory_config.get("enabled"): messages = self.conversation_service.get_messages( conversation_id=conversation.id, @@ -267,7 +266,7 @@ class SharedChatService: {"role": msg.role, "content": msg.content} for msg in messages ] - + # 调用 Agent result = await agent.chat( message=message, @@ -279,7 +278,7 @@ class SharedChatService: config_id=config_id, memory_flag=memory_flag ) - + # 保存消息 self.conversation_service.save_conversation_messages( conversation_id=conversation.id, @@ -298,7 +297,7 @@ class SharedChatService: # role="user", # content=message # ) - + # self.conversation_service.add_message( # conversation_id=conversation.id, # role="assistant", @@ -308,12 +307,11 @@ class SharedChatService: # "usage": result.get("usage", {}) # } # ) - + elapsed_time = time.time() - start_time ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) - return { "conversation_id": conversation.id, "message": result["content"], @@ -324,19 +322,19 @@ class SharedChatService: }), "elapsed_time": elapsed_time } - + async def chat_stream( - self, - share_token: str, - message: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - password: Optional[str] = None, - web_search: bool = False, - memory: bool = True, - storage_type:Optional[str] = None, - user_rag_memory_id: Optional[str] = None, + self, + share_token: str, + message: str, + conversation_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + password: Optional[str] = None, + web_search: bool = False, + memory: bool = True, + storage_type: Optional[str] = None, + user_rag_memory_id: Optional[str] = None, ) -> AsyncGenerator[str, None]: """聊天(流式)""" from app.core.agent.langchain_agent import LangChainAgent @@ -345,36 +343,35 @@ class SharedChatService: from sqlalchemy import select from app.models import ModelApiKey import json - - start_time = time.time() - actual_config_id=None - config_id=actual_config_id - + start_time = time.time() + actual_config_id = None + config_id = actual_config_id + if variables is None: variables = {} # 兼容新旧字段名:使用 memory_config_id memory_config = {"enabled": memory, "memory_config_id": "17", "max_history": 10} - + try: # 获取发布版本和配置 - share, release = self._get_release_by_share_token(share_token, password) - + share, release = self.get_release_by_share_token(share_token, password) + # 获取 Agent 配置 config = release.config or {} agent_config_data = config.get("agent_config", {}) - + # 获取模型配置ID model_config_id = release.default_model_config_id if not model_config_id: raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING) - + # 获取模型配置 from app.models import ModelConfig model_config = self.db.get(ModelConfig, model_config_id) if not model_config: raise ResourceNotFoundException("模型配置", str(model_config_id)) - + # 获取 API Key # stmt = ( # select(ModelApiKey).join( @@ -393,7 +390,7 @@ class SharedChatService: api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) if not api_key_obj: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) - + # 获取或创建会话 conversation = self.create_or_get_conversation( share_token=share_token, @@ -401,7 +398,7 @@ class SharedChatService: user_id=user_id, password=password ) - + # 处理系统提示词(支持变量替换) system_prompt = config.get("system_prompt", "你是一个专业的AI助手") if variables: @@ -411,21 +408,21 @@ class SharedChatService: variables ) system_prompt = system_prompt_rendered.get_text_content() or system_prompt - + # 准备工具列表 tools = [] - + # 添加知识库检索工具 knowledge_retrieval = config.get("knowledge_retrieval") if knowledge_retrieval: knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id) + kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) tools.append(kb_tool) - + # 添加长期记忆工具 - memory_flag=False + memory_flag = False if memory: memory_config = config.get("memory", {}) if memory_config.get("enabled") and user_id: @@ -450,7 +447,7 @@ class SharedChatService: # 获取模型参数 model_parameters = config.get("model_parameters", {}) - + # 创建 LangChain Agent agent = LangChainAgent( model_name=api_key_obj.model_name, @@ -463,7 +460,7 @@ class SharedChatService: tools=tools, streaming=True ) - + # 加载历史消息 history = [] memory_config = {"enabled": True, 'max_history': 10} @@ -476,22 +473,22 @@ class SharedChatService: {"role": msg.role, "content": msg.content} for msg in messages ] - + # 发送开始事件 yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n" - + # 流式调用 Agent full_content = "" total_tokens = 0 async for chunk in agent.chat_stream( - message=message, - history=history, - context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag + message=message, + history=history, + context=None, + end_user_id=user_id, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + config_id=config_id, + memory_flag=memory_flag ): if isinstance(chunk, int): total_tokens = chunk @@ -499,16 +496,16 @@ class SharedChatService: full_content += chunk # 发送消息块事件 yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" - + elapsed_time = time.time() - start_time - + # 保存消息 self.conversation_service.add_message( conversation_id=conversation.id, role="user", content=message ) - + self.conversation_service.add_message( conversation_id=conversation.id, role="assistant", @@ -524,7 +521,7 @@ class SharedChatService: # 发送结束事件 end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)} yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" - + logger.info( "流式聊天完成", extra={ @@ -533,7 +530,7 @@ class SharedChatService: "message_length": len(full_content) } ) - + except (GeneratorExit, asyncio.CancelledError): # 生成器被关闭或任务被取消,正常退出 logger.debug("流式聊天被中断") @@ -542,39 +539,39 @@ class SharedChatService: logger.error(f"流式聊天失败: {str(e)}", exc_info=True) # 发送错误事件 yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" - + def get_conversation_messages( - self, - share_token: str, - conversation_id: uuid.UUID, - password: Optional[str] = None + self, + share_token: str, + conversation_id: uuid.UUID, + password: Optional[str] = None ) -> Conversation: """获取会话消息""" - share, release = self._get_release_by_share_token(share_token, password) - + share, release = self.get_release_by_share_token(share_token, password) + # 获取会话 conversation = self.conversation_service.get_conversation( conversation_id=conversation_id, workspace_id=release.app.workspace_id ) - + # 验证会话是否属于该应用 if conversation.app_id != release.app_id: raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION) - + return conversation - + def list_conversations( - self, - share_token: str, - user_id: Optional[str] = None, - password: Optional[str] = None, - page: int = 1, - pagesize: int = 20 + self, + share_token: str, + user_id: Optional[str] = None, + password: Optional[str] = None, + page: int = 1, + pagesize: int = 20 ) -> tuple[list[Conversation], int]: """列出会话""" - share, release = self._get_release_by_share_token(share_token, password) - + share, release = self.get_release_by_share_token(share_token, password) + conversations, total = self.conversation_service.list_conversations( app_id=release.app_id, workspace_id=release.app.workspace_id, @@ -583,19 +580,19 @@ class SharedChatService: page=page, pagesize=pagesize ) - + return conversations, total - + async def multi_agent_chat( - self, - share_token: str, - message: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - password: Optional[str] = None, - web_search: bool = False, - memory: bool = True, + self, + share_token: str, + message: str, + conversation_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + password: Optional[str] = None, + web_search: bool = False, + memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None ) -> Dict[str, Any]: @@ -603,18 +600,16 @@ class SharedChatService: from app.services.multi_agent_service import MultiAgentService from app.models import MultiAgentConfig - - start_time = time.time() - actual_config_id=None - config_id=actual_config_id - + actual_config_id = None + config_id = actual_config_id + if variables is None: variables = {} - + # 获取发布版本和配置 - share, release = self._get_release_by_share_token(share_token, password) - + share, release = self.get_release_by_share_token(share_token, password) + # 获取或创建会话 conversation = self.create_or_get_conversation( share_token=share_token, @@ -622,19 +617,19 @@ class SharedChatService: user_id=user_id, password=password ) - + # 获取多 Agent 配置 multi_agent_config = self.db.query(MultiAgentConfig).filter( MultiAgentConfig.app_id == release.app_id, MultiAgentConfig.is_active.is_(True) ).first() - + if not multi_agent_config: raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - + # 构建多 Agent 运行请求 from app.schemas.multi_agent_schema import MultiAgentRunRequest - + multi_agent_request = MultiAgentRunRequest( message=message, conversation_id=conversation.id, @@ -644,23 +639,23 @@ class SharedChatService: web_search=web_search, memory=memory ) - + # 使用多 Agent 服务执行 multi_agent_service = MultiAgentService(self.db) result = await multi_agent_service.run( app_id=release.app_id, request=multi_agent_request ) - + elapsed_time = time.time() - start_time - + # 保存消息 self.conversation_service.add_message( conversation_id=conversation.id, role="user", content=message ) - + self.conversation_service.add_message( conversation_id=conversation.id, role="assistant", @@ -672,8 +667,6 @@ class SharedChatService: } ) - - return { "conversation_id": conversation.id, "message": result.get("message", ""), @@ -684,34 +677,33 @@ class SharedChatService: }, "elapsed_time": elapsed_time } - + async def multi_agent_chat_stream( - self, - share_token: str, - message: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - password: Optional[str] = None, - web_search: bool = False, - memory: bool = True, + self, + share_token: str, + message: str, + conversation_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + password: Optional[str] = None, + web_search: bool = False, + memory: bool = True, storage_type: Optional[str] = None, - user_rag_memory_id:Optional[str] = None + user_rag_memory_id: Optional[str] = None ) -> AsyncGenerator[str, None]: """多 Agent 聊天(流式)""" - start_time = time.time() - actual_config_id=None - config_id=actual_config_id - + actual_config_id = None + config_id = actual_config_id + if variables is None: variables = {} - + try: # 获取发布版本和配置 - share, release = self._get_release_by_share_token(share_token, password) - + share, release = self.get_release_by_share_token(share_token, password) + # 获取或创建会话 conversation = self.create_or_get_conversation( share_token=share_token, @@ -719,28 +711,28 @@ class SharedChatService: user_id=user_id, password=password ) - + # 获取多 Agent 配置 multi_agent_config = self.db.query(MultiAgentConfig).filter( MultiAgentConfig.app_id == release.app_id, MultiAgentConfig.is_active.is_(True) ).first() - + if not multi_agent_config: raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - + # 获取 storage_type 和 user_rag_memory_id workspace_id = release.app.workspace_id storage_type = 'neo4j' # 默认值 user_rag_memory_id = '' - + try: # 获取工作空间的存储类型(不需要用户权限检查,因为是公开分享) from app.models import Workspace workspace = self.db.get(Workspace, workspace_id) if workspace and workspace.storage_type: storage_type = workspace.storage_type - + # 获取 USER_RAG_MERORY 知识库 ID knowledge = knowledge_repository.get_knowledge_by_name( db=self.db, @@ -751,13 +743,13 @@ class SharedChatService: user_rag_memory_id = str(knowledge.id) except Exception as e: logger.warning(f"获取 storage_type 或 user_rag_memory_id 失败,使用默认值: {str(e)}") - + # 发送开始事件 yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n" - + # 构建多 Agent 运行请求 from app.schemas.multi_agent_schema import MultiAgentRunRequest - + multi_agent_request = MultiAgentRunRequest( message=message, conversation_id=conversation.id, @@ -767,20 +759,20 @@ class SharedChatService: web_search=web_search, memory=memory ) - + # 使用多 Agent 服务流式执行 multi_agent_service = MultiAgentService(self.db) full_content = "" - + async for event in multi_agent_service.run_stream( - app_id=release.app_id, - request=multi_agent_request, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + app_id=release.app_id, + request=multi_agent_request, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): # 直接转发事件 yield event - + # 尝试提取内容(用于保存) if "data:" in event: try: @@ -790,16 +782,16 @@ class SharedChatService: full_content += data["content"] except: pass - + elapsed_time = time.time() - start_time - + # 保存消息 self.conversation_service.add_message( conversation_id=conversation.id, role="user", content=message ) - + self.conversation_service.add_message( conversation_id=conversation.id, role="assistant", @@ -808,7 +800,7 @@ class SharedChatService: "elapsed_time": elapsed_time } ) - + logger.info( "多 Agent 流式聊天完成", extra={ @@ -818,7 +810,6 @@ class SharedChatService: } ) - except (GeneratorExit, asyncio.CancelledError): # 生成器被关闭或任务被取消,正常退出 logger.debug("多 Agent 流式聊天被中断") diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index ffcf8b0c..a388ca75 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -13,6 +13,7 @@ from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.workflow.adapters.registry import PlatformAdapterRegistry +from app.core.workflow.nodes.enums import NodeType from app.core.workflow.validator import validate_workflow_config from app.db import get_db from app.models import App @@ -617,7 +618,8 @@ class WorkflowService: "event": "end", "data": { "elapsed_time": payload.get("elapsed_time"), - "message_length": len(payload.get("output", "")) + "message_length": len(payload.get("output", "")), + "error": payload.get("error", "") } } case "node_start" | "node_end" | "node_error" | "cycle_item": @@ -779,6 +781,14 @@ class WorkflowService: } } + @staticmethod + def get_start_node_variables(config: dict) -> list: + nodes = config.get("nodes", []) + for node in nodes: + if node.get("type") == NodeType.START: + return node.get("config", {}).get("variables", []) + raise BusinessException("workflow config error - start node not found") + def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]: """清理事件数据,移除不可序列化的对象 From 537668b4633eb88e72d5f527df2ca8527b8f1857 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 3 Mar 2026 11:08:24 +0800 Subject: [PATCH 049/164] Merge pull request #432 from SuanmoSuanyangTechnology/feature/workflow_import_zy Feature/workflow import zy --- api/app/schemas/app_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index eeb73a01..24475d5b 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -433,7 +433,7 @@ class AppChatRequest(BaseModel): user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)") variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") stream: bool = Field(default=False, description="是否流式返回") - files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)") + files: Optional[List[FileInput]] = Field(default=list, description="附件列表(支持多文件)") class DraftRunRequest(BaseModel): From 45a64dbbacacf3dd1c872ce8447eb62c55e0ede2 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 3 Mar 2026 11:15:14 +0800 Subject: [PATCH 050/164] fix(web): agent's variables init update --- web/src/views/ApplicationConfig/Agent.tsx | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 4bee291b..6018c600 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:29:21 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-25 18:11:49 + * @Last Modified time: 2026-03-03 11:14:30 */ import { type FC, type ReactNode, useEffect, useRef, useState, forwardRef, useImperativeHandle } from 'react'; import clsx from 'clsx' @@ -175,6 +175,10 @@ const Agent = forwardRef((_props, ref) => { const parsedMemoryContent = memoryContent === null || memoryContent === '' ? undefined : !isNaN(Number(memoryContent)) ? Number(memoryContent) : memoryContent + const variableList = variables?.map((item, index) => ({ + ...item, + index + })) || [] form.setFieldsValue({ ...response, tools: allTools, @@ -185,9 +189,10 @@ const Agent = forwardRef((_props, ref) => { skills: { ...skills, skill_ids: allSkills - } + }, + variables: [...variableList] }) - updateVariableList([...variables]) + updateVariableList([...variableList]) setData({ ...response, tools: allTools From 2b6d86e5918a5f3d8d3265487f23cbf039162286 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 3 Mar 2026 11:49:33 +0800 Subject: [PATCH 051/164] [changes] --- api/app/core/config.py | 2 +- .../core/memory/ontology_services}/General_purpose_entity.ttl | 0 api/app/core/memory/utils/ontology/ontology_parser.py | 2 +- api/env.example | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename api/{ => app/core/memory/ontology_services}/General_purpose_entity.ttl (100%) diff --git a/api/app/core/config.py b/api/app/core/config.py index 19998d32..6a2cf206 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -230,7 +230,7 @@ class Settings: # General Ontology Type Configuration # ======================================================================== # 通用本体文件路径列表(逗号分隔) - GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl") + GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "app/core/memory/ontology_services/General_purpose_entity.ttl") # 是否启用通用本体类型功能 ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true" diff --git a/api/General_purpose_entity.ttl b/api/app/core/memory/ontology_services/General_purpose_entity.ttl similarity index 100% rename from api/General_purpose_entity.ttl rename to api/app/core/memory/ontology_services/General_purpose_entity.ttl diff --git a/api/app/core/memory/utils/ontology/ontology_parser.py b/api/app/core/memory/utils/ontology/ontology_parser.py index a8bd054c..d75a8905 100644 --- a/api/app/core/memory/utils/ontology/ontology_parser.py +++ b/api/app/core/memory/utils/ontology/ontology_parser.py @@ -327,7 +327,7 @@ class MultiOntologyParser: Example: >>> parser = MultiOntologyParser([ - ... "General_purpose_entity.ttl", + ... "app/core/memory/ontology_services/General_purpose_entity.ttl", ... "domain_specific.owl" ... ]) >>> registry = parser.parse_all() diff --git a/api/env.example b/api/env.example index e8074f82..d67bbf7c 100644 --- a/api/env.example +++ b/api/env.example @@ -139,7 +139,7 @@ SMTP_USER= SMTP_PASSWORD= # 本体类型融合配置 (记得写入env_example) -GENERAL_ONTOLOGY_FILES=General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔 +GENERAL_ONTOLOGY_FILES=app/core/memory/ontology_services/General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔 ENABLE_GENERAL_ONTOLOGY_TYPES=true # 总开关,控制是否启用通用本体类型融合功能(false = 不使用任何本体类型指导) MAX_ONTOLOGY_TYPES_IN_PROMPT=100 # 限制传给 LLM 的类型数量,防止 Prompt 过长 CORE_GENERAL_TYPES=Person,Organization,Place,Event,Work,Concept # 定义核心类型列表,这些类型会优先包含在合并结果中 From a5034e84ba0a2982f1cf2f4e6da5da7fb5d1281c Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 3 Mar 2026 12:18:52 +0800 Subject: [PATCH 052/164] fix(agent): fix issue where default runtime file list configuration was empty --- api/app/schemas/app_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 24475d5b..07875e13 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -433,7 +433,7 @@ class AppChatRequest(BaseModel): user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)") variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") stream: bool = Field(default=False, description="是否流式返回") - files: Optional[List[FileInput]] = Field(default=list, description="附件列表(支持多文件)") + files: List[FileInput] = Field(default_factory=list, description="附件列表(支持多文件)") class DraftRunRequest(BaseModel): From 304ccef1016a7e42ba1199fd32fe7133bb7d6a9d Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Tue, 3 Mar 2026 12:30:09 +0800 Subject: [PATCH 053/164] chore(api): organize imports and refactor database context management --- .gitignore | 1 + api/app/tasks.py | 39 +++++++++++++-------------------------- 2 files changed, 14 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index 2fb41537..66d1beb2 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ search_results.json api/migrations/versions tmp files +powers/ # Exclude dep files huggingface.co/ diff --git a/api/app/tasks.py b/api/app/tasks.py index 8e3aea85..299d188b 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,16 +1,16 @@ import asyncio -from concurrent.futures import ThreadPoolExecutor import json import os import re +import shutil import time import uuid -from uuid import UUID +from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone from math import ceil from pathlib import Path -import shutil from typing import Any, Dict, List, Optional +from uuid import UUID import redis import requests @@ -38,7 +38,7 @@ from app.db import get_db, get_db_context from app.models.document_model import Document from app.models.file_model import File from app.models.knowledge_model import Knowledge -from app.schemas import file_schema, document_schema +from app.schemas import document_schema, file_schema from app.services.memory_agent_service import MemoryAgentService from app.utils.config_utils import resolve_config_id @@ -67,8 +67,9 @@ def parse_document(file_path: str, document_id: uuid.UUID): Document parsing, vectorization, and storage """ # Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) - import trio import importlib + + import trio importlib.reload(trio) db = next(get_db()) # Manually call the generator db_document = None @@ -297,8 +298,9 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): build knowledge graph """ # Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) - import trio import importlib + + import trio importlib.reload(trio) db = next(get_db()) # Manually call the generator db_documents = None @@ -932,24 +934,18 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s if actual_config_id is None: try: from app.services.memory_agent_service import get_end_user_connected_config - db = next(get_db()) - try: + with get_db_context() as db: connected_config = get_end_user_connected_config(end_user_id, db) actual_config_id = connected_config.get("memory_config_id") - finally: - db.close() except Exception: # Log but continue - will fail later with proper error pass async def _run() -> str: - db = next(get_db()) - try: + with get_db_context() as db: service = MemoryAgentService() return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id) - finally: - db.close() try: # 使用 nest_asyncio 来避免事件循环冲突 @@ -1049,19 +1045,15 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s if actual_config_id is None: try: from app.services.memory_agent_service import get_end_user_connected_config - db = next(get_db()) - try: + with get_db_context() as db: connected_config = get_end_user_connected_config(end_user_id, db) actual_config_id = connected_config.get("memory_config_id") - finally: - db.close() except Exception: # Log but continue - will fail later with proper error pass async def _run() -> str: - db = next(get_db()) - try: + with get_db_context() as db: logger.info( f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") service = MemoryAgentService() @@ -1069,11 +1061,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s user_rag_memory_id, language) logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result - except Exception as e: - logger.error(f"[CELERY WRITE] Write failed: {e}", exc_info=True) - raise - finally: - db.close() try: # 使用 nest_asyncio 来避免事件循环冲突 @@ -1328,9 +1315,9 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: async def _run() -> Dict[str, Any]: from app.core.logging_config import get_api_logger - from app.models.workspace_model import Workspace from app.models.app_model import App from app.models.end_user_model import EndUser + from app.models.workspace_model import Workspace from app.repositories.memory_increment_repository import write_memory_increment from app.services.memory_storage_service import search_all From 1792cb4d933fb8fa06212f2180e33145bb77d183 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 3 Mar 2026 13:48:50 +0800 Subject: [PATCH 054/164] feat(web): chat add variables --- web/src/api/application.ts | 18 ++-- web/src/i18n/en.ts | 5 +- web/src/i18n/zh.ts | 5 +- web/src/views/Conversation/index.tsx | 83 ++++++++++++++++++- web/src/views/Conversation/types.ts | 3 +- .../components/Chat/VariableConfigModal.tsx | 6 +- .../Properties/VariableList/types.ts | 1 + 7 files changed, 102 insertions(+), 19 deletions(-) diff --git a/web/src/api/application.ts b/web/src/api/application.ts index f019103e..c769dd91 100644 --- a/web/src/api/application.ts +++ b/web/src/api/application.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 13:59:45 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-28 16:34:15 + * @Last Modified time: 2026-03-03 12:08:42 */ import { request } from '@/utils/request' import type { ApplicationModalData } from '@/views/ApplicationManagement/types' @@ -120,15 +120,19 @@ export const copyApplication = (app_id: string, new_name: string) => { export const getAppStatistics = (app_id: string, data: { start_date: number; end_date: number; }) => { return request.get(`/apps/${app_id}/statistics`, data) } -// 导出工作流 -export const exportWorkflow = (app_id: string, fileName: string) => { - return request.downloadFile(`/apps/${app_id}/workflow/export`, fileName, undefined, undefined, 'GET') -} -// 工作流上传+兼容性分析 +// Upload workflow and analyze compatibility export const importWorkflow = (formData: FormData) => { return request.uploadFile(`/apps/workflow/import`, formData) } -// 完成工作流导入 +// Complete workflow import export const completeImportWorkflow = (data: { temp_id: string; name?: string; description?: string }) => { return request.post(`/apps/workflow/import/save`, data) } +// Get experience config +export const getExperienceConfig = (share_token: string) => { + return request.get(`/public/share/config`, {}, { + headers: { + 'Authorization': `Bearer ${localStorage.getItem(`shareToken_${share_token}`)}` + } + }) +} \ No newline at end of file diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 5f041a7c..bc1f0c45 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1684,7 +1684,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re fileType: 'File Type', image: 'Image', fileUrl: 'File URL', - addRemoteFile: 'Add Remote File' + addRemoteFile: 'Add Remote File', + variableConfig: 'Variable Configuration', }, login: { title: 'Red Bear Memory Science', @@ -2192,7 +2193,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re save: 'Save', export: 'Export', variableConfig: 'Variable Configuration', - variableRequired: 'Required', + variableRequired: 'Required, please configure variable values', addMessage: 'Add Message', answerDesc: 'Reply', addNode: 'Add Node', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 0c0f7562..826001ed 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1681,7 +1681,8 @@ export const zh = { fileType: '文件类型', image: '图片', fileUrl: '文件链接', - addRemoteFile: '添加远程文件' + addRemoteFile: '添加远程文件', + variableConfig: '变量配置', }, login: { title: '红熊记忆科学', @@ -2192,7 +2193,7 @@ export const zh = { save: '保存', export: '导出', variableConfig: '变量配置', - variableRequired: '必填', + variableRequired: '必填,请配置变量值', addMessage: '添加消息', answerDesc: '回复', addNode: '添加节点', diff --git a/web/src/views/Conversation/index.tsx b/web/src/views/Conversation/index.tsx index 825ea834..f532ac53 100644 --- a/web/src/views/Conversation/index.tsx +++ b/web/src/views/Conversation/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:58:03 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-10 17:41:05 + * @Last Modified time: 2026-03-03 13:46:22 */ /** * Conversation Page @@ -14,11 +14,12 @@ import { type FC, useState, useEffect, useRef } from 'react' import { useParams, useLocation } from 'react-router-dom' import { useTranslation } from 'react-i18next' import InfiniteScroll from 'react-infinite-scroll-component'; -import { Flex, Skeleton, Form, Dropdown, type MenuProps } from 'antd' +import { Flex, Skeleton, Form, Dropdown, type MenuProps, App } from 'antd' +import { SettingOutlined } from '@ant-design/icons' import clsx from 'clsx' import dayjs from 'dayjs' -import { getConversationHistory, sendConversation, getConversationDetail, getShareToken } from '@/api/application' +import { getConversationHistory, sendConversation, getConversationDetail, getShareToken, getExperienceConfig } from '@/api/application' import type { HistoryItem, QueryParams, UploadFileListModalRef } from './types' import Empty from '@/components/Empty' import { formatDateTime } from '@/utils/format'; @@ -37,12 +38,16 @@ import UploadFiles from './components/FileUpload' // import AudioRecorder from '@/components/AudioRecorder' import { shareFileUploadUrlWithoutApiPrefix } from '@/api/fileStorage' import UploadFileListModal from './components/UploadFileListModal' +import type { VariableConfigModalRef } from '@/views/Workflow/types' +import type { Variable } from '@/views/Workflow/components/Properties/VariableList/types' +import VariableConfigModal from '@/views/Workflow/components/Chat/VariableConfigModal'; /** * Conversation component for shared applications */ const Conversation: FC = () => { const { t } = useTranslation() + const { message: messageApi } = App.useApp() const { token } = useParams() const location = useLocation() const searchParams = new URLSearchParams(location.search) @@ -64,6 +69,22 @@ const Conversation: FC = () => { const queryValues = Form.useWatch([], form) const uploadFileListModalRef = useRef(null) + + const variableConfigModalRef = useRef(null) + const [variables, setVariables] = useState([]) // Workflow input variables + + /** + * Opens the variable configuration modal + */ + const handleEditVariables = () => { + variableConfigModalRef.current?.handleOpen(variables) + } + /** + * Saves updated variable values from the modal + */ + const handleSave = (values: Variable[]) => { + setVariables([...values]) + } useEffect(() => { const shareToken = localStorage.getItem(`shareToken_${token}`) setShareToken(shareToken) @@ -81,6 +102,17 @@ const Conversation: FC = () => { getHistory() } }, [token, shareToken, page, hasMore, historyList]) + useEffect(() => { + if (shareToken && token) { + getExperienceConfig(token) + .then(res => { + const response = res as { variables: Variable[] } + setVariables(response.variables || []) + }) + } else { + setChatList([]) + } + }, [shareToken, token]) /** Group conversation history by date */ const groupHistoryByDate = (items: HistoryItem[]): Record => { @@ -191,12 +223,35 @@ const Conversation: FC = () => { }) } + const isNeedVariableConfig = variables.some(vo => vo.required && (vo.value === null || vo.value === undefined || vo.value === '')) + /** Send message and handle streaming response */ const handleSend = () => { if (!token || !shareToken) { return } const { files = [], ...rest } = queryValues || {} + // Validate required variables before sending + let isCanSend = true + const params: Record = {} + if (variables.length > 0) { + const needRequired: string[] = [] + variables.forEach(vo => { + params[vo.name] = vo.value ?? vo.defaultValue + + if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) { + isCanSend = false + needRequired.push(vo.name) + } + }) + + if (needRequired.length) { + messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`) + } + } + if (!isCanSend) { + return + } setLoading(true) setStreamLoading(true) addUserMessage(message, files) @@ -247,7 +302,8 @@ const Conversation: FC = () => { upload_file_id: file.response.data.file_id } } - }) + }), + variables: params }, handleStreamMessage, shareToken) .finally(() => { setLoading(false) @@ -384,6 +440,20 @@ const Conversation: FC = () => { {t(`memoryConversation.memory`)} + {variables.length > 0 && ( + +
+ + {t(`memoryConversation.variableConfig`)} +
+
+ )} {/* @@ -399,6 +469,11 @@ const Conversation: FC = () => { ref={uploadFileListModalRef} refresh={addFileList} /> + ) } diff --git a/web/src/views/Conversation/types.ts b/web/src/views/Conversation/types.ts index deb14d1f..cc074c1b 100644 --- a/web/src/views/Conversation/types.ts +++ b/web/src/views/Conversation/types.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:57:46 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-06 21:11:19 + * @Last Modified time: 2026-03-03 13:46:55 */ /** * Type definitions for Conversation @@ -51,6 +51,7 @@ export interface QueryParams { /** Current conversation ID */ conversation_id?: string | null; files?: any[]; + variables?: Record; } export interface UploadFileListModalRef { diff --git a/web/src/views/Workflow/components/Chat/VariableConfigModal.tsx b/web/src/views/Workflow/components/Chat/VariableConfigModal.tsx index dd84e6ba..5acd3eb1 100644 --- a/web/src/views/Workflow/components/Chat/VariableConfigModal.tsx +++ b/web/src/views/Workflow/components/Chat/VariableConfigModal.tsx @@ -71,20 +71,20 @@ const VariableConfigModal = forwardRef { - field.type === 'string' && + (field.type === 'string' || field.type === 'text') && } { field.type === 'number' && form.setFieldValue(['variables', name, 'value'], value)} /> } { - field.type === 'boolean' && {`${field.name}·${field.description}`} + field.type === 'boolean' && {`${field.name}·${field.display_name || field.description}`} } ) diff --git a/web/src/views/Workflow/components/Properties/VariableList/types.ts b/web/src/views/Workflow/components/Properties/VariableList/types.ts index 1cc2939a..64e62c87 100644 --- a/web/src/views/Workflow/components/Properties/VariableList/types.ts +++ b/web/src/views/Workflow/components/Properties/VariableList/types.ts @@ -1,5 +1,6 @@ export interface Variable { name: string; + display_name?: string; type: string; required: boolean; description: string; From 0826a34d8bb4bdd0faaf1620f80774e5e8160773 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 3 Mar 2026 13:57:31 +0800 Subject: [PATCH 055/164] fix(web): http node body variable filter update --- .../components/Properties/HttpRequest/EditableTable.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx b/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx index 74593913..19706a71 100644 --- a/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx +++ b/web/src/views/Workflow/components/Properties/HttpRequest/EditableTable.tsx @@ -59,7 +59,7 @@ const EditableTable: FC = ({ render: (_: any, __: TableRow, index: number) => ( !option.dataType.includes('file'))} type="input" className={contentClassName} size={size} @@ -109,7 +109,7 @@ const EditableTable: FC = ({ const currentType = form.getFieldValue([...Array.isArray(parentName) ? parentName : [parentName], index, 'type']); const filteredOptions = currentType === 'file' ? booleanFilterOptions.filter(option => option.dataType.includes('file')) - : booleanFilterOptions; + : booleanFilterOptions.filter(option => !option.dataType.includes('file')); return ( From 7f36a06f26a8ed22ae0db43db96b574edc6f194f Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 3 Mar 2026 14:05:02 +0800 Subject: [PATCH 056/164] fix(web): update share version modal's title --- .../ApplicationConfig/components/ReleaseShareModal.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/src/views/ApplicationConfig/components/ReleaseShareModal.tsx b/web/src/views/ApplicationConfig/components/ReleaseShareModal.tsx index b98c1aa4..f26441fd 100644 --- a/web/src/views/ApplicationConfig/components/ReleaseShareModal.tsx +++ b/web/src/views/ApplicationConfig/components/ReleaseShareModal.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 16:28:46 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 16:28:46 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-03 14:03:44 */ /** * Release Share Modal @@ -79,7 +79,7 @@ const ReleaseShareModal = forwardRef{t('application.shareVersion')} {version?.version}} + title={<>{t('application.shareVersion')} ({version?.version_name && version.version_name[0].toLocaleLowerCase() === 'v' ? version.version_name : version?.version_name ? `v${version.version_name}` : `v${version?.version}`})} open={visible} onCancel={handleClose} footer={false} From ee4027c561db166dcb4a4c2f816014e3ba094f64 Mon Sep 17 00:00:00 2001 From: yujiangping Date: Tue, 3 Mar 2026 14:47:24 +0800 Subject: [PATCH 057/164] feat(web): enhance knowledge base sharing with stop share feedback - Fix file download URL to use absolute API path instead of apiPrefix variable - Add stopShareSuccess i18n message for English locale - Add stopShareSuccess i18n message for Chinese locale - Update ShareModal to display different success messages based on share toggle state - Show "Sharing is off" message when disabling knowledge base sharing - Improve user feedback when toggling share status on/off --- web/src/api/knowledgeBase.ts | 2 +- web/src/i18n/en.ts | 1 + web/src/i18n/zh.ts | 1 + web/src/views/KnowledgeBase/components/ShareModal.tsx | 9 +++++++-- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/web/src/api/knowledgeBase.ts b/web/src/api/knowledgeBase.ts index 1353067e..60ed2403 100644 --- a/web/src/api/knowledgeBase.ts +++ b/web/src/api/knowledgeBase.ts @@ -154,7 +154,7 @@ export const uploadFile = async (data: FormData, options?: UploadFileOptions) => // 下载文件 export const downloadFile = async (fileId: string, fileName?: string) => { const token = cookieUtils.get('authToken'); - const url = `${apiPrefix}/files/${fileId}`; + const url = `/api/files/${fileId}`; try { const response = await fetch(url, { diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 9892b728..9bf8ee00 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -580,6 +580,7 @@ export const en = { knowledgeBase: { pleaseUploadFileFirst: 'Please upload file first', shareSuccess: 'Share successfully', + stopShareSuccess: 'Sharing is off. Access denied. ', shareFailed: 'Share failed', allModels: 'All Models', knowledgeBaseInfo: 'Knowledge base information', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index c72da969..f8316651 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -122,6 +122,7 @@ export const zh = { preview: '预览', pleaseUploadFileFirst: '请先上传文件', shareSuccess: '分享成功', + stopShareSuccess: '已取消分享,对方将无法访问该知识库', shareFailed: '分享失败', allModels: '所有模型', knowledgeBaseInfo: '知识库信息', diff --git a/web/src/views/KnowledgeBase/components/ShareModal.tsx b/web/src/views/KnowledgeBase/components/ShareModal.tsx index 1877e337..dc4a732a 100644 --- a/web/src/views/KnowledgeBase/components/ShareModal.tsx +++ b/web/src/views/KnowledgeBase/components/ShareModal.tsx @@ -4,7 +4,7 @@ * @Author: yujiangping * @Date: 2025-11-10 18:52:55 * @LastEditors: yujiangping - * @LastEditTime: 2026-02-10 15:18:32 + * @LastEditTime: 2026-03-03 14:46:08 */ import { forwardRef, useImperativeHandle, useState, useRef } from 'react'; import { Switch } from 'antd'; @@ -75,7 +75,12 @@ const ShareModal = forwardRef(({ handleShare: updateKnowledgeBase(item.target_kb?.id, { status: checked ? 1 : 2 }).then(() => { - messageApi.success(t('knowledgeBase.shareSuccess')); + if(checked){ + messageApi.success(t('knowledgeBase.shareSuccess')); + }else{ + messageApi.success(t('knowledgeBase.stopShareSuccess')); + } + getShareSpaceList(kbId); }).catch(() => { messageApi.error(t('knowledgeBase.shareFailed')); From 9a98ccff2c701635bed3ae996c73669549923cf0 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 3 Mar 2026 14:48:50 +0800 Subject: [PATCH 058/164] feat(web): agent compare chat add variables --- web/src/views/ApplicationConfig/Agent.tsx | 6 +++- .../ApplicationConfig/components/Chat.tsx | 32 ++++++++++++++++--- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 6018c600..237c3373 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:29:21 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-03 11:14:30 + * @Last Modified time: 2026-03-03 14:24:34 */ import { type FC, type ReactNode, useEffect, useRef, useState, forwardRef, useImperativeHandle } from 'react'; import clsx from 'clsx' @@ -403,6 +403,9 @@ const Agent = forwardRef((_props, ref) => { const handleSaveChatVariable = (values: Variable[]) => { setChatVariables(values) } + useEffect(() => { + setChatVariables(values?.variables || []) + }, [values?.variables]) console.log('values', values) return ( <> @@ -507,6 +510,7 @@ const Agent = forwardRef((_props, ref) => { chatList={chatList} updateChatList={setChatList} handleSave={handleSave} + chatVariables={chatVariables} /> diff --git a/web/src/views/ApplicationConfig/components/Chat.tsx b/web/src/views/ApplicationConfig/components/Chat.tsx index 794489c6..8cb6812c 100644 --- a/web/src/views/ApplicationConfig/components/Chat.tsx +++ b/web/src/views/ApplicationConfig/components/Chat.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:27:39 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-10 17:40:15 + * @Last Modified time: 2026-03-03 14:21:54 */ /** * Chat debugging component for application testing @@ -13,7 +13,7 @@ import { type FC, useEffect, useState, useRef } from 'react'; import { useTranslation } from 'react-i18next'; import clsx from 'clsx' -import { Flex, Dropdown, type MenuProps } from 'antd' +import { Flex, Dropdown, type MenuProps, App } from 'antd' import ChatIcon from '@/assets/images/application/chat.png' import DebuggingEmpty from '@/assets/images/application/debuggingEmpty.png' @@ -28,6 +28,7 @@ import UploadFiles from '@/views/Conversation/components/FileUpload' // import AudioRecorder from '@/components/AudioRecorder' import UploadFileListModal from '@/views/Conversation/components/UploadFileListModal' import type { UploadFileListModalRef } from '@/views/Conversation/types' +import type { Variable } from './VariableList/types' /** * Component props @@ -43,14 +44,16 @@ interface ChatProps { handleSave: (flag?: boolean) => Promise; /** Source type: multi-agent cluster or single agent */ source?: 'multi_agent' | 'agent'; + chatVariables?: Variable[]; // Add chatVariables prop } /** * Chat debugging component * Allows testing application with different model configurations side-by-side */ -const Chat: FC = ({ chatList, data, updateChatList, handleSave, source = 'agent' }) => { +const Chat: FC = ({ chatList, data, updateChatList, handleSave, source = 'agent', chatVariables }) => { const { t } = useTranslation(); + const { message: messageApi } = App.useApp() const [loading, setLoading] = useState(false) const [isCluster, setIsCluster] = useState(source === 'multi_agent') const [conversationId, setConversationId] = useState(null) @@ -195,6 +198,27 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc }; setTimeout(() => { + // Validate required variables before sending + let isCanSend = true + const params: Record = {} + if (chatVariables && chatVariables.length > 0) { + const needRequired: string[] = [] + chatVariables.forEach(vo => { + params[vo.name] = vo.value + + if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) { + isCanSend = false + needRequired.push(vo.name) + } + }) + + if (needRequired.length) { + messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`) + } + } + if (!isCanSend) { + return + } runCompare(data.app_id, { message, files: fileList.map(file => { @@ -214,7 +238,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc model_parameters: item.model_parameters, conversation_id: item.conversation_id })), - variables: {}, + variables: params, "parallel": true, "stream": true, "timeout": 60, From 66c153f1ad9e21b69ff4c980b5892abeb0a09948 Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Tue, 3 Mar 2026 16:48:34 +0800 Subject: [PATCH 059/164] refactor(api): improve memory service dependency injection and code organization - Update ShortService and LongService constructors to accept db Session parameter for proper dependency injection instead of using module-level db instance - Reorganize imports in memory_short_term_controller.py following PEP 8 conventions (stdlib, third-party, local imports) - Add comprehensive docstrings with type hints to ShortService and LongService methods for better code documentation - Fix import organization in memory_short_service.py to group related imports and improve readability - Reorganize imports in user_memory_service.py to follow consistent import ordering patterns - Update ShortService instantiation in analytics_memory_types to pass db parameter - Remove module-level db instance initialization in favor of caller-managed database session lifecycle - Add type annotations to method signatures (end_user_id: str, db: Session, return types) - Improve code formatting and spacing consistency across memory service files --- .../memory_short_term_controller.py | 18 +++--- api/app/services/memory_short_service.py | 61 ++++++++++++++----- api/app/services/user_memory_service.py | 25 +++++--- 3 files changed, 72 insertions(+), 32 deletions(-) diff --git a/api/app/controllers/memory_short_term_controller.py b/api/app/controllers/memory_short_term_controller.py index 1cca266e..0acac6ce 100644 --- a/api/app/controllers/memory_short_term_controller.py +++ b/api/app/controllers/memory_short_term_controller.py @@ -1,16 +1,18 @@ -from fastapi import APIRouter, Depends, HTTPException, status,Header +from typing import Optional + +from dotenv import load_dotenv +from fastapi import APIRouter, Depends, Header, HTTPException, status +from sqlalchemy.orm import Session + from app.core.language_utils import get_language_from_header from app.core.logging_config import get_api_logger from app.core.response_utils import success from app.db import get_db from app.dependencies import get_current_user from app.models.user_model import User - +from app.services.memory_short_service import LongService, ShortService from app.services.memory_storage_service import search_entity -from app.services.memory_short_service import ShortService,LongService -from dotenv import load_dotenv -from sqlalchemy.orm import Session -from typing import Optional + load_dotenv() api_logger = get_api_logger() @@ -29,11 +31,11 @@ async def short_term_configs( language = get_language_from_header(language_type) # 获取短期记忆数据 - short_term=ShortService(end_user_id) + short_term=ShortService(end_user_id, db) short_result=short_term.get_short_databasets() short_count=short_term.get_short_count() - long_term=LongService(end_user_id) + long_term=LongService(end_user_id, db) long_result=long_term.get_long_databasets() entity_result = await search_entity(end_user_id) diff --git a/api/app/services/memory_short_service.py b/api/app/services/memory_short_service.py index fa3870f0..fa9623e0 100644 --- a/api/app/services/memory_short_service.py +++ b/api/app/services/memory_short_service.py @@ -1,22 +1,37 @@ +from typing import Dict, List + +from sqlalchemy.orm import Session from app.core.logging_config import get_api_logger -from app.db import get_db -from app.repositories.memory_short_repository import LongTermMemoryRepository -from app.repositories.memory_short_repository import ShortTermMemoryRepository - +from app.repositories.memory_short_repository import ( + LongTermMemoryRepository, + ShortTermMemoryRepository, +) api_logger = get_api_logger() -db=next(get_db()) + + class ShortService: - def __init__(self, end_user_id): + def __init__(self, end_user_id: str, db: Session) -> None: + """Service for short-term memory queries. + + Args: + end_user_id: The end user identifier to query memories for. + db: SQLAlchemy database session (caller-managed lifecycle). + """ self.short_repo = ShortTermMemoryRepository(db) self.end_user_id = end_user_id - def get_short_databasets(self): + def get_short_databasets(self) -> List[Dict]: + """Retrieve the latest short-term memory entries for the user. + + Returns: + List[Dict]: List of memory dicts with retrieval, message, and answer keys. + """ short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3) short_result = [] for memory in short_memories: - deep_expanded = {} # Create a new dictionary for each memory + deep_expanded = {} messages = memory.messages aimessages = memory.aimessages retrieved_content = memory.retrieved_content or [] @@ -27,23 +42,41 @@ class ShortService: for item in retrieved_content: if isinstance(item, dict): for key, values in item.items(): - retrieval_source.append({"query": key, "retrieval": values,"source":"上下文记忆"}) + retrieval_source.append({"query": key, "retrieval": values, "source": "上下文记忆"}) deep_expanded['retrieval'] = retrieval_source - deep_expanded['message'] = messages # 修正拼写错误 + deep_expanded['message'] = messages deep_expanded['answer'] = aimessages short_result.append(deep_expanded) return short_result - def get_short_count(self): + + def get_short_count(self) -> int: + """Count total short-term memory entries for the user. + + Returns: + int: Number of short-term memory records. + """ short_count = self.short_repo.count_by_user_id(self.end_user_id) return short_count + class LongService: - def __init__(self, end_user_id): + def __init__(self, end_user_id: str, db: Session) -> None: + """Service for long-term memory queries. + + Args: + end_user_id: The end user identifier to query memories for. + db: SQLAlchemy database session (caller-managed lifecycle). + """ self.long_repo = LongTermMemoryRepository(db) self.end_user_id = end_user_id - def get_long_databasets(self): - # 获取长期记忆数据 + + def get_long_databasets(self) -> List[Dict]: + """Retrieve long-term memory retrieval data for the user. + + Returns: + List[Dict]: List of dicts with query and retrieval keys. + """ long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1) long_result = [] diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index e34756b9..db5051d2 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -10,6 +10,9 @@ from collections import Counter from datetime import datetime from typing import Any, Dict, List, Optional, Tuple +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + from app.core.logging_config import get_logger from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context @@ -23,8 +26,6 @@ from app.services.memory_base_service import MemoryBaseService, MemoryTransServi from app.services.memory_config_service import MemoryConfigService from app.services.memory_perceptual_service import MemoryPerceptualService from app.services.memory_short_service import ShortService -from pydantic import BaseModel, Field -from sqlalchemy.orm import Session logger = get_logger(__name__) @@ -1035,9 +1036,10 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None, lan "growth_trajectory": str # 成长轨迹 } """ - from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt - from app.core.language_utils import validate_language import re + + from app.core.language_utils import validate_language + from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt # 验证语言参数 language = validate_language(language) @@ -1161,11 +1163,12 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st "one_sentence": str } """ - from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt - from app.core.language_utils import validate_language - from app.repositories.end_user_repository import EndUserRepository - from app.db import get_db import re + + from app.core.language_utils import validate_language + from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt + from app.db import get_db + from app.repositories.end_user_repository import EndUserRepository # 验证语言参数 language = validate_language(language) @@ -1457,7 +1460,7 @@ async def analytics_memory_types( short_term_count = 0 if end_user_id: try: - short_term_service = ShortService(end_user_id) + short_term_service = ShortService(end_user_id, db) short_term_data = short_term_service.get_short_databasets() # 统计 short_term 数组的长度 if short_term_data: @@ -1471,8 +1474,10 @@ async def analytics_memory_types( forgetting_threshold = 0.3 # 默认值 if end_user_id: try: + from app.core.memory.storage_services.forgetting_engine.config_utils import ( + load_actr_config_from_db, + ) from app.services.memory_agent_service import get_end_user_connected_config - from app.core.memory.storage_services.forgetting_engine.config_utils import load_actr_config_from_db # 获取用户关联的 config_id connected_config = get_end_user_connected_config(end_user_id, db) From d899b274489043d5474e7d0072e17af5d27b987b Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 3 Mar 2026 22:46:05 +0800 Subject: [PATCH 060/164] [changes] The timing of the memory increment task has been changed from relative time to absolute time. --- api/app/celery_app.py | 4 ++-- api/app/core/config.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/api/app/celery_app.py b/api/app/celery_app.py index f422f4a0..151c1e67 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -1,6 +1,7 @@ import os import platform from datetime import timedelta +from celery.schedules import crontab from urllib.parse import quote from celery import Celery @@ -90,9 +91,8 @@ celery_app.conf.update( celery_app.autodiscover_tasks(['app']) # Celery Beat schedule for periodic tasks -memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) +memory_increment_schedule = crontab(hour=2, minute=0) # 每天凌晨 2:00 执行 memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) -# 这个30秒的设计不合理 workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期 diff --git a/api/app/core/config.py b/api/app/core/config.py index 6a2cf206..4472d373 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -200,7 +200,6 @@ class Settings: REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300")) HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) - MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) # Memory Cache Regeneration Configuration From 8466c8e0192c84641ec2bd8f088075604b814bb2 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 3 Mar 2026 23:30:54 +0800 Subject: [PATCH 061/164] [fix] Revising the judgment method for the interest analysis tags --- .../controllers/memory_agent_controller.py | 32 ++--- .../core/memory/analytics/hot_memory_tags.py | 112 ++++++++++++++++++ .../core/memory/utils/prompt/prompt_utils.py | 17 +++ .../prompt/prompts/interest_filter.jinja2 | 47 ++++++++ api/app/services/memory_agent_service.py | 32 ++--- 5 files changed, 210 insertions(+), 30 deletions(-) create mode 100644 api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index b88e65ff..8f2e5c31 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -661,34 +661,38 @@ async def get_knowledge_type_stats_api( return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e)) -@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse) -async def get_hot_memory_tags_by_user_api( - end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), - limit: int = Query(20, description="返回标签数量限制"), +@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse) +async def get_interest_distribution_by_user_api( + end_user_id: Optional[str] = Query(None, description="用户ID(必填)"), + limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"), + language_type: str = Header(default=None, alias="X-Language-Type"), current_user: User = Depends(get_current_user), - db: Session=Depends(get_db), + db: Session = Depends(get_db), ): """ - 获取指定用户的热门记忆标签 + 获取指定用户的兴趣分布标签 - 注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译 + 与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等), + 过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。 返回格式: [ - {"name": "标签名", "frequency": 频次}, + {"name": "兴趣活动名", "frequency": 频次}, ... ] """ - api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}") + language = get_language_from_header(language_type) + api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}") try: - result = await memory_agent_service.get_hot_memory_tags_by_user( + result = await memory_agent_service.get_interest_distribution_by_user( end_user_id=end_user_id, - limit=limit + limit=limit, + language=language ) - return success(data=result, msg="获取热门记忆标签成功") + return success(data=result, msg="获取兴趣分布标签成功") except Exception as e: - api_logger.error(f"Hot memory tags by user failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e)) + api_logger.error(f"Interest distribution by user failed: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e)) @router.get("/analytics/user_profile", response_model=ApiResponse) diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index abb0f138..da08e88e 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -16,6 +16,10 @@ class FilteredTags(BaseModel): """用于接收LLM筛选后的核心标签列表的模型。""" meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。") +class InterestTags(BaseModel): + """用于接收LLM筛选后的兴趣活动标签列表的模型。""" + interest_tags: List[str] = Field(..., description="从原始列表中筛选出的代表用户兴趣活动的标签列表。") + async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]: """ 使用LLM筛选标签列表,仅保留具有代表性的核心名词。 @@ -89,6 +93,70 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]: # 在LLM失败时返回原始标签,确保流程继续 return tags +async def filter_interests_with_llm(tags: List[str], end_user_id: str, language: str = "zh") -> List[str]: + """ + 使用LLM从标签列表中筛选出代表用户兴趣活动的标签。 + + 与 filter_tags_with_llm 不同,此函数专注于识别"活动/行为"类兴趣, + 过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。 + + Args: + tags: 原始标签列表 + end_user_id: 用户ID,用于获取LLM配置 + + Returns: + 筛选后的兴趣活动标签列表 + """ + try: + with get_db_context() as db: + from app.services.memory_agent_service import ( + get_end_user_connected_config, + ) + connected_config = get_end_user_connected_config(end_user_id, db) + config_id = connected_config.get("memory_config_id") + workspace_id = connected_config.get("workspace_id") + + if not config_id and not workspace_id: + raise ValueError( + f"No memory_config_id found for end_user_id: {end_user_id}." + ) + + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + workspace_id=workspace_id + ) + + if not memory_config.llm_model_id: + raise ValueError( + f"No llm_model_id found in memory config {config_id}." + ) + + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(memory_config.llm_model_id) + + tag_list_str = ", ".join(tags) + from app.core.memory.utils.prompt.prompt_utils import render_interest_filter_prompt + rendered_prompt = render_interest_filter_prompt(tag_list_str, language=language) + messages = [ + { + "role": "user", + "content": rendered_prompt + } + ] + + structured_response = await llm_client.response_structured( + messages=messages, + response_model=InterestTags + ) + + return structured_response.interest_tags + + except Exception as e: + print(f"兴趣标签LLM筛选过程中发生错误: {e}") + return tags + + async def get_raw_tags_from_db( connector: Neo4jConnector, end_user_id: str, @@ -183,3 +251,47 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = finally: # 确保关闭连接 await connector.close() + +async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: bool = False, language: str = "zh") -> List[Tuple[str, int]]: + """ + 获取用户的兴趣分布标签。 + + 与 get_hot_memory_tags 不同,此函数使用专门针对"活动/行为"的LLM prompt, + 过滤掉纯物品、工具、地点等,只保留能代表用户兴趣爱好的活动类标签。 + + Args: + end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id + limit: 最终返回的标签数量限制(默认10) + by_user: 是否按user_id查询(默认False,按end_user_id查询) + + Raises: + ValueError: 如果end_user_id未提供或为空 + """ + if not end_user_id or not end_user_id.strip(): + raise ValueError( + "end_user_id is required. Please provide a valid end_user_id or user_id." + ) + + connector = Neo4jConnector() + try: + # 查询更多原始标签,给LLM提供充足上下文 + query_limit = 40 + raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user) + if not raw_tags_with_freq: + return [] + + raw_tag_names = [tag for tag, freq in raw_tags_with_freq] + + # 使用兴趣活动专用prompt进行筛选 + interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language) + + # 保留原始频率,按兴趣筛选结果过滤 + final_tags = [ + (tag, freq) + for tag, freq in raw_tags_with_freq + if tag in interest_tag_names + ] + + return final_tags[:limit] + finally: + await connector.close() diff --git a/api/app/core/memory/utils/prompt/prompt_utils.py b/api/app/core/memory/utils/prompt/prompt_utils.py index d88f50cf..0cea98f2 100644 --- a/api/app/core/memory/utils/prompt/prompt_utils.py +++ b/api/app/core/memory/utils/prompt/prompt_utils.py @@ -548,3 +548,20 @@ async def render_ontology_extraction_prompt( }) return rendered_prompt + + +def render_interest_filter_prompt(tag_list: str, language: str = "zh") -> str: + """ + Renders the interest filter prompt using the interest_filter.jinja2 template. + + Args: + tag_list: Comma-separated string of raw tags to filter + language: Output language ("zh" for Chinese, "en" for English) + + Returns: + Rendered prompt content as string + """ + template = prompt_env.get_template("interest_filter.jinja2") + rendered_prompt = template.render(tag_list=tag_list, language=language) + log_prompt_rendering('interest filter', rendered_prompt) + return rendered_prompt diff --git a/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 b/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 new file mode 100644 index 00000000..1e3aac55 --- /dev/null +++ b/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 @@ -0,0 +1,47 @@ +{% if language == "zh" %} +You are a user interest analysis expert. Your task is to identify activity-based tags from a tag list that represent the user's hobbies and interests. Please output the results in Chinese. + +**Keep Rules** (keep if any condition is met): +- Tags representing sports or physical activities the user actively participates in (e.g., '攀岩', '篮球', '游泳', '跑步') +- Tags representing cultural or entertainment hobbies (e.g., '读书', '看电影', '听音乐', '摄影') +- Tags representing learning or creative activities (e.g., '编程', '绘画', '写作', '烹饪') +- Tags representing specific interest domains or hobby categories (e.g., '历史', '天文', '园艺') + +**Filter Rules** (remove if any condition is met): +- Pure object or tool names that do not represent an activity (e.g., '篮球鞋', '相机', '书桌') +- Pure location or venue names (e.g., '篮球场', '图书馆', '健身房') +- Abstract concepts or quality descriptions (e.g., '核心力量', '团队合作', '专注力') +- Person names, brand names, or proper nouns (e.g., '乔丹', 'Nike') + +**Merge Rules**: For semantically similar tags, keep only the most representative one. +For example: keep '篮球' over '打篮球'; keep '读书' over '阅读'. + +**Example**: +Input: ['攀岩', '篮球场', '篮球鞋', '篮球', '《三体》', '历史', '核心力量', '烹饪', '菜刀'] +Output: ['攀岩', '篮球', '历史', '烹饪'] + +Please filter the following tag list and return only the tags that represent user interest activities in Chinese: {{ tag_list }} +{% else %} +You are a user interest analysis expert. Your task is to identify activity-based tags from a tag list that represent the user's hobbies and interests. Please output the results in English. + +**Keep Rules** (keep if any condition is met): +- Tags representing sports or physical activities the user actively participates in (e.g., 'rock climbing', 'basketball', 'swimming', 'running') +- Tags representing cultural or entertainment hobbies (e.g., 'reading', 'watching movies', 'listening to music', 'photography') +- Tags representing learning or creative activities (e.g., 'programming', 'painting', 'writing', 'cooking') +- Tags representing specific interest domains or hobby categories (e.g., 'history', 'astronomy', 'gardening') + +**Filter Rules** (remove if any condition is met): +- Pure object or tool names that do not represent an activity (e.g., 'basketball shoes', 'camera', 'desk') +- Pure location or venue names (e.g., 'basketball court', 'library', 'gym') +- Abstract concepts or quality descriptions (e.g., 'core strength', 'teamwork', 'focus') +- Person names, brand names, or proper nouns (e.g., 'Jordan', 'Nike') + +**Merge Rules**: For semantically similar tags, keep only the most representative one. +For example: keep 'basketball' over 'playing basketball'; keep 'reading' over 'reading books'. + +**Example**: +Input: ['rock climbing', 'basketball court', 'basketball shoes', 'basketball', 'The Three-Body Problem', 'history', 'core strength', 'cooking', 'kitchen knife'] +Output: ['rock climbing', 'basketball', 'history', 'cooking'] + +Please filter the following tag list and return only the tags that represent user interest activities in English: {{ tag_list }} +{% endif %} diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 1f3667a6..16aee283 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -36,7 +36,7 @@ from app.core.memory.agent.utils.messages_tools import ( ) from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 -from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags +from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType @@ -890,36 +890,36 @@ class MemoryAgentService: return result - async def get_hot_memory_tags_by_user( + + async def get_interest_distribution_by_user( self, end_user_id: Optional[str] = None, - limit: int = 20 + limit: int = 5, + language: str = "zh" ) -> List[Dict[str, Any]]: """ - 获取指定用户的热门记忆标签 + 获取指定用户的兴趣分布标签。 + + 与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习等), + 过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。 参数: - - end_user_id: 用户ID(可选),对应Neo4j中的end_user_id字段 + - end_user_id: 用户ID(必填) - limit: 返回标签数量限制 + - language: 输出语言("zh" 中文, "en" 英文) 返回格式: [ - {"name": "标签名", "frequency": 频次}, + {"name": "兴趣活动名", "frequency": 频次}, ... ] - - 注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译 """ try: - # by_user=False 表示按 end_user_id 查询(在Neo4j中,end_user_id就是用户维度) - tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False) - payload = [] - for tag, freq in tags: - payload.append({"name": tag, "frequency": freq}) - return payload + tags = await get_interest_distribution(end_user_id, limit=limit, by_user=False, language=language) + return [{"name": tag, "frequency": freq} for tag, freq in tags] except Exception as e: - logger.error(f"热门记忆标签查询失败: {e}") - raise Exception(f"热门记忆标签查询失败: {e}") + logger.error(f"兴趣分布标签查询失败: {e}") + raise Exception(f"兴趣分布标签查询失败: {e}") async def get_user_profile( From 68c4c7429c5f8523e34774bde35bbdd6ff5f168c Mon Sep 17 00:00:00 2001 From: zhaoying Date: Wed, 4 Mar 2026 10:59:29 +0800 Subject: [PATCH 062/164] feat(web): change interest distribution api --- web/src/api/memory.ts | 8 ++++---- .../UserMemoryDetail/components/InterestDistribution.tsx | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index 987ef358..ef7aa460 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 14:00:06 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 14:00:06 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-04 10:58:41 */ import { request } from '@/utils/request' import type { @@ -98,8 +98,8 @@ export const getMemorySearchEdges = (end_user_id: string) => { return request.get(`/memory-storage/analytics/graph_data`, { end_user_id }) } // User Memory - User interest distribution -export const getHotMemoryTagsByUser = (end_user_id: string) => { - return request.get(`/memory/analytics/hot_memory_tags/by_user`, { end_user_id }) +export const getInterestDistributionByUser = (end_user_id: string) => { + return request.get(`/memory/analytics/interest_distribution/by_user`, { end_user_id }) } // User Memory - Total memory count export const getTotalMemoryCountByUser = (end_user_id: string) => { diff --git a/web/src/views/UserMemoryDetail/components/InterestDistribution.tsx b/web/src/views/UserMemoryDetail/components/InterestDistribution.tsx index d48013b3..849e1eb2 100644 --- a/web/src/views/UserMemoryDetail/components/InterestDistribution.tsx +++ b/web/src/views/UserMemoryDetail/components/InterestDistribution.tsx @@ -15,7 +15,7 @@ import { useParams } from 'react-router-dom' import ReactEcharts from 'echarts-for-react'; import { Space } from 'antd' -import { getHotMemoryTagsByUser } from '@/api/memory'; +import { getInterestDistributionByUser } from '@/api/memory'; import Empty from '@/components/Empty'; import Loading from '@/components/Empty/Loading'; import RbCard from '@/components/RbCard/Card'; @@ -38,7 +38,7 @@ const InterestDistribution: FC = () => { /** Fetch interest distribution data */ const getData = () => { setLoading(true) - getHotMemoryTagsByUser(id as string).then(res => { + getInterestDistributionByUser(id as string).then(res => { const response = res as { name: string; frequency: number }[] setData(response.map(item => ({ ...item, From 9115ad6950dc30c5fa20aaa0c0d9f3aa23cca7a9 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 3 Mar 2026 23:30:54 +0800 Subject: [PATCH 063/164] [fix] Revising the judgment method for the interest analysis tags --- .../controllers/memory_agent_controller.py | 32 ++--- .../core/memory/analytics/hot_memory_tags.py | 112 ++++++++++++++++++ .../core/memory/utils/prompt/prompt_utils.py | 17 +++ .../prompt/prompts/interest_filter.jinja2 | 47 ++++++++ api/app/services/memory_agent_service.py | 32 ++--- 5 files changed, 210 insertions(+), 30 deletions(-) create mode 100644 api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index b88e65ff..8f2e5c31 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -661,34 +661,38 @@ async def get_knowledge_type_stats_api( return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e)) -@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse) -async def get_hot_memory_tags_by_user_api( - end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), - limit: int = Query(20, description="返回标签数量限制"), +@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse) +async def get_interest_distribution_by_user_api( + end_user_id: Optional[str] = Query(None, description="用户ID(必填)"), + limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"), + language_type: str = Header(default=None, alias="X-Language-Type"), current_user: User = Depends(get_current_user), - db: Session=Depends(get_db), + db: Session = Depends(get_db), ): """ - 获取指定用户的热门记忆标签 + 获取指定用户的兴趣分布标签 - 注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译 + 与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等), + 过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。 返回格式: [ - {"name": "标签名", "frequency": 频次}, + {"name": "兴趣活动名", "frequency": 频次}, ... ] """ - api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}") + language = get_language_from_header(language_type) + api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}") try: - result = await memory_agent_service.get_hot_memory_tags_by_user( + result = await memory_agent_service.get_interest_distribution_by_user( end_user_id=end_user_id, - limit=limit + limit=limit, + language=language ) - return success(data=result, msg="获取热门记忆标签成功") + return success(data=result, msg="获取兴趣分布标签成功") except Exception as e: - api_logger.error(f"Hot memory tags by user failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e)) + api_logger.error(f"Interest distribution by user failed: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e)) @router.get("/analytics/user_profile", response_model=ApiResponse) diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index abb0f138..da08e88e 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -16,6 +16,10 @@ class FilteredTags(BaseModel): """用于接收LLM筛选后的核心标签列表的模型。""" meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。") +class InterestTags(BaseModel): + """用于接收LLM筛选后的兴趣活动标签列表的模型。""" + interest_tags: List[str] = Field(..., description="从原始列表中筛选出的代表用户兴趣活动的标签列表。") + async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]: """ 使用LLM筛选标签列表,仅保留具有代表性的核心名词。 @@ -89,6 +93,70 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]: # 在LLM失败时返回原始标签,确保流程继续 return tags +async def filter_interests_with_llm(tags: List[str], end_user_id: str, language: str = "zh") -> List[str]: + """ + 使用LLM从标签列表中筛选出代表用户兴趣活动的标签。 + + 与 filter_tags_with_llm 不同,此函数专注于识别"活动/行为"类兴趣, + 过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。 + + Args: + tags: 原始标签列表 + end_user_id: 用户ID,用于获取LLM配置 + + Returns: + 筛选后的兴趣活动标签列表 + """ + try: + with get_db_context() as db: + from app.services.memory_agent_service import ( + get_end_user_connected_config, + ) + connected_config = get_end_user_connected_config(end_user_id, db) + config_id = connected_config.get("memory_config_id") + workspace_id = connected_config.get("workspace_id") + + if not config_id and not workspace_id: + raise ValueError( + f"No memory_config_id found for end_user_id: {end_user_id}." + ) + + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + workspace_id=workspace_id + ) + + if not memory_config.llm_model_id: + raise ValueError( + f"No llm_model_id found in memory config {config_id}." + ) + + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(memory_config.llm_model_id) + + tag_list_str = ", ".join(tags) + from app.core.memory.utils.prompt.prompt_utils import render_interest_filter_prompt + rendered_prompt = render_interest_filter_prompt(tag_list_str, language=language) + messages = [ + { + "role": "user", + "content": rendered_prompt + } + ] + + structured_response = await llm_client.response_structured( + messages=messages, + response_model=InterestTags + ) + + return structured_response.interest_tags + + except Exception as e: + print(f"兴趣标签LLM筛选过程中发生错误: {e}") + return tags + + async def get_raw_tags_from_db( connector: Neo4jConnector, end_user_id: str, @@ -183,3 +251,47 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = finally: # 确保关闭连接 await connector.close() + +async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: bool = False, language: str = "zh") -> List[Tuple[str, int]]: + """ + 获取用户的兴趣分布标签。 + + 与 get_hot_memory_tags 不同,此函数使用专门针对"活动/行为"的LLM prompt, + 过滤掉纯物品、工具、地点等,只保留能代表用户兴趣爱好的活动类标签。 + + Args: + end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id + limit: 最终返回的标签数量限制(默认10) + by_user: 是否按user_id查询(默认False,按end_user_id查询) + + Raises: + ValueError: 如果end_user_id未提供或为空 + """ + if not end_user_id or not end_user_id.strip(): + raise ValueError( + "end_user_id is required. Please provide a valid end_user_id or user_id." + ) + + connector = Neo4jConnector() + try: + # 查询更多原始标签,给LLM提供充足上下文 + query_limit = 40 + raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user) + if not raw_tags_with_freq: + return [] + + raw_tag_names = [tag for tag, freq in raw_tags_with_freq] + + # 使用兴趣活动专用prompt进行筛选 + interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language) + + # 保留原始频率,按兴趣筛选结果过滤 + final_tags = [ + (tag, freq) + for tag, freq in raw_tags_with_freq + if tag in interest_tag_names + ] + + return final_tags[:limit] + finally: + await connector.close() diff --git a/api/app/core/memory/utils/prompt/prompt_utils.py b/api/app/core/memory/utils/prompt/prompt_utils.py index d88f50cf..0cea98f2 100644 --- a/api/app/core/memory/utils/prompt/prompt_utils.py +++ b/api/app/core/memory/utils/prompt/prompt_utils.py @@ -548,3 +548,20 @@ async def render_ontology_extraction_prompt( }) return rendered_prompt + + +def render_interest_filter_prompt(tag_list: str, language: str = "zh") -> str: + """ + Renders the interest filter prompt using the interest_filter.jinja2 template. + + Args: + tag_list: Comma-separated string of raw tags to filter + language: Output language ("zh" for Chinese, "en" for English) + + Returns: + Rendered prompt content as string + """ + template = prompt_env.get_template("interest_filter.jinja2") + rendered_prompt = template.render(tag_list=tag_list, language=language) + log_prompt_rendering('interest filter', rendered_prompt) + return rendered_prompt diff --git a/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 b/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 new file mode 100644 index 00000000..1e3aac55 --- /dev/null +++ b/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 @@ -0,0 +1,47 @@ +{% if language == "zh" %} +You are a user interest analysis expert. Your task is to identify activity-based tags from a tag list that represent the user's hobbies and interests. Please output the results in Chinese. + +**Keep Rules** (keep if any condition is met): +- Tags representing sports or physical activities the user actively participates in (e.g., '攀岩', '篮球', '游泳', '跑步') +- Tags representing cultural or entertainment hobbies (e.g., '读书', '看电影', '听音乐', '摄影') +- Tags representing learning or creative activities (e.g., '编程', '绘画', '写作', '烹饪') +- Tags representing specific interest domains or hobby categories (e.g., '历史', '天文', '园艺') + +**Filter Rules** (remove if any condition is met): +- Pure object or tool names that do not represent an activity (e.g., '篮球鞋', '相机', '书桌') +- Pure location or venue names (e.g., '篮球场', '图书馆', '健身房') +- Abstract concepts or quality descriptions (e.g., '核心力量', '团队合作', '专注力') +- Person names, brand names, or proper nouns (e.g., '乔丹', 'Nike') + +**Merge Rules**: For semantically similar tags, keep only the most representative one. +For example: keep '篮球' over '打篮球'; keep '读书' over '阅读'. + +**Example**: +Input: ['攀岩', '篮球场', '篮球鞋', '篮球', '《三体》', '历史', '核心力量', '烹饪', '菜刀'] +Output: ['攀岩', '篮球', '历史', '烹饪'] + +Please filter the following tag list and return only the tags that represent user interest activities in Chinese: {{ tag_list }} +{% else %} +You are a user interest analysis expert. Your task is to identify activity-based tags from a tag list that represent the user's hobbies and interests. Please output the results in English. + +**Keep Rules** (keep if any condition is met): +- Tags representing sports or physical activities the user actively participates in (e.g., 'rock climbing', 'basketball', 'swimming', 'running') +- Tags representing cultural or entertainment hobbies (e.g., 'reading', 'watching movies', 'listening to music', 'photography') +- Tags representing learning or creative activities (e.g., 'programming', 'painting', 'writing', 'cooking') +- Tags representing specific interest domains or hobby categories (e.g., 'history', 'astronomy', 'gardening') + +**Filter Rules** (remove if any condition is met): +- Pure object or tool names that do not represent an activity (e.g., 'basketball shoes', 'camera', 'desk') +- Pure location or venue names (e.g., 'basketball court', 'library', 'gym') +- Abstract concepts or quality descriptions (e.g., 'core strength', 'teamwork', 'focus') +- Person names, brand names, or proper nouns (e.g., 'Jordan', 'Nike') + +**Merge Rules**: For semantically similar tags, keep only the most representative one. +For example: keep 'basketball' over 'playing basketball'; keep 'reading' over 'reading books'. + +**Example**: +Input: ['rock climbing', 'basketball court', 'basketball shoes', 'basketball', 'The Three-Body Problem', 'history', 'core strength', 'cooking', 'kitchen knife'] +Output: ['rock climbing', 'basketball', 'history', 'cooking'] + +Please filter the following tag list and return only the tags that represent user interest activities in English: {{ tag_list }} +{% endif %} diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 1f3667a6..16aee283 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -36,7 +36,7 @@ from app.core.memory.agent.utils.messages_tools import ( ) from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 -from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags +from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType @@ -890,36 +890,36 @@ class MemoryAgentService: return result - async def get_hot_memory_tags_by_user( + + async def get_interest_distribution_by_user( self, end_user_id: Optional[str] = None, - limit: int = 20 + limit: int = 5, + language: str = "zh" ) -> List[Dict[str, Any]]: """ - 获取指定用户的热门记忆标签 + 获取指定用户的兴趣分布标签。 + + 与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习等), + 过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。 参数: - - end_user_id: 用户ID(可选),对应Neo4j中的end_user_id字段 + - end_user_id: 用户ID(必填) - limit: 返回标签数量限制 + - language: 输出语言("zh" 中文, "en" 英文) 返回格式: [ - {"name": "标签名", "frequency": 频次}, + {"name": "兴趣活动名", "frequency": 频次}, ... ] - - 注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译 """ try: - # by_user=False 表示按 end_user_id 查询(在Neo4j中,end_user_id就是用户维度) - tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False) - payload = [] - for tag, freq in tags: - payload.append({"name": tag, "frequency": freq}) - return payload + tags = await get_interest_distribution(end_user_id, limit=limit, by_user=False, language=language) + return [{"name": tag, "frequency": freq} for tag, freq in tags] except Exception as e: - logger.error(f"热门记忆标签查询失败: {e}") - raise Exception(f"热门记忆标签查询失败: {e}") + logger.error(f"兴趣分布标签查询失败: {e}") + raise Exception(f"兴趣分布标签查询失败: {e}") async def get_user_profile( From 31bee889d7a4eca95c42a4a2decdf9d462baf8fe Mon Sep 17 00:00:00 2001 From: zhaoying Date: Wed, 4 Mar 2026 11:52:54 +0800 Subject: [PATCH 064/164] feat(web): model add is_vision/is_omni config --- web/src/i18n/en.ts | 8 +++- web/src/i18n/zh.ts | 6 +++ .../components/CustomModelModal.tsx | 44 +++++++++++++++---- .../ModelImplement/SubModelModal.tsx | 15 ++++--- .../components/ModelListDetail.tsx | 7 +-- .../components/ModelSquareDetail.tsx | 9 ++-- web/src/views/ModelManagement/types.ts | 13 ++++-- 7 files changed, 77 insertions(+), 25 deletions(-) diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index d404dd6e..6bd3ea3f 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -603,7 +603,13 @@ export const en = { ollama: "Ollama", xinference: "Xinference", gpustack: "Gpustack", - bedrock: "Bedrock" + bedrock: "Bedrock", + + is_vision: 'Vision Support', + is_omni: 'Omni Support', + vision: 'Vision', + audio: 'Audio', + video: 'Video', }, knowledgeBase: { home: 'Home', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index fc6bb822..43c3b8be 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1184,6 +1184,12 @@ export const zh = { xinference: "Xinference", gpustack: "Gpustack", bedrock: "Bedrock", + + is_vision: '支持视觉', + is_omni: '支持全模态', + vision: '视觉', + audio: '音频', + video: '视频', }, timezones: { 'Asia/Shanghai': '中国标准时间 (UTC+8)', diff --git a/web/src/views/ModelManagement/components/CustomModelModal.tsx b/web/src/views/ModelManagement/components/CustomModelModal.tsx index 112534a5..d47fc996 100644 --- a/web/src/views/ModelManagement/components/CustomModelModal.tsx +++ b/web/src/views/ModelManagement/components/CustomModelModal.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:49:28 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-28 17:24:05 + * @Last Modified time: 2026-03-04 11:31:43 */ /** * Custom Model Modal @@ -10,8 +10,8 @@ * Supports logo upload, type/provider selection, and tagging */ -import { forwardRef, useImperativeHandle, useState } from 'react'; -import { Form, Input, App } from 'antd'; +import { forwardRef, useEffect, useImperativeHandle, useState } from 'react'; +import { Form, Input, App, Checkbox } from 'antd'; import { useTranslation } from 'react-i18next'; import type { CustomModelForm, ModelListItem, CustomModelModalRef, CustomModelModalProps } from '../types'; @@ -35,6 +35,14 @@ const CustomModelModal = forwardRef( const [isEdit, setIsEdit] = useState(false); const [form] = Form.useForm(); const [loading, setLoading] = useState(false) + const modelType = Form.useWatch(['type'], form); + const isOmni = Form.useWatch(['is_omni'], form); + + useEffect(() => { + if (isOmni) { + form.setFieldsValue({ is_vision: true }) + } + }, [isOmni]) /** Close modal and reset state */ const handleClose = () => { @@ -49,9 +57,12 @@ const CustomModelModal = forwardRef( if (model) { setIsEdit(true); setModel(model); + const { capability, is_omni, ...rest} = model form.setFieldsValue({ - ...model, - logo: model.logo && model.logo.startsWith('http') ? { url: model.logo, uid: model.logo, status: 'done', name: 'logo' } : undefined + ...rest, + logo: model.logo && model.logo.startsWith('http') ? { url: model.logo, uid: model.logo, status: 'done', name: 'logo' } : undefined, + is_omni, + is_vision: capability?.includes('vision') || false, }); } else { setIsEdit(false); @@ -79,9 +90,14 @@ const CustomModelModal = forwardRef( form .validateFields() .then((values) => { - const { logo, ...rest } = values; + const { logo, type, is_vision, is_omni, ...rest } = values; const formData: CustomModelForm = { - ...rest + ...rest, + type, + } + if (!['embedding', 'rerank'].includes(type as string)) { + formData.capability = is_omni ? ["vision", "audio"] : is_vision ? ['vision'] : [] + formData.is_omni = is_omni } if (typeof logo === 'object' && logo?.response?.data.file_id) { @@ -108,7 +124,7 @@ const CustomModelModal = forwardRef( useImperativeHandle(ref, () => ({ handleOpen, })); - + console.log('modelType', modelType) return ( ( - ( > + + {!['embedding', 'rerank'].includes(modelType as string) && + <> + + {t('modelNew.is_omni')} + + + {t('modelNew.is_vision')} + + + } ); diff --git a/web/src/views/ModelManagement/components/ModelImplement/SubModelModal.tsx b/web/src/views/ModelManagement/components/ModelImplement/SubModelModal.tsx index e312b779..b2b44bf3 100644 --- a/web/src/views/ModelManagement/components/ModelImplement/SubModelModal.tsx +++ b/web/src/views/ModelManagement/components/ModelImplement/SubModelModal.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:49:20 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 16:54:54 + * @Last Modified time: 2026-03-04 11:51:01 */ /** * Sub-Model Modal @@ -10,8 +10,8 @@ * Uses cascader for hierarchical selection */ -import { forwardRef, useImperativeHandle, useState, useEffect } from 'react'; -import { Form, Cascader, App, type CascaderProps } from 'antd'; +import { type ReactNode, forwardRef, useImperativeHandle, useState, useEffect } from 'react'; +import { Form, Cascader, App, type CascaderProps, Space } from 'antd'; import { useTranslation } from 'react-i18next'; import type { SubModelModalForm, SubModelModalRef, SubModelModalProps } from './types'; @@ -19,6 +19,7 @@ import RbModal from '@/components/RbModal' import CustomSelect from '@/components/CustomSelect' import { modelProviderUrl, getModelNewList } from '@/api/models' import type { ProviderModelItem } from '../../types' +import Tag from '@/components/Tag'; const { SHOW_CHILD } = Cascader; @@ -27,7 +28,7 @@ const { SHOW_CHILD } = Cascader; */ interface Option { value: string | number; - label: string; + label: string | ReactNode; children?: Option[]; [key: string]: any; } @@ -116,7 +117,11 @@ const SubModelModal = forwardRef(({ })) return { ...vo, - label: vo.name, + label: + {vo.name} + {t(`modelNew.${vo.type}`)} + {vo.capability?.filter(item => item !== 'video').map(vo => {t(`modelNew.${vo}`)})} + , value: vo.id, children: children } diff --git a/web/src/views/ModelManagement/components/ModelListDetail.tsx b/web/src/views/ModelManagement/components/ModelListDetail.tsx index aad7b887..5291d5c4 100644 --- a/web/src/views/ModelManagement/components/ModelListDetail.tsx +++ b/web/src/views/ModelManagement/components/ModelListDetail.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 16:49:45 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 16:49:45 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-04 11:50:47 */ /** * Model List Detail Drawer @@ -133,9 +133,10 @@ const ModelListDetail = forwardRef(({ + subTitle={ {t(`modelNew.${item.type}`)} {item.api_keys.length}{t('modelNew.apiKeyNum')} + {item.capability?.filter(item => item !=='video').map(vo => {t(`modelNew.${vo}`)})} } avatarUrl={getLogoUrl(item.logo)} avatar={ diff --git a/web/src/views/ModelManagement/components/ModelSquareDetail.tsx b/web/src/views/ModelManagement/components/ModelSquareDetail.tsx index 4fee5a7b..6826e9f5 100644 --- a/web/src/views/ModelManagement/components/ModelSquareDetail.tsx +++ b/web/src/views/ModelManagement/components/ModelSquareDetail.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:49:49 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 16:54:26 + * @Last Modified time: 2026-03-04 11:50:31 */ /** * Model Square Detail Drawer @@ -89,9 +89,10 @@ const ModelSquareDetail = forwardRef - {t(`modelNew.${item.type}`)} - {item.is_official && {t(`modelNew.official`)}} + subTitle={ + {t(`modelNew.${item.type}`)} + {item.is_official && {t(`modelNew.official`)}} + {item.capability?.filter(item => item !== 'video').map(vo => {t(`modelNew.${vo}`)})} } avatarUrl={getLogoUrl(item.logo)} avatar={ diff --git a/web/src/views/ModelManagement/types.ts b/web/src/views/ModelManagement/types.ts index e7e1f9ac..3233353b 100644 --- a/web/src/views/ModelManagement/types.ts +++ b/web/src/views/ModelManagement/types.ts @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 16:50:18 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 16:50:18 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-04 11:39:20 */ /** * Type definitions for Model Management @@ -148,7 +148,9 @@ export interface ModelListItem { /** Update timestamp */ updated_at: number; /** Associated API keys */ - api_keys: ModelApiKey[] + api_keys: ModelApiKey[]; + capability?: string[]; + is_omni?: boolean; } /** @@ -261,6 +263,8 @@ export interface ModelPlazaItem { add_count: number; /** Whether user has added this model */ is_added: boolean; + capability?: string[]; + is_omni?: boolean; } /** @@ -291,6 +295,9 @@ export interface CustomModelForm { /** API base URL */ api_base: string; }> + is_vision?: boolean; + is_omni?: boolean; + capability?: string[]; } /** From df34735a9bb2f15f7d9d19e92edbe857eaf75c5d Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Mar 2026 12:08:57 +0800 Subject: [PATCH 065/164] [add] Set cache for the distribution of interest tags --- api/app/cache/__init__.py | 3 +- api/app/cache/memory/__init__.py | 2 + api/app/cache/memory/interest_memory.py | 122 ++++++++++++++++++ .../controllers/memory_agent_controller.py | 19 +++ api/app/core/config.py | 2 +- .../core/memory/analytics/hot_memory_tags.py | 23 +++- .../prompt/prompts/interest_filter.jinja2 | 84 +++++++----- api/env.example | 2 +- 8 files changed, 215 insertions(+), 42 deletions(-) create mode 100644 api/app/cache/memory/interest_memory.py diff --git a/api/app/cache/__init__.py b/api/app/cache/__init__.py index a79d4cb2..46d1c959 100644 --- a/api/app/cache/__init__.py +++ b/api/app/cache/__init__.py @@ -3,9 +3,10 @@ Cache 缓存模块 提供各种缓存功能的统一入口 """ -from .memory import EmotionMemoryCache, ImplicitMemoryCache +from .memory import EmotionMemoryCache, ImplicitMemoryCache, InterestMemoryCache __all__ = [ "EmotionMemoryCache", "ImplicitMemoryCache", + "InterestMemoryCache", ] diff --git a/api/app/cache/memory/__init__.py b/api/app/cache/memory/__init__.py index 4ada3153..0e21df0f 100644 --- a/api/app/cache/memory/__init__.py +++ b/api/app/cache/memory/__init__.py @@ -5,8 +5,10 @@ Memory 缓存模块 """ from .emotion_memory import EmotionMemoryCache from .implicit_memory import ImplicitMemoryCache +from .interest_memory import InterestMemoryCache __all__ = [ "EmotionMemoryCache", "ImplicitMemoryCache", + "InterestMemoryCache", ] diff --git a/api/app/cache/memory/interest_memory.py b/api/app/cache/memory/interest_memory.py new file mode 100644 index 00000000..108e2a37 --- /dev/null +++ b/api/app/cache/memory/interest_memory.py @@ -0,0 +1,122 @@ +""" +Interest Distribution Cache + +兴趣分布缓存模块 +用于缓存用户的兴趣分布标签数据,避免重复调用模型生成 +""" +import json +import logging +from typing import Optional, List, Dict, Any +from datetime import datetime + +from app.aioRedis import aio_redis + +logger = logging.getLogger(__name__) + +# 缓存过期时间:24小时 +INTEREST_CACHE_EXPIRE = 86400 + + +class InterestMemoryCache: + """兴趣分布缓存类""" + + PREFIX = "cache:memory:interest_distribution" + + @classmethod + def _get_key(cls, end_user_id: str, language: str) -> str: + """生成 Redis key + + Args: + end_user_id: 用户ID + language: 语言类型 + + Returns: + 完整的 Redis key + """ + return f"{cls.PREFIX}:by_user:{end_user_id}:{language}" + + @classmethod + async def set_interest_distribution( + cls, + end_user_id: str, + language: str, + data: List[Dict[str, Any]], + expire: int = INTEREST_CACHE_EXPIRE, + ) -> bool: + """设置用户兴趣分布缓存 + + Args: + end_user_id: 用户ID + language: 语言类型 + data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...] + expire: 过期时间(秒),默认24小时 + + Returns: + 是否设置成功 + """ + try: + key = cls._get_key(end_user_id, language) + payload = { + "data": data, + "generated_at": datetime.now().isoformat(), + "cached": True, + } + value = json.dumps(payload, ensure_ascii=False) + await aio_redis.set(key, value, ex=expire) + logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒") + return True + except Exception as e: + logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True) + return False + + @classmethod + async def get_interest_distribution( + cls, + end_user_id: str, + language: str, + ) -> Optional[List[Dict[str, Any]]]: + """获取用户兴趣分布缓存 + + Args: + end_user_id: 用户ID + language: 语言类型 + + Returns: + 兴趣分布列表,缓存不存在或已过期返回 None + """ + try: + key = cls._get_key(end_user_id, language) + value = await aio_redis.get(key) + if value: + payload = json.loads(value) + logger.info(f"命中兴趣分布缓存: {key}") + return payload.get("data") + logger.info(f"兴趣分布缓存不存在或已过期: {key}") + return None + except Exception as e: + logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True) + return None + + @classmethod + async def delete_interest_distribution( + cls, + end_user_id: str, + language: str, + ) -> bool: + """删除用户兴趣分布缓存 + + Args: + end_user_id: 用户ID + language: 语言类型 + + Returns: + 是否删除成功 + """ + try: + key = cls._get_key(end_user_id, language) + result = await aio_redis.delete(key) + logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}") + return result > 0 + except Exception as e: + logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True) + return False diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 8f2e5c31..1f070eb6 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -1,5 +1,6 @@ from typing import List, Optional +from app.cache.memory.interest_memory import InterestMemoryCache from app.celery_app import celery_app from app.core.error_codes import BizCode from app.core.language_utils import get_language_from_header @@ -684,11 +685,29 @@ async def get_interest_distribution_by_user_api( language = get_language_from_header(language_type) api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}") try: + # 优先读取缓存 + cached = await InterestMemoryCache.get_interest_distribution( + end_user_id=end_user_id, + language=language, + ) + if cached is not None: + api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}") + return success(data=cached, msg="获取兴趣分布标签成功") + + # 缓存未命中,调用模型生成 result = await memory_agent_service.get_interest_distribution_by_user( end_user_id=end_user_id, limit=limit, language=language ) + + # 写入缓存,24小时过期 + await InterestMemoryCache.set_interest_distribution( + end_user_id=end_user_id, + language=language, + data=result, + ) + return success(data=result, msg="获取兴趣分布标签成功") except Exception as e: api_logger.error(f"Interest distribution by user failed: {str(e)}") diff --git a/api/app/core/config.py b/api/app/core/config.py index 6a2cf206..d9132be2 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -230,7 +230,7 @@ class Settings: # General Ontology Type Configuration # ======================================================================== # 通用本体文件路径列表(逗号分隔) - GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "app/core/memory/ontology_services/General_purpose_entity.ttl") + GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl") # 是否启用通用本体类型功能 ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true" diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index da08e88e..1d2d5259 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -281,16 +281,25 @@ async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: return [] raw_tag_names = [tag for tag, freq in raw_tags_with_freq] + raw_freq_map = {tag: freq for tag, freq in raw_tags_with_freq} - # 使用兴趣活动专用prompt进行筛选 + # 使用兴趣活动专用prompt进行筛选(支持语义推断出新标签) interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language) - # 保留原始频率,按兴趣筛选结果过滤 - final_tags = [ - (tag, freq) - for tag, freq in raw_tags_with_freq - if tag in interest_tag_names - ] + # 构建最终标签列表: + # - 原始标签中存在的,保留原始频率 + # - LLM推断出的新标签(不在原始列表中),赋予默认频率1 + final_tags = [] + seen = set() + for tag in interest_tag_names: + if tag in seen: + continue + seen.add(tag) + freq = raw_freq_map.get(tag, 1) + final_tags.append((tag, freq)) + + # 按频率降序排列 + final_tags.sort(key=lambda x: x[1], reverse=True) return final_tags[:limit] finally: diff --git a/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 b/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 index 1e3aac55..7957bf1c 100644 --- a/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 @@ -1,47 +1,67 @@ {% if language == "zh" %} -You are a user interest analysis expert. Your task is to identify activity-based tags from a tag list that represent the user's hobbies and interests. Please output the results in Chinese. +You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent. -**Keep Rules** (keep if any condition is met): -- Tags representing sports or physical activities the user actively participates in (e.g., '攀岩', '篮球', '游泳', '跑步') -- Tags representing cultural or entertainment hobbies (e.g., '读书', '看电影', '听音乐', '摄影') -- Tags representing learning or creative activities (e.g., '编程', '绘画', '写作', '烹饪') -- Tags representing specific interest domains or hobby categories (e.g., '历史', '天文', '园艺') +**Step 1 - Infer the underlying interest from each tag**: +Look at each tag and ask: "What hobby or interest does this tag suggest the user has?" -**Filter Rules** (remove if any condition is met): -- Pure object or tool names that do not represent an activity (e.g., '篮球鞋', '相机', '书桌') -- Pure location or venue names (e.g., '篮球场', '图书馆', '健身房') -- Abstract concepts or quality descriptions (e.g., '核心力量', '团队合作', '专注力') -- Person names, brand names, or proper nouns (e.g., '乔丹', 'Nike') +Examples of inference: +- '攀岩', '室内攀岩馆', '攀岩者数据仪表盘', '路线解锁地图', '指力', '路线等级', '当日攀岩流畅度' → '攀岩' +- '风光摄影元数据增强器', 'EXIF数据', '.CR2文件', '.NEF文件', '日出拍摄点', '曝光补偿', '光圈', '太阳高度角', '云量预测图层' → '摄影' +- '晨间冥想坚持天数', '身心协同峰值' → '冥想' +- '川味可视化', '川菜' → '烹饪' +- '开源项目命名建议', 'climbviz', '可视化', '力量增长雷达图' → '编程' 或 '数据可视化' +- '吉他', '指弹', '琴谱' → '吉他' +- '跑步', '5公里', '跑鞋' → '跑步' +- '瑜伽垫', '瑜伽课' → '瑜伽' -**Merge Rules**: For semantically similar tags, keep only the most representative one. -For example: keep '篮球' over '打篮球'; keep '读书' over '阅读'. +**Step 2 - Consolidate and deduplicate**: +- Merge tags that point to the same interest into one representative label +- Use concise, standard hobby names (e.g., '攀岩', '摄影', '编程', '烹饪', '冥想', '吉他', '跑步') +- If multiple tags all point to '攀岩', output '攀岩' only once + +**Step 3 - Filter out non-interest tags**: +Remove tags that do NOT suggest any hobby or interest: +- Generic system/assistant terms (e.g., '助手', '用户', 'AI') +- Pure abstract metrics with no clear hobby link (e.g., '完成时间', '日期', '自我评分') +- Location names with no clear hobby link (e.g., '青城山后山' alone — but if combined with photography context, infer '摄影') + +**Output format**: Return a list of concise interest activity names in Chinese. **Example**: -Input: ['攀岩', '篮球场', '篮球鞋', '篮球', '《三体》', '历史', '核心力量', '烹饪', '菜刀'] -Output: ['攀岩', '篮球', '历史', '烹饪'] +Input: ['攀岩', '攀岩者数据仪表盘', '路线解锁地图', '指力', '风光摄影元数据增强器', 'EXIF数据', '晨间冥想坚持天数', '川味可视化', '可视化', '助手', '完成时间'] +Output: ['攀岩', '摄影', '冥想', '烹饪', '编程'] -Please filter the following tag list and return only the tags that represent user interest activities in Chinese: {{ tag_list }} +Now process the following tag list and return the inferred interest activities in Chinese: {{ tag_list }} {% else %} -You are a user interest analysis expert. Your task is to identify activity-based tags from a tag list that represent the user's hobbies and interests. Please output the results in English. +You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent. -**Keep Rules** (keep if any condition is met): -- Tags representing sports or physical activities the user actively participates in (e.g., 'rock climbing', 'basketball', 'swimming', 'running') -- Tags representing cultural or entertainment hobbies (e.g., 'reading', 'watching movies', 'listening to music', 'photography') -- Tags representing learning or creative activities (e.g., 'programming', 'painting', 'writing', 'cooking') -- Tags representing specific interest domains or hobby categories (e.g., 'history', 'astronomy', 'gardening') +**Step 1 - Infer the underlying interest from each tag**: +Look at each tag and ask: "What hobby or interest does this tag suggest the user has?" -**Filter Rules** (remove if any condition is met): -- Pure object or tool names that do not represent an activity (e.g., 'basketball shoes', 'camera', 'desk') -- Pure location or venue names (e.g., 'basketball court', 'library', 'gym') -- Abstract concepts or quality descriptions (e.g., 'core strength', 'teamwork', 'focus') -- Person names, brand names, or proper nouns (e.g., 'Jordan', 'Nike') +Examples of inference: +- 'rock climbing', 'indoor climbing gym', 'climber dashboard', 'route map', 'finger strength' → 'rock climbing' +- 'landscape photography metadata enhancer', 'EXIF data', 'sunrise shooting spot', 'exposure compensation' → 'photography' +- 'morning meditation streak', 'mind-body peak' → 'meditation' +- 'Sichuan cuisine visualization', 'Sichuan food' → 'cooking' +- 'open source project', 'data visualization tool', 'Python' → 'programming' +- 'guitar', 'fingerpicking', 'sheet music' → 'guitar' +- 'running', '5km', 'running shoes' → 'running' -**Merge Rules**: For semantically similar tags, keep only the most representative one. -For example: keep 'basketball' over 'playing basketball'; keep 'reading' over 'reading books'. +**Step 2 - Consolidate and deduplicate**: +- Merge tags that point to the same interest into one representative label +- Use concise, standard hobby names (e.g., 'rock climbing', 'photography', 'programming', 'cooking', 'meditation') +- If multiple tags all point to 'rock climbing', output 'rock climbing' only once + +**Step 3 - Filter out non-interest tags**: +Remove tags that do NOT suggest any hobby or interest: +- Generic system/assistant terms (e.g., 'assistant', 'user', 'AI') +- Pure abstract metrics with no clear hobby link (e.g., 'completion time', 'date', 'self-rating') + +**Output format**: Return a list of concise interest activity names in English. **Example**: -Input: ['rock climbing', 'basketball court', 'basketball shoes', 'basketball', 'The Three-Body Problem', 'history', 'core strength', 'cooking', 'kitchen knife'] -Output: ['rock climbing', 'basketball', 'history', 'cooking'] +Input: ['rock climbing', 'climber dashboard', 'route map', 'finger strength', 'landscape photography metadata enhancer', 'EXIF data', 'morning meditation streak', 'Sichuan cuisine visualization', 'visualization', 'assistant', 'completion time'] +Output: ['rock climbing', 'photography', 'meditation', 'cooking', 'programming'] -Please filter the following tag list and return only the tags that represent user interest activities in English: {{ tag_list }} +Now process the following tag list and return the inferred interest activities in English: {{ tag_list }} {% endif %} diff --git a/api/env.example b/api/env.example index d67bbf7c..1dc4536c 100644 --- a/api/env.example +++ b/api/env.example @@ -139,7 +139,7 @@ SMTP_USER= SMTP_PASSWORD= # 本体类型融合配置 (记得写入env_example) -GENERAL_ONTOLOGY_FILES=app/core/memory/ontology_services/General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔 +GENERAL_ONTOLOGY_FILES=api/app/core/memory/ontology_services/General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔 ENABLE_GENERAL_ONTOLOGY_TYPES=true # 总开关,控制是否启用通用本体类型融合功能(false = 不使用任何本体类型指导) MAX_ONTOLOGY_TYPES_IN_PROMPT=100 # 限制传给 LLM 的类型数量,防止 Prompt 过长 CORE_GENERAL_TYPES=Person,Organization,Place,Event,Work,Concept # 定义核心类型列表,这些类型会优先包含在合并结果中 From b5703c1b8282e3e8472e897968c76cec7ac0430b Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 3 Mar 2026 23:30:54 +0800 Subject: [PATCH 066/164] [fix] Revising the judgment method for the interest analysis tags --- .../controllers/memory_agent_controller.py | 32 ++--- .../core/memory/analytics/hot_memory_tags.py | 112 ++++++++++++++++++ .../core/memory/utils/prompt/prompt_utils.py | 17 +++ .../prompt/prompts/interest_filter.jinja2 | 47 ++++++++ api/app/services/memory_agent_service.py | 32 ++--- 5 files changed, 210 insertions(+), 30 deletions(-) create mode 100644 api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index b88e65ff..8f2e5c31 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -661,34 +661,38 @@ async def get_knowledge_type_stats_api( return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e)) -@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse) -async def get_hot_memory_tags_by_user_api( - end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), - limit: int = Query(20, description="返回标签数量限制"), +@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse) +async def get_interest_distribution_by_user_api( + end_user_id: Optional[str] = Query(None, description="用户ID(必填)"), + limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"), + language_type: str = Header(default=None, alias="X-Language-Type"), current_user: User = Depends(get_current_user), - db: Session=Depends(get_db), + db: Session = Depends(get_db), ): """ - 获取指定用户的热门记忆标签 + 获取指定用户的兴趣分布标签 - 注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译 + 与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等), + 过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。 返回格式: [ - {"name": "标签名", "frequency": 频次}, + {"name": "兴趣活动名", "frequency": 频次}, ... ] """ - api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}") + language = get_language_from_header(language_type) + api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}") try: - result = await memory_agent_service.get_hot_memory_tags_by_user( + result = await memory_agent_service.get_interest_distribution_by_user( end_user_id=end_user_id, - limit=limit + limit=limit, + language=language ) - return success(data=result, msg="获取热门记忆标签成功") + return success(data=result, msg="获取兴趣分布标签成功") except Exception as e: - api_logger.error(f"Hot memory tags by user failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e)) + api_logger.error(f"Interest distribution by user failed: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e)) @router.get("/analytics/user_profile", response_model=ApiResponse) diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index abb0f138..da08e88e 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -16,6 +16,10 @@ class FilteredTags(BaseModel): """用于接收LLM筛选后的核心标签列表的模型。""" meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。") +class InterestTags(BaseModel): + """用于接收LLM筛选后的兴趣活动标签列表的模型。""" + interest_tags: List[str] = Field(..., description="从原始列表中筛选出的代表用户兴趣活动的标签列表。") + async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]: """ 使用LLM筛选标签列表,仅保留具有代表性的核心名词。 @@ -89,6 +93,70 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]: # 在LLM失败时返回原始标签,确保流程继续 return tags +async def filter_interests_with_llm(tags: List[str], end_user_id: str, language: str = "zh") -> List[str]: + """ + 使用LLM从标签列表中筛选出代表用户兴趣活动的标签。 + + 与 filter_tags_with_llm 不同,此函数专注于识别"活动/行为"类兴趣, + 过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。 + + Args: + tags: 原始标签列表 + end_user_id: 用户ID,用于获取LLM配置 + + Returns: + 筛选后的兴趣活动标签列表 + """ + try: + with get_db_context() as db: + from app.services.memory_agent_service import ( + get_end_user_connected_config, + ) + connected_config = get_end_user_connected_config(end_user_id, db) + config_id = connected_config.get("memory_config_id") + workspace_id = connected_config.get("workspace_id") + + if not config_id and not workspace_id: + raise ValueError( + f"No memory_config_id found for end_user_id: {end_user_id}." + ) + + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + workspace_id=workspace_id + ) + + if not memory_config.llm_model_id: + raise ValueError( + f"No llm_model_id found in memory config {config_id}." + ) + + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(memory_config.llm_model_id) + + tag_list_str = ", ".join(tags) + from app.core.memory.utils.prompt.prompt_utils import render_interest_filter_prompt + rendered_prompt = render_interest_filter_prompt(tag_list_str, language=language) + messages = [ + { + "role": "user", + "content": rendered_prompt + } + ] + + structured_response = await llm_client.response_structured( + messages=messages, + response_model=InterestTags + ) + + return structured_response.interest_tags + + except Exception as e: + print(f"兴趣标签LLM筛选过程中发生错误: {e}") + return tags + + async def get_raw_tags_from_db( connector: Neo4jConnector, end_user_id: str, @@ -183,3 +251,47 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = finally: # 确保关闭连接 await connector.close() + +async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: bool = False, language: str = "zh") -> List[Tuple[str, int]]: + """ + 获取用户的兴趣分布标签。 + + 与 get_hot_memory_tags 不同,此函数使用专门针对"活动/行为"的LLM prompt, + 过滤掉纯物品、工具、地点等,只保留能代表用户兴趣爱好的活动类标签。 + + Args: + end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id + limit: 最终返回的标签数量限制(默认10) + by_user: 是否按user_id查询(默认False,按end_user_id查询) + + Raises: + ValueError: 如果end_user_id未提供或为空 + """ + if not end_user_id or not end_user_id.strip(): + raise ValueError( + "end_user_id is required. Please provide a valid end_user_id or user_id." + ) + + connector = Neo4jConnector() + try: + # 查询更多原始标签,给LLM提供充足上下文 + query_limit = 40 + raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user) + if not raw_tags_with_freq: + return [] + + raw_tag_names = [tag for tag, freq in raw_tags_with_freq] + + # 使用兴趣活动专用prompt进行筛选 + interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language) + + # 保留原始频率,按兴趣筛选结果过滤 + final_tags = [ + (tag, freq) + for tag, freq in raw_tags_with_freq + if tag in interest_tag_names + ] + + return final_tags[:limit] + finally: + await connector.close() diff --git a/api/app/core/memory/utils/prompt/prompt_utils.py b/api/app/core/memory/utils/prompt/prompt_utils.py index d88f50cf..0cea98f2 100644 --- a/api/app/core/memory/utils/prompt/prompt_utils.py +++ b/api/app/core/memory/utils/prompt/prompt_utils.py @@ -548,3 +548,20 @@ async def render_ontology_extraction_prompt( }) return rendered_prompt + + +def render_interest_filter_prompt(tag_list: str, language: str = "zh") -> str: + """ + Renders the interest filter prompt using the interest_filter.jinja2 template. + + Args: + tag_list: Comma-separated string of raw tags to filter + language: Output language ("zh" for Chinese, "en" for English) + + Returns: + Rendered prompt content as string + """ + template = prompt_env.get_template("interest_filter.jinja2") + rendered_prompt = template.render(tag_list=tag_list, language=language) + log_prompt_rendering('interest filter', rendered_prompt) + return rendered_prompt diff --git a/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 b/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 new file mode 100644 index 00000000..1e3aac55 --- /dev/null +++ b/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 @@ -0,0 +1,47 @@ +{% if language == "zh" %} +You are a user interest analysis expert. Your task is to identify activity-based tags from a tag list that represent the user's hobbies and interests. Please output the results in Chinese. + +**Keep Rules** (keep if any condition is met): +- Tags representing sports or physical activities the user actively participates in (e.g., '攀岩', '篮球', '游泳', '跑步') +- Tags representing cultural or entertainment hobbies (e.g., '读书', '看电影', '听音乐', '摄影') +- Tags representing learning or creative activities (e.g., '编程', '绘画', '写作', '烹饪') +- Tags representing specific interest domains or hobby categories (e.g., '历史', '天文', '园艺') + +**Filter Rules** (remove if any condition is met): +- Pure object or tool names that do not represent an activity (e.g., '篮球鞋', '相机', '书桌') +- Pure location or venue names (e.g., '篮球场', '图书馆', '健身房') +- Abstract concepts or quality descriptions (e.g., '核心力量', '团队合作', '专注力') +- Person names, brand names, or proper nouns (e.g., '乔丹', 'Nike') + +**Merge Rules**: For semantically similar tags, keep only the most representative one. +For example: keep '篮球' over '打篮球'; keep '读书' over '阅读'. + +**Example**: +Input: ['攀岩', '篮球场', '篮球鞋', '篮球', '《三体》', '历史', '核心力量', '烹饪', '菜刀'] +Output: ['攀岩', '篮球', '历史', '烹饪'] + +Please filter the following tag list and return only the tags that represent user interest activities in Chinese: {{ tag_list }} +{% else %} +You are a user interest analysis expert. Your task is to identify activity-based tags from a tag list that represent the user's hobbies and interests. Please output the results in English. + +**Keep Rules** (keep if any condition is met): +- Tags representing sports or physical activities the user actively participates in (e.g., 'rock climbing', 'basketball', 'swimming', 'running') +- Tags representing cultural or entertainment hobbies (e.g., 'reading', 'watching movies', 'listening to music', 'photography') +- Tags representing learning or creative activities (e.g., 'programming', 'painting', 'writing', 'cooking') +- Tags representing specific interest domains or hobby categories (e.g., 'history', 'astronomy', 'gardening') + +**Filter Rules** (remove if any condition is met): +- Pure object or tool names that do not represent an activity (e.g., 'basketball shoes', 'camera', 'desk') +- Pure location or venue names (e.g., 'basketball court', 'library', 'gym') +- Abstract concepts or quality descriptions (e.g., 'core strength', 'teamwork', 'focus') +- Person names, brand names, or proper nouns (e.g., 'Jordan', 'Nike') + +**Merge Rules**: For semantically similar tags, keep only the most representative one. +For example: keep 'basketball' over 'playing basketball'; keep 'reading' over 'reading books'. + +**Example**: +Input: ['rock climbing', 'basketball court', 'basketball shoes', 'basketball', 'The Three-Body Problem', 'history', 'core strength', 'cooking', 'kitchen knife'] +Output: ['rock climbing', 'basketball', 'history', 'cooking'] + +Please filter the following tag list and return only the tags that represent user interest activities in English: {{ tag_list }} +{% endif %} diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 1f3667a6..16aee283 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -36,7 +36,7 @@ from app.core.memory.agent.utils.messages_tools import ( ) from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 -from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags +from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType @@ -890,36 +890,36 @@ class MemoryAgentService: return result - async def get_hot_memory_tags_by_user( + + async def get_interest_distribution_by_user( self, end_user_id: Optional[str] = None, - limit: int = 20 + limit: int = 5, + language: str = "zh" ) -> List[Dict[str, Any]]: """ - 获取指定用户的热门记忆标签 + 获取指定用户的兴趣分布标签。 + + 与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习等), + 过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。 参数: - - end_user_id: 用户ID(可选),对应Neo4j中的end_user_id字段 + - end_user_id: 用户ID(必填) - limit: 返回标签数量限制 + - language: 输出语言("zh" 中文, "en" 英文) 返回格式: [ - {"name": "标签名", "frequency": 频次}, + {"name": "兴趣活动名", "frequency": 频次}, ... ] - - 注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译 """ try: - # by_user=False 表示按 end_user_id 查询(在Neo4j中,end_user_id就是用户维度) - tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False) - payload = [] - for tag, freq in tags: - payload.append({"name": tag, "frequency": freq}) - return payload + tags = await get_interest_distribution(end_user_id, limit=limit, by_user=False, language=language) + return [{"name": tag, "frequency": freq} for tag, freq in tags] except Exception as e: - logger.error(f"热门记忆标签查询失败: {e}") - raise Exception(f"热门记忆标签查询失败: {e}") + logger.error(f"兴趣分布标签查询失败: {e}") + raise Exception(f"兴趣分布标签查询失败: {e}") async def get_user_profile( From c31a92bf01a87721afb0c87272975704b0322ad7 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Mar 2026 12:08:57 +0800 Subject: [PATCH 067/164] [add] Set cache for the distribution of interest tags --- api/app/cache/__init__.py | 3 +- api/app/cache/memory/__init__.py | 2 + api/app/cache/memory/interest_memory.py | 122 ++++++++++++++++++ .../controllers/memory_agent_controller.py | 19 +++ api/app/core/config.py | 2 +- .../core/memory/analytics/hot_memory_tags.py | 23 +++- .../prompt/prompts/interest_filter.jinja2 | 84 +++++++----- api/env.example | 2 +- 8 files changed, 215 insertions(+), 42 deletions(-) create mode 100644 api/app/cache/memory/interest_memory.py diff --git a/api/app/cache/__init__.py b/api/app/cache/__init__.py index a79d4cb2..46d1c959 100644 --- a/api/app/cache/__init__.py +++ b/api/app/cache/__init__.py @@ -3,9 +3,10 @@ Cache 缓存模块 提供各种缓存功能的统一入口 """ -from .memory import EmotionMemoryCache, ImplicitMemoryCache +from .memory import EmotionMemoryCache, ImplicitMemoryCache, InterestMemoryCache __all__ = [ "EmotionMemoryCache", "ImplicitMemoryCache", + "InterestMemoryCache", ] diff --git a/api/app/cache/memory/__init__.py b/api/app/cache/memory/__init__.py index 4ada3153..0e21df0f 100644 --- a/api/app/cache/memory/__init__.py +++ b/api/app/cache/memory/__init__.py @@ -5,8 +5,10 @@ Memory 缓存模块 """ from .emotion_memory import EmotionMemoryCache from .implicit_memory import ImplicitMemoryCache +from .interest_memory import InterestMemoryCache __all__ = [ "EmotionMemoryCache", "ImplicitMemoryCache", + "InterestMemoryCache", ] diff --git a/api/app/cache/memory/interest_memory.py b/api/app/cache/memory/interest_memory.py new file mode 100644 index 00000000..108e2a37 --- /dev/null +++ b/api/app/cache/memory/interest_memory.py @@ -0,0 +1,122 @@ +""" +Interest Distribution Cache + +兴趣分布缓存模块 +用于缓存用户的兴趣分布标签数据,避免重复调用模型生成 +""" +import json +import logging +from typing import Optional, List, Dict, Any +from datetime import datetime + +from app.aioRedis import aio_redis + +logger = logging.getLogger(__name__) + +# 缓存过期时间:24小时 +INTEREST_CACHE_EXPIRE = 86400 + + +class InterestMemoryCache: + """兴趣分布缓存类""" + + PREFIX = "cache:memory:interest_distribution" + + @classmethod + def _get_key(cls, end_user_id: str, language: str) -> str: + """生成 Redis key + + Args: + end_user_id: 用户ID + language: 语言类型 + + Returns: + 完整的 Redis key + """ + return f"{cls.PREFIX}:by_user:{end_user_id}:{language}" + + @classmethod + async def set_interest_distribution( + cls, + end_user_id: str, + language: str, + data: List[Dict[str, Any]], + expire: int = INTEREST_CACHE_EXPIRE, + ) -> bool: + """设置用户兴趣分布缓存 + + Args: + end_user_id: 用户ID + language: 语言类型 + data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...] + expire: 过期时间(秒),默认24小时 + + Returns: + 是否设置成功 + """ + try: + key = cls._get_key(end_user_id, language) + payload = { + "data": data, + "generated_at": datetime.now().isoformat(), + "cached": True, + } + value = json.dumps(payload, ensure_ascii=False) + await aio_redis.set(key, value, ex=expire) + logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒") + return True + except Exception as e: + logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True) + return False + + @classmethod + async def get_interest_distribution( + cls, + end_user_id: str, + language: str, + ) -> Optional[List[Dict[str, Any]]]: + """获取用户兴趣分布缓存 + + Args: + end_user_id: 用户ID + language: 语言类型 + + Returns: + 兴趣分布列表,缓存不存在或已过期返回 None + """ + try: + key = cls._get_key(end_user_id, language) + value = await aio_redis.get(key) + if value: + payload = json.loads(value) + logger.info(f"命中兴趣分布缓存: {key}") + return payload.get("data") + logger.info(f"兴趣分布缓存不存在或已过期: {key}") + return None + except Exception as e: + logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True) + return None + + @classmethod + async def delete_interest_distribution( + cls, + end_user_id: str, + language: str, + ) -> bool: + """删除用户兴趣分布缓存 + + Args: + end_user_id: 用户ID + language: 语言类型 + + Returns: + 是否删除成功 + """ + try: + key = cls._get_key(end_user_id, language) + result = await aio_redis.delete(key) + logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}") + return result > 0 + except Exception as e: + logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True) + return False diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 8f2e5c31..1f070eb6 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -1,5 +1,6 @@ from typing import List, Optional +from app.cache.memory.interest_memory import InterestMemoryCache from app.celery_app import celery_app from app.core.error_codes import BizCode from app.core.language_utils import get_language_from_header @@ -684,11 +685,29 @@ async def get_interest_distribution_by_user_api( language = get_language_from_header(language_type) api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}") try: + # 优先读取缓存 + cached = await InterestMemoryCache.get_interest_distribution( + end_user_id=end_user_id, + language=language, + ) + if cached is not None: + api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}") + return success(data=cached, msg="获取兴趣分布标签成功") + + # 缓存未命中,调用模型生成 result = await memory_agent_service.get_interest_distribution_by_user( end_user_id=end_user_id, limit=limit, language=language ) + + # 写入缓存,24小时过期 + await InterestMemoryCache.set_interest_distribution( + end_user_id=end_user_id, + language=language, + data=result, + ) + return success(data=result, msg="获取兴趣分布标签成功") except Exception as e: api_logger.error(f"Interest distribution by user failed: {str(e)}") diff --git a/api/app/core/config.py b/api/app/core/config.py index 4472d373..d04e2a43 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -229,7 +229,7 @@ class Settings: # General Ontology Type Configuration # ======================================================================== # 通用本体文件路径列表(逗号分隔) - GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "app/core/memory/ontology_services/General_purpose_entity.ttl") + GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl") # 是否启用通用本体类型功能 ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true" diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index da08e88e..1d2d5259 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -281,16 +281,25 @@ async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: return [] raw_tag_names = [tag for tag, freq in raw_tags_with_freq] + raw_freq_map = {tag: freq for tag, freq in raw_tags_with_freq} - # 使用兴趣活动专用prompt进行筛选 + # 使用兴趣活动专用prompt进行筛选(支持语义推断出新标签) interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language) - # 保留原始频率,按兴趣筛选结果过滤 - final_tags = [ - (tag, freq) - for tag, freq in raw_tags_with_freq - if tag in interest_tag_names - ] + # 构建最终标签列表: + # - 原始标签中存在的,保留原始频率 + # - LLM推断出的新标签(不在原始列表中),赋予默认频率1 + final_tags = [] + seen = set() + for tag in interest_tag_names: + if tag in seen: + continue + seen.add(tag) + freq = raw_freq_map.get(tag, 1) + final_tags.append((tag, freq)) + + # 按频率降序排列 + final_tags.sort(key=lambda x: x[1], reverse=True) return final_tags[:limit] finally: diff --git a/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 b/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 index 1e3aac55..7957bf1c 100644 --- a/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 @@ -1,47 +1,67 @@ {% if language == "zh" %} -You are a user interest analysis expert. Your task is to identify activity-based tags from a tag list that represent the user's hobbies and interests. Please output the results in Chinese. +You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent. -**Keep Rules** (keep if any condition is met): -- Tags representing sports or physical activities the user actively participates in (e.g., '攀岩', '篮球', '游泳', '跑步') -- Tags representing cultural or entertainment hobbies (e.g., '读书', '看电影', '听音乐', '摄影') -- Tags representing learning or creative activities (e.g., '编程', '绘画', '写作', '烹饪') -- Tags representing specific interest domains or hobby categories (e.g., '历史', '天文', '园艺') +**Step 1 - Infer the underlying interest from each tag**: +Look at each tag and ask: "What hobby or interest does this tag suggest the user has?" -**Filter Rules** (remove if any condition is met): -- Pure object or tool names that do not represent an activity (e.g., '篮球鞋', '相机', '书桌') -- Pure location or venue names (e.g., '篮球场', '图书馆', '健身房') -- Abstract concepts or quality descriptions (e.g., '核心力量', '团队合作', '专注力') -- Person names, brand names, or proper nouns (e.g., '乔丹', 'Nike') +Examples of inference: +- '攀岩', '室内攀岩馆', '攀岩者数据仪表盘', '路线解锁地图', '指力', '路线等级', '当日攀岩流畅度' → '攀岩' +- '风光摄影元数据增强器', 'EXIF数据', '.CR2文件', '.NEF文件', '日出拍摄点', '曝光补偿', '光圈', '太阳高度角', '云量预测图层' → '摄影' +- '晨间冥想坚持天数', '身心协同峰值' → '冥想' +- '川味可视化', '川菜' → '烹饪' +- '开源项目命名建议', 'climbviz', '可视化', '力量增长雷达图' → '编程' 或 '数据可视化' +- '吉他', '指弹', '琴谱' → '吉他' +- '跑步', '5公里', '跑鞋' → '跑步' +- '瑜伽垫', '瑜伽课' → '瑜伽' -**Merge Rules**: For semantically similar tags, keep only the most representative one. -For example: keep '篮球' over '打篮球'; keep '读书' over '阅读'. +**Step 2 - Consolidate and deduplicate**: +- Merge tags that point to the same interest into one representative label +- Use concise, standard hobby names (e.g., '攀岩', '摄影', '编程', '烹饪', '冥想', '吉他', '跑步') +- If multiple tags all point to '攀岩', output '攀岩' only once + +**Step 3 - Filter out non-interest tags**: +Remove tags that do NOT suggest any hobby or interest: +- Generic system/assistant terms (e.g., '助手', '用户', 'AI') +- Pure abstract metrics with no clear hobby link (e.g., '完成时间', '日期', '自我评分') +- Location names with no clear hobby link (e.g., '青城山后山' alone — but if combined with photography context, infer '摄影') + +**Output format**: Return a list of concise interest activity names in Chinese. **Example**: -Input: ['攀岩', '篮球场', '篮球鞋', '篮球', '《三体》', '历史', '核心力量', '烹饪', '菜刀'] -Output: ['攀岩', '篮球', '历史', '烹饪'] +Input: ['攀岩', '攀岩者数据仪表盘', '路线解锁地图', '指力', '风光摄影元数据增强器', 'EXIF数据', '晨间冥想坚持天数', '川味可视化', '可视化', '助手', '完成时间'] +Output: ['攀岩', '摄影', '冥想', '烹饪', '编程'] -Please filter the following tag list and return only the tags that represent user interest activities in Chinese: {{ tag_list }} +Now process the following tag list and return the inferred interest activities in Chinese: {{ tag_list }} {% else %} -You are a user interest analysis expert. Your task is to identify activity-based tags from a tag list that represent the user's hobbies and interests. Please output the results in English. +You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent. -**Keep Rules** (keep if any condition is met): -- Tags representing sports or physical activities the user actively participates in (e.g., 'rock climbing', 'basketball', 'swimming', 'running') -- Tags representing cultural or entertainment hobbies (e.g., 'reading', 'watching movies', 'listening to music', 'photography') -- Tags representing learning or creative activities (e.g., 'programming', 'painting', 'writing', 'cooking') -- Tags representing specific interest domains or hobby categories (e.g., 'history', 'astronomy', 'gardening') +**Step 1 - Infer the underlying interest from each tag**: +Look at each tag and ask: "What hobby or interest does this tag suggest the user has?" -**Filter Rules** (remove if any condition is met): -- Pure object or tool names that do not represent an activity (e.g., 'basketball shoes', 'camera', 'desk') -- Pure location or venue names (e.g., 'basketball court', 'library', 'gym') -- Abstract concepts or quality descriptions (e.g., 'core strength', 'teamwork', 'focus') -- Person names, brand names, or proper nouns (e.g., 'Jordan', 'Nike') +Examples of inference: +- 'rock climbing', 'indoor climbing gym', 'climber dashboard', 'route map', 'finger strength' → 'rock climbing' +- 'landscape photography metadata enhancer', 'EXIF data', 'sunrise shooting spot', 'exposure compensation' → 'photography' +- 'morning meditation streak', 'mind-body peak' → 'meditation' +- 'Sichuan cuisine visualization', 'Sichuan food' → 'cooking' +- 'open source project', 'data visualization tool', 'Python' → 'programming' +- 'guitar', 'fingerpicking', 'sheet music' → 'guitar' +- 'running', '5km', 'running shoes' → 'running' -**Merge Rules**: For semantically similar tags, keep only the most representative one. -For example: keep 'basketball' over 'playing basketball'; keep 'reading' over 'reading books'. +**Step 2 - Consolidate and deduplicate**: +- Merge tags that point to the same interest into one representative label +- Use concise, standard hobby names (e.g., 'rock climbing', 'photography', 'programming', 'cooking', 'meditation') +- If multiple tags all point to 'rock climbing', output 'rock climbing' only once + +**Step 3 - Filter out non-interest tags**: +Remove tags that do NOT suggest any hobby or interest: +- Generic system/assistant terms (e.g., 'assistant', 'user', 'AI') +- Pure abstract metrics with no clear hobby link (e.g., 'completion time', 'date', 'self-rating') + +**Output format**: Return a list of concise interest activity names in English. **Example**: -Input: ['rock climbing', 'basketball court', 'basketball shoes', 'basketball', 'The Three-Body Problem', 'history', 'core strength', 'cooking', 'kitchen knife'] -Output: ['rock climbing', 'basketball', 'history', 'cooking'] +Input: ['rock climbing', 'climber dashboard', 'route map', 'finger strength', 'landscape photography metadata enhancer', 'EXIF data', 'morning meditation streak', 'Sichuan cuisine visualization', 'visualization', 'assistant', 'completion time'] +Output: ['rock climbing', 'photography', 'meditation', 'cooking', 'programming'] -Please filter the following tag list and return only the tags that represent user interest activities in English: {{ tag_list }} +Now process the following tag list and return the inferred interest activities in English: {{ tag_list }} {% endif %} diff --git a/api/env.example b/api/env.example index d67bbf7c..1dc4536c 100644 --- a/api/env.example +++ b/api/env.example @@ -139,7 +139,7 @@ SMTP_USER= SMTP_PASSWORD= # 本体类型融合配置 (记得写入env_example) -GENERAL_ONTOLOGY_FILES=app/core/memory/ontology_services/General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔 +GENERAL_ONTOLOGY_FILES=api/app/core/memory/ontology_services/General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔 ENABLE_GENERAL_ONTOLOGY_TYPES=true # 总开关,控制是否启用通用本体类型融合功能(false = 不使用任何本体类型指导) MAX_ONTOLOGY_TYPES_IN_PROMPT=100 # 限制传给 LLM 的类型数量,防止 Prompt 过长 CORE_GENERAL_TYPES=Person,Organization,Place,Event,Work,Concept # 定义核心类型列表,这些类型会优先包含在合并结果中 From 91d20f727218240ec9613495961f0bcd6a26e8af Mon Sep 17 00:00:00 2001 From: zhaoying Date: Wed, 4 Mar 2026 12:12:21 +0800 Subject: [PATCH 068/164] feat(web): workflow chat use content replace chunk --- web/src/views/Conversation/index.tsx | 6 +++--- web/src/views/Workflow/components/Chat/Chat.tsx | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/web/src/views/Conversation/index.tsx b/web/src/views/Conversation/index.tsx index f532ac53..2ad2a5a4 100644 --- a/web/src/views/Conversation/index.tsx +++ b/web/src/views/Conversation/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:58:03 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-03 13:46:22 + * @Last Modified time: 2026-03-04 12:10:44 */ /** * Conversation Page @@ -267,8 +267,8 @@ const Conversation: FC = () => { currentConversationId = newId break case 'message': - const { content, chunk, conversation_id: curId } = item.data as { content: string; chunk: string; conversation_id: string; } - updateAssistantMessage(content ?? chunk) + const { content, conversation_id: curId } = item.data as { content: string; conversation_id: string; } + updateAssistantMessage(content) if (curId) { currentConversationId = curId; diff --git a/web/src/views/Workflow/components/Chat/Chat.tsx b/web/src/views/Workflow/components/Chat/Chat.tsx index 65989b30..e4c80e3c 100644 --- a/web/src/views/Workflow/components/Chat/Chat.tsx +++ b/web/src/views/Workflow/components/Chat/Chat.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-06 21:10:56 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-28 16:43:06 + * @Last Modified time: 2026-03-04 12:10:17 */ /** * Workflow Chat Component @@ -174,8 +174,8 @@ const Chat = forwardRef(({ appId */ const handleStreamMessage = (data: SSEMessage[]) => { data.forEach(item => { - const { chunk, conversation_id, node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status } = item.data as { - chunk: string; + const { content, conversation_id, node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status } = item.data as { + content: string; conversation_id: string | null; cycle_id: string; cycle_idx: number; @@ -202,7 +202,7 @@ const Chat = forwardRef(({ appId if (lastIndex >= 0) { newList[lastIndex] = { ...newList[lastIndex], - content: newList[lastIndex].content + chunk + content: newList[lastIndex].content + content } } return newList From c488eb0cd00e3fccc6f0ce090b6fc909597fd441 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Mar 2026 12:17:34 +0800 Subject: [PATCH 069/164] [changes] 1.Use structured logs; 2.Align the type and default value of "end_user_id" with the semantic meaning of "required". --- api/app/controllers/memory_agent_controller.py | 2 +- api/app/core/memory/analytics/hot_memory_tags.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 1f070eb6..ccf93d68 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -664,7 +664,7 @@ async def get_knowledge_type_stats_api( @router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse) async def get_interest_distribution_by_user_api( - end_user_id: Optional[str] = Query(None, description="用户ID(必填)"), + end_user_id: str = Query(..., description="用户ID(必填)"), limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"), language_type: str = Header(default=None, alias="X-Language-Type"), current_user: User = Depends(get_current_user), diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index 1d2d5259..6afcec6d 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -1,9 +1,12 @@ import asyncio import json +import logging import os from typing import List, Tuple from app.core.config import settings + +logger = logging.getLogger(__name__) from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -89,7 +92,7 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]: return structured_response.meaningful_tags except Exception as e: - print(f"LLM筛选过程中发生错误: {e}") + logger.error(f"LLM筛选过程中发生错误: {e}", exc_info=True) # 在LLM失败时返回原始标签,确保流程继续 return tags @@ -153,7 +156,7 @@ async def filter_interests_with_llm(tags: List[str], end_user_id: str, language: return structured_response.interest_tags except Exception as e: - print(f"兴趣标签LLM筛选过程中发生错误: {e}") + logger.error(f"兴趣标签LLM筛选过程中发生错误: {e}", exc_info=True) return tags From 14fcb66a9c107c32f41a519ac2315165e9f08ce3 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Wed, 4 Mar 2026 12:19:48 +0800 Subject: [PATCH 070/164] feat(web): short term detail use Markdown --- web/src/views/UserMemoryDetail/pages/ShortTermDetail.tsx | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/web/src/views/UserMemoryDetail/pages/ShortTermDetail.tsx b/web/src/views/UserMemoryDetail/pages/ShortTermDetail.tsx index 6cc8eafc..f0f9ce02 100644 --- a/web/src/views/UserMemoryDetail/pages/ShortTermDetail.tsx +++ b/web/src/views/UserMemoryDetail/pages/ShortTermDetail.tsx @@ -6,6 +6,7 @@ import { getShortTerm, } from '@/api/memory' import Empty from '@/components/Empty' +import Markdown from '@/components/Markdown' interface ShortTermItem { retrieval: Array<{ query: string; retrieval: string[]; }>; @@ -85,7 +86,9 @@ const ShortTermDetail: FC = () => { ))}
{t('shortTermDetail.answer')}
-
{vo.answer}
+
+ +
@@ -103,7 +106,9 @@ const ShortTermDetail: FC = () => { : data.long_term?.map((vo, voIdx) => (
{vo.query}
-
{vo.retrieval}
+
+ +
)) } From 163872be6eb3edb0089000f9371ca271517e3f2d Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 4 Mar 2026 12:12:52 +0800 Subject: [PATCH 071/164] fix(workflow): rename output message field --- .../workflow/engine/event_stream_handler.py | 2 +- .../engine/stream_output_coordinator.py | 2 +- api/app/core/workflow/executor.py | 2 +- api/app/services/workflow_service.py | 201 +++++++++--------- 4 files changed, 102 insertions(+), 105 deletions(-) diff --git a/api/app/core/workflow/engine/event_stream_handler.py b/api/app/core/workflow/engine/event_stream_handler.py index 5b7d8de2..dc3cd04d 100644 --- a/api/app/core/workflow/engine/event_stream_handler.py +++ b/api/app/core/workflow/engine/event_stream_handler.py @@ -127,7 +127,7 @@ class EventStreamHandler: yield { "event": "message", "data": { - "chunk": data.get("chunk") + "content": data.get("chunk") } } diff --git a/api/app/core/workflow/engine/stream_output_coordinator.py b/api/app/core/workflow/engine/stream_output_coordinator.py index ba6af156..c2885ab0 100644 --- a/api/app/core/workflow/engine/stream_output_coordinator.py +++ b/api/app/core/workflow/engine/stream_output_coordinator.py @@ -274,7 +274,7 @@ class StreamOutputCoordinator: yield { "event": "message", "data": { - "chunk": final_chunk + "content": final_chunk } } diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index e781b6c4..7e5bb0e4 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -272,7 +272,7 @@ class WorkflowExecutor: event_type = data.get("type", "node_chunk") # "message" or "node_chunk" if event_type == "node_chunk": async for msg_event in self.event_handler.handle_node_chunk_event(data): - full_content += msg_event["data"]["chunk"] + full_content += msg_event["data"]["content"] yield msg_event elif event_type == "node_error": diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index a388ca75..02819efb 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -13,6 +13,7 @@ from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.workflow.adapters.registry import PlatformAdapterRegistry +from app.core.workflow.executor import execute_workflow, execute_workflow_stream from app.core.workflow.nodes.enums import NodeType from app.core.workflow.validator import validate_workflow_config from app.db import get_db @@ -23,7 +24,7 @@ from app.repositories.workflow_repository import ( WorkflowExecutionRepository, WorkflowNodeExecutionRepository ) -from app.schemas import DraftRunRequest +from app.schemas import DraftRunRequest, FileInput from app.services.conversation_service import ConversationService from app.services.multi_agent_service import convert_uuids_to_str from app.services.multimodal_service import MultimodalService @@ -445,6 +446,91 @@ class WorkflowService: "success_rate": completed / total if total > 0 else 0 } + async def _handle_file_input(self, files: list[FileInput]): + if not files: + return [] + + files_struct = [] + for file in files: + files_struct.append( + { + "type": file.type, + "url": await self.multimodal_service.get_file_url(file), + "__file": True + } + ) + return files_struct + + @staticmethod + def _map_public_event(event: dict) -> dict | None: + """ + Map internal workflow events to public-facing event formats. + + Purpose: + - Hide internal execution details + - Expose a stable and simplified public event schema + - Filter out non-public events + - Maintain backward compatibility when possible + + Args: + event (dict): Internal event object, e.g.: + { + "event": "workflow_start", + "data": {...} + } + + Returns: + dict | None: + - Returns the mapped public event + - Returns None if the event should not be exposed + """ + event_type = event.get("event") + payload = event.get("data") + match event_type: + case "workflow_start": + return { + "event": "start", + "data": { + "conversation_id": payload.get("conversation_id"), + } + } + case "workflow_end": + return { + "event": "end", + "data": { + "elapsed_time": payload.get("elapsed_time"), + "message_length": len(payload.get("output", "")), + "error": payload.get("error", "") + } + } + case "node_start" | "node_end" | "node_error" | "cycle_item": + return None + case _: + return event + + def _emit(self, public: bool, internal_event: dict): + """ + Unified event emission entry. + + Args: + public (bool): + - True -> Emit mapped public event + - False -> Emit raw internal event + + internal_event (dict): + The original internal event object + + Returns: + dict | None: + - The mapped event + - Or None if the event is filtered out + """ + if public: + mapped = self._map_public_event(internal_event) + else: + mapped = internal_event + return mapped + # ==================== 工作流执行 ==================== async def run( @@ -479,10 +565,11 @@ class WorkflowService: message=f"工作流配置不存在: app_id={app_id}" ) - input_data = {"message": payload.message, "variables": payload.variables, - "conversation_id": payload.conversation_id, - "files": [file.model_dump(mode='json') for file in payload.files] - } + input_data = { + "message": payload.message, "variables": payload.variables, + "conversation_id": payload.conversation_id, + "files": [file.model_dump(mode='json') for file in payload.files] + } # 转换 conversation_id 为 UUID conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None @@ -506,22 +593,8 @@ class WorkflowService: "execution_config": config.execution_config } - # 4. 获取工作空间 ID(从 app 获取) - - # 5. 执行工作流 - from app.core.workflow.executor import execute_workflow - try: - files = [] - if payload.files: - for file in payload.files: - files.append( - { - "type": file.type, - "url": await self.multimodal_service.get_file_url(file), - "__file": True - } - ) + files = await self._handle_file_input(payload.files) input_data["files"] = files # 更新状态为运行中 self.update_execution_status(execution.execution_id, "running") @@ -601,42 +674,6 @@ class WorkflowService: message=f"工作流执行失败: {str(e)}" ) - @staticmethod - def _map_public_event(event: dict) -> dict | None: - event_type = event.get("event") - payload = event.get("data") - match event_type: - case "workflow_start": - return { - "event": "start", - "data": { - "conversation_id": payload.get("conversation_id"), - } - } - case "workflow_end": - return { - "event": "end", - "data": { - "elapsed_time": payload.get("elapsed_time"), - "message_length": len(payload.get("output", "")), - "error": payload.get("error", "") - } - } - case "node_start" | "node_end" | "node_error" | "cycle_item": - return None - case _: - return event - - def _emit(self, public: bool, internal_event: dict): - """ - decide - """ - if public: - mapped = self._map_public_event(internal_event) - else: - mapped = internal_event - return mapped - async def run_stream( self, app_id: uuid.UUID, @@ -671,10 +708,11 @@ class WorkflowService: message=f"工作流配置不存在: app_id={app_id}" ) - input_data = {"message": payload.message, "variables": payload.variables, - "conversation_id": payload.conversation_id, - "files": [file.model_dump(mode='json') for file in payload.files] - } + input_data = { + "message": payload.message, "variables": payload.variables, + "conversation_id": payload.conversation_id, + "files": [file.model_dump(mode='json') for file in payload.files] + } # 转换 conversation_id 为 UUID conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None @@ -699,16 +737,7 @@ class WorkflowService: } try: - files = [] - if payload.files: - for file in payload.files: - files.append( - { - "type": file.type, - "url": await self.multimodal_service.get_file_url(file), - "__file": True - } - ) + files = await self._handle_file_input(payload.files) input_data["files"] = files self.update_execution_status(execution.execution_id, "running") executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) @@ -723,7 +752,6 @@ class WorkflowService: input_data["conv_messages"] = last_state.get("messages") or [] break init_message_length = len(input_data.get("conv_messages", [])) - from app.core.workflow.executor import execute_workflow_stream async for event in execute_workflow_stream( workflow_config=workflow_config_dict, @@ -789,37 +817,6 @@ class WorkflowService: return node.get("config", {}).get("variables", []) raise BusinessException("workflow config error - start node not found") - def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]: - """清理事件数据,移除不可序列化的对象 - - Args: - event: 原始事件数据 - - Returns: - 可序列化的事件数据 - """ - from langchain_core.messages import BaseMessage - - def clean_value(value): - """递归清理值""" - if isinstance(value, BaseMessage): - # 将 Message 对象转换为字典 - return { - "type": value.__class__.__name__, - "content": value.content, - } - elif isinstance(value, dict): - return {k: clean_value(v) for k, v in value.items()} - elif isinstance(value, list): - return [clean_value(item) for item in value] - elif isinstance(value, (str, int, float, bool, type(None))): - return value - else: - # 其他不可序列化的对象转换为字符串 - return str(value) - - return clean_value(event) - # ==================== 依赖注入函数 ==================== From 82794f051acb8fe15d556c679ebca58bf0749b67 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 4 Mar 2026 13:49:33 +0800 Subject: [PATCH 072/164] fix(workflow): rename output message field --- api/app/core/workflow/executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 7e5bb0e4..78149e4c 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -295,12 +295,12 @@ class WorkflowExecutor: self.graph, self.execution_context.checkpoint_config ): - full_content += msg_event["data"]['chunk'] + full_content += msg_event["data"]['content'] yield msg_event # Flush any remaining chunks async for msg_event in self.stream_coordinator.flush_remaining_chunk(self.variable_pool): - full_content += msg_event["data"]['chunk'] + full_content += msg_event["data"]['content'] yield msg_event result = graph.get_state(self.execution_context.checkpoint_config).values From b3af7571671376a7241c2baeac20590d9c4b96eb Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Mar 2026 13:51:31 +0800 Subject: [PATCH 073/164] [changes] Setting the environment variable for the scheduled task time --- api/app/celery_app.py | 6 +++--- api/app/core/config.py | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 151c1e67..c087e1d7 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -91,10 +91,10 @@ celery_app.conf.update( celery_app.autodiscover_tasks(['app']) # Celery Beat schedule for periodic tasks -memory_increment_schedule = crontab(hour=2, minute=0) # 每天凌晨 2:00 执行 +memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE) memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) -workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME -forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期 +workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS) +forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS) #构建定时任务配置 beat_schedule_config = { diff --git a/api/app/core/config.py b/api/app/core/config.py index 4472d373..6f48fec2 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -205,6 +205,12 @@ class Settings: # Memory Cache Regeneration Configuration MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24")) + # Celery Beat Schedule Configuration (定时任务执行频率) + MEMORY_INCREMENT_HOUR: int = int(os.getenv("MEMORY_INCREMENT_HOUR", "2")) + MEMORY_INCREMENT_MINUTE: int = int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")) + WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")) + FORGETTING_CYCLE_INTERVAL_HOURS: int = int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")) + # Memory Module Configuration (internal) MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory") From 6e758faa37fa4377caaf6e6945c19dfe192381d7 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Mar 2026 14:17:45 +0800 Subject: [PATCH 074/164] [changes] Using Pydantic to standardize the time data for scheduled tasks --- api/app/core/config.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/api/app/core/config.py b/api/app/core/config.py index 6f48fec2..0dbebb6d 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -1,9 +1,10 @@ import json import os from pathlib import Path -from typing import Any, Dict, Optional +from typing import Annotated, Any, Dict, Optional from dotenv import load_dotenv +from pydantic import Field, TypeAdapter load_dotenv() @@ -206,10 +207,18 @@ class Settings: MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24")) # Celery Beat Schedule Configuration (定时任务执行频率) - MEMORY_INCREMENT_HOUR: int = int(os.getenv("MEMORY_INCREMENT_HOUR", "2")) - MEMORY_INCREMENT_MINUTE: int = int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")) - WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")) - FORGETTING_CYCLE_INTERVAL_HOURS: int = int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")) + MEMORY_INCREMENT_HOUR: int = TypeAdapter( + Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")] + ).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2"))) + MEMORY_INCREMENT_MINUTE: int = TypeAdapter( + Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")] + ).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0"))) + WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter( + Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")] + ).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30"))) + FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter( + Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")] + ).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24"))) # Memory Module Configuration (internal) MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") From f326febc8a9bf8abb244262d1914cbc27efc6f10 Mon Sep 17 00:00:00 2001 From: yujiangping Date: Wed, 4 Mar 2026 14:40:27 +0800 Subject: [PATCH 075/164] feat:tool market add --- web/src/i18n/en.ts | 1 + web/src/i18n/zh.ts | 1 + web/src/views/ToolManagement/Market.tsx | 315 ++++++++++++++++++ .../components/MarketConfigModal.tsx | 173 ++++++++++ web/src/views/ToolManagement/index.tsx | 12 +- 5 files changed, 501 insertions(+), 1 deletion(-) create mode 100644 web/src/views/ToolManagement/Market.tsx create mode 100644 web/src/views/ToolManagement/components/MarketConfigModal.tsx diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index d404dd6e..b9ab09b0 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1774,6 +1774,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re mcp: 'MCP Services', inner: 'Built-in Tools', custom: 'Custom Tools', + market: 'Tool Market', mcpSearchPlaceholder: 'Search MCP Services...', innerSearchPlaceholder: 'Search Tools...', customSearchPlaceholder: 'Search Custom Tools...', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index fc6bb822..71a20207 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1771,6 +1771,7 @@ export const zh = { mcp: 'MCP 服务', inner: '内置工具', custom: '自定义工具', + market: '工具市场', mcpSearchPlaceholder: '搜索MCP服务...', innerSearchPlaceholder: '搜索工具...', customSearchPlaceholder: '搜索自定义工具...', diff --git a/web/src/views/ToolManagement/Market.tsx b/web/src/views/ToolManagement/Market.tsx new file mode 100644 index 00000000..59fbddcc --- /dev/null +++ b/web/src/views/ToolManagement/Market.tsx @@ -0,0 +1,315 @@ +import React, { useState, useRef, type ReactNode } from 'react'; +import { Input, Button, Spin, App } from 'antd'; +import { SearchOutlined, SettingOutlined, GlobalOutlined, SyncOutlined } from '@ant-design/icons'; +import { useTranslation } from 'react-i18next'; +import MarketConfigModal, { type MarketConfigModalRef } from './components/MarketConfigModal'; + +interface MarketSource { + id: string; + name: string; + category: string; + icon: string; + url: string; + desc: string; + apiKey: string; + connected: boolean; + mcpCount: number; +} + +interface MarketMcp { + id: string; + name: string; + provider: string; + type: string; + desc: string; + downloads?: string; + stars?: string; + icon: string; + configTemplate: any; +} + +interface MarketCategory { + id: string; + name: string; + icon: string; +} + +const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () => { + const { t } = useTranslation(); + const { message } = App.useApp(); + const [loading, setLoading] = useState(false); + const [selectedSource, setSelectedSource] = useState(null); + const marketConfigModalRef = useRef(null); + const [marketSources, setMarketSources] = useState([ + { id: 'smithery', name: 'Smithery', category: 'official', icon: '🔧', url: 'https://mcp.smithery.ai', desc: '官方 MCP 服务市场,提供丰富的 MCP 服务', apiKey: '', connected: false, mcpCount: 2847 }, + { id: 'mcpmarket', name: 'MCP Market', category: 'official', icon: '🏪', url: 'https://mcpmarket.com', desc: '综合性 MCP 市场平台', apiKey: '', connected: false, mcpCount: 1523 }, + { id: 'glama', name: 'Glama.ai MCP', category: 'official', icon: '✨', url: 'https://glama.ai/mcp', desc: 'Glama AI 提供的 MCP 服务集合', apiKey: '', connected: false, mcpCount: 892 }, + { id: 'github-mcp', name: 'modelcontextprotocol/servers', category: 'official', icon: '🐙', url: 'https://github.com/modelcontextprotocol/servers', desc: 'GitHub 官方 MCP 服务器仓库', apiKey: '', connected: true, mcpCount: 156 }, + { id: 'aliyun-bailian', name: '阿里云百炼 MCP', category: 'china-cloud', icon: '☁️', url: 'https://bailian.console.aliyun.com/mcp', desc: '阿里云百炼平台 MCP 市场', apiKey: '', connected: false, mcpCount: 423 }, + { id: 'modelscope', name: '魔搭社区 MCP', category: 'china-cloud', icon: '🎭', url: 'https://modelscope.cn/mcp', desc: '阿里达摩院魔搭社区 MCP 市场', apiKey: '', connected: false, mcpCount: 312 }, + ]); + + const [categories] = useState([ + { id: 'official', name: '官方/综合', icon: '🌐' }, + { id: 'china-cloud', name: '国内云', icon: '☁️' }, + { id: 'community', name: '社区/垂直', icon: '👥' } + ]); + + const [mcpCache, setMcpCache] = useState>({ + 'github-mcp': [ + { id: 'gh-1', name: 'Fetch', provider: 'modelcontextprotocol', type: 'Hosted', desc: '使用浏览器模拟大型语言模型检索和处理网页内容', downloads: '203.7m', stars: '308.2k', icon: '🌐', configTemplate: {} }, + { id: 'gh-2', name: 'Filesystem', provider: 'modelcontextprotocol', type: 'Local', desc: '安全的文件系统操作,支持读写文件和目录管理', downloads: '156.2m', stars: '245.1k', icon: '📁', configTemplate: {} }, + { id: 'gh-3', name: 'GitHub', provider: 'modelcontextprotocol', type: 'Hosted', desc: 'GitHub API 集成,支持仓库、Issue、PR 等操作', downloads: '89.4m', stars: '178.3k', icon: '🐙', configTemplate: {} }, + ] + }); + + const [searchKeyword, setSearchKeyword] = useState(''); + + const handleSelectSource = (sourceId: string) => { + setSelectedSource(sourceId); + }; + + const handleRefresh = (sourceId: string) => { + setLoading(true); + setTimeout(() => { + // 模拟刷新数据 + const source = marketSources.find(s => s.id === sourceId); + if (source) { + message.success(`${source.name} 列表已刷新`); + } + setLoading(false); + }, 600); + }; + + const handleOpenConfig = (sourceId: string) => { + const source = marketSources.find(s => s.id === sourceId); + if (source) { + marketConfigModalRef.current?.handleOpen(source); + } + }; + + const handleConnect = (sourceId: string, apiKey: string) => { + // 更新市场源状态 + setMarketSources(prev => prev.map(source => { + if (source.id === sourceId) { + return { + ...source, + apiKey, + connected: true + }; + } + return source; + })); + + // 模拟获取MCP列表 + setTimeout(() => { + const source = marketSources.find(s => s.id === sourceId); + if (source && !mcpCache[sourceId]) { + // 生成模拟数据 + const mockData: MarketMcp[] = [ + { id: `${sourceId}-1`, name: `${source.name} 服务 1`, provider: source.name, type: 'Hosted', desc: `来自 ${source.name} 的 MCP 服务`, downloads: '10.2m', stars: '23.4k', icon: '🔧', configTemplate: {} }, + { id: `${sourceId}-2`, name: `${source.name} 服务 2`, provider: source.name, type: 'Local', desc: `来自 ${source.name} 的本地 MCP 服务`, downloads: '8.5m', stars: '18.7k', icon: '⚙️', configTemplate: {} } + ]; + setMcpCache(prev => ({ + ...prev, + [sourceId]: mockData + })); + } + message.success(`已连接 ${source?.name}`); + }, 800); + }; + + const renderSourceDetail = () => { + if (!selectedSource) { + return ( +
+
🏪
+

选择一个 MCP 市场

+

从左侧选择一个市场源,配置连接后即可浏览该市场的 MCP 服务

+
+ ); + } + + const source = marketSources.find(s => s.id === selectedSource); + if (!source) return null; + + const mcpList = mcpCache[selectedSource] || []; + const filteredList = mcpList.filter(mcp => + mcp.name.toLowerCase().includes(searchKeyword.toLowerCase()) || + mcp.desc.toLowerCase().includes(searchKeyword.toLowerCase()) + ); + + return ( + <> +
+
+
+ {source.icon} +
+
+

{source.name}

+

{source.desc}

+
+
+
+ + +
+
+ +
+
+

+ 可用 MCP 服务 ({mcpList.length}) +

+
+ {source.connected && ( + + )} + {mcpList.length > 0 && ( + } + placeholder="搜索服务..." + value={searchKeyword} + onChange={(e) => setSearchKeyword(e.target.value)} + style={{ width: 200 }} + /> + )} +
+
+ + {mcpList.length > 0 ? ( + +
+ {filteredList.map(mcp => ( +
+
+
+ {mcp.icon} +
+ + {mcp.type} + +
+

{mcp.name}

+ {mcp.provider && ( +
+ @ {mcp.provider} +
+ )} +

{mcp.desc}

+
+ {mcp.downloads && ( + + {mcp.downloads} + + )} + {mcp.stars && ( + + ⭐ {mcp.stars} + + )} +
+
+ +
+
+ ))} +
+
+ ) : ( +
+
{source.connected ? '📭' : '🔌'}
+

+ {source.connected ? '暂无可用的 MCP 服务' : '尚未连接此市场'} +

+

+ {source.connected ? '该市场暂时没有可用的服务' : '点击右上角"配置"按钮设置连接信息'} +

+ {!source.connected && ( + + )} +
+ )} +
+ + ); + }; + + return ( +
+ {/* 左侧市场源列表 */} +
+
+ MCP 市场 +
+ {categories.map(cat => ( +
+
+ {cat.icon} + {cat.name} +
+
+ {marketSources + .filter(s => s.category === cat.id) + .map(source => ( +
handleSelectSource(source.id)} + > + {source.icon} + + {source.name} + + + {source.mcpCount} + + {source.connected && ( + + )} +
+ ))} +
+
+ ))} +
+ + {/* 右侧内容区 */} +
+
+ {renderSourceDetail()} +
+
+ + {/* 配置弹窗 */} + +
+ ); +}; + +export default Market; diff --git a/web/src/views/ToolManagement/components/MarketConfigModal.tsx b/web/src/views/ToolManagement/components/MarketConfigModal.tsx new file mode 100644 index 00000000..d1d87563 --- /dev/null +++ b/web/src/views/ToolManagement/components/MarketConfigModal.tsx @@ -0,0 +1,173 @@ +import { forwardRef, useImperativeHandle, useState } from 'react'; +import { Form, Input, Button, App, Space } from 'antd'; +import { useTranslation } from 'react-i18next'; +import { CopyOutlined, EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons'; +import RbModal from '@/components/RbModal'; + +const FormItem = Form.Item; + +interface MarketSource { + id: string; + name: string; + icon: string; + url: string; + desc: string; + apiKey: string; + connected: boolean; +} + +interface MarketConfigModalProps { + onConnect: (sourceId: string, apiKey: string) => void; +} + +export interface MarketConfigModalRef { + handleOpen: (source: MarketSource) => void; + handleClose: () => void; +} + +const MarketConfigModal = forwardRef(({ + onConnect +}, ref) => { + const { t } = useTranslation(); + const { message } = App.useApp(); + const [visible, setVisible] = useState(false); + const [form] = Form.useForm(); + const [loading, setLoading] = useState(false); + const [currentSource, setCurrentSource] = useState(null); + const [showApiKey, setShowApiKey] = useState(false); + + const handleClose = () => { + setVisible(false); + form.resetFields(); + setLoading(false); + setCurrentSource(null); + setShowApiKey(false); + }; + + const handleOpen = (source: MarketSource) => { + setCurrentSource(source); + form.setFieldsValue({ + url: source.url, + apiKey: source.apiKey, + }); + setVisible(true); + }; + + const handleSave = () => { + form + .validateFields() + .then((values) => { + if (!currentSource) return; + + setLoading(true); + + // 模拟连接延迟 + setTimeout(() => { + onConnect(currentSource.id, values.apiKey || ''); + message.success(`正在连接 ${currentSource.name}...`); + setLoading(false); + handleClose(); + }, 500); + }) + .catch((err) => { + console.log('表单验证失败:', err); + }); + }; + + const handleCopyUrl = () => { + if (currentSource?.url) { + navigator.clipboard.writeText(currentSource.url).then(() => { + message.success(t('common.copySuccess')); + }); + } + }; + + useImperativeHandle(ref, () => ({ + handleOpen, + handleClose + })); + + if (!currentSource) return null; + + return ( + +
+ {/* 市场源信息头部 */} +
+
+ {currentSource.icon} +
+
+

{currentSource.name}

+

{currentSource.desc}

+
+
+ +
+ {/* 市场地址 */} + + + + + + + + {/* API Key */} + + API Key (可选) + + } + extra="部分市场需要 API Key 才能获取完整的服务列表" + > + + +
+
+ ); +}); + +export default MarketConfigModal; diff --git a/web/src/views/ToolManagement/index.tsx b/web/src/views/ToolManagement/index.tsx index 9fa73067..d684ebdd 100644 --- a/web/src/views/ToolManagement/index.tsx +++ b/web/src/views/ToolManagement/index.tsx @@ -1,3 +1,11 @@ +/* + * @Description: + * @Version: 0.0.1 + * @Author: yujiangping + * @Date: 2026-01-05 17:22:23 + * @LastEditors: yujiangping + * @LastEditTime: 2026-03-04 12:24:01 + */ import React, { useState } from 'react'; import { Tabs } from 'antd'; import { useTranslation } from 'react-i18next'; @@ -5,9 +13,10 @@ import { useTranslation } from 'react-i18next'; import Mcp from './Mcp'; import Inner from './Inner'; import Custom from './Custom'; +import Market from './Market'; import Tag from '@/components/Tag' -const tabKeys = ['mcp', 'inner', 'custom'] +const tabKeys = ['mcp', 'inner', 'custom', 'market'] const ToolManagement: React.FC = () => { const { t } = useTranslation(); const [activeTab, setActiveTab] = useState('mcp'); @@ -45,6 +54,7 @@ const ToolManagement: React.FC = () => { {activeTab === 'mcp' && } {activeTab === 'inner' && } {activeTab === 'custom' && } + {activeTab === 'market' && }
); }; From 85aea97c21190fe91c014cb586e5f11a3f1b7e80 Mon Sep 17 00:00:00 2001 From: yujiangping Date: Wed, 4 Mar 2026 15:13:14 +0800 Subject: [PATCH 076/164] chore(web): disable market tab in tool management - Comment out Market component rendering in ToolManagement view - Update LastEditTime timestamp in file header - Market tab functionality temporarily disabled pending further developmen --- web/src/views/ToolManagement/index.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/src/views/ToolManagement/index.tsx b/web/src/views/ToolManagement/index.tsx index d684ebdd..c9383258 100644 --- a/web/src/views/ToolManagement/index.tsx +++ b/web/src/views/ToolManagement/index.tsx @@ -4,7 +4,7 @@ * @Author: yujiangping * @Date: 2026-01-05 17:22:23 * @LastEditors: yujiangping - * @LastEditTime: 2026-03-04 12:24:01 + * @LastEditTime: 2026-03-04 15:12:48 */ import React, { useState } from 'react'; import { Tabs } from 'antd'; @@ -54,7 +54,7 @@ const ToolManagement: React.FC = () => { {activeTab === 'mcp' && } {activeTab === 'inner' && } {activeTab === 'custom' && } - {activeTab === 'market' && } + {/* {activeTab === 'market' && } */} ); }; From d4c4160215f9bdf35f941d7c3242dbed9795fa1c Mon Sep 17 00:00:00 2001 From: lixiangcheng1 Date: Wed, 4 Mar 2026 15:28:17 +0800 Subject: [PATCH 077/164] =?UTF-8?q?=E3=80=90ADD]Knowledge=20base=20retriev?= =?UTF-8?q?al=20supports=20file=20set=20retrieval?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/app/controllers/chunk_controller.py | 8 ++++---- api/app/schemas/chunk_schema.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index 620d8a1a..988aa706 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -441,14 +441,14 @@ async def retrieve_chunks( # 1 participle search, 2 semantic search, 3 hybrid search match retrieve_data.retrieve_type: case chunk_schema.RetrieveType.PARTICIPLE: - rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold) + rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter) return success(data=rs, msg="retrieval successful") case chunk_schema.RetrieveType.SEMANTIC: - rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight) + rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter) return success(data=rs, msg="retrieval successful") case _: - rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight) - rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold) + rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter) + rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter) # Efficient deduplication seen_ids = set() unique_rs = [] diff --git a/api/app/schemas/chunk_schema.py b/api/app/schemas/chunk_schema.py index cef9b9cb..ce8f70f2 100644 --- a/api/app/schemas/chunk_schema.py +++ b/api/app/schemas/chunk_schema.py @@ -46,6 +46,7 @@ class ChunkUpdate(BaseModel): class ChunkRetrieve(BaseModel): query: str kb_ids: list[uuid.UUID] + file_names_filter: list[str] | None = Field(None) similarity_threshold: float | None = Field(None) vector_similarity_weight: float | None = Field(None) top_k: int | None = Field(None) From 778bc4bd7008f7a1ea3b0aa6a8c6b822b9240603 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 4 Mar 2026 15:58:49 +0800 Subject: [PATCH 078/164] fix(workflow): fix incorrect fields in streaming API output --- api/app/controllers/service/app_api_controller.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index 61a919b1..64143f57 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -249,6 +249,7 @@ async def chat( app_id=app.id, workspace_id=workspace_id, release_id=app.current_release.id, + public=True ): event_type = event.get("event", "message") event_data = event.get("data", {}) From 440e8acd990547baaf7e6ba70c5471693b312a14 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Wed, 4 Mar 2026 16:42:15 +0800 Subject: [PATCH 079/164] feat(web): mcp tool add form rules --- web/src/i18n/en.ts | 4 +++- web/src/i18n/zh.ts | 4 +++- .../ToolManagement/components/McpServiceModal.tsx | 14 ++++++++++++-- .../components/RequestHeaderModal.tsx | 13 +++++++++++-- 4 files changed, 29 insertions(+), 6 deletions(-) diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 94dd67cc..a80a0aa4 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1942,7 +1942,9 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re path: 'Path', viewDetail: 'View Details', textLink: 'Test Connection', - noResult: 'Processing results will be displayed here' + noResult: 'Processing results will be displayed here', + serverUrlInvalid: 'Must start with http:// or https://, and cannot have leading or trailing spaces', + requestHeaderKeyInvalid: 'Only English letters, numbers, hyphens (-), and underscores (_) are allowed, and cannot start or end with a hyphen or underscore', }, workflow: { coreNode: 'Core Nodes', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 49632789..1302fdad 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1939,7 +1939,9 @@ export const zh = { path: '路径', viewDetail: '查看详情', textLink: '测试连接', - noResult: '处理结果将显示在这里' + noResult: '处理结果将显示在这里', + serverUrlInvalid: '必须以 http:// 或 https:// 开头,且不能有前后空格', + requestHeaderKeyInvalid: '只支持英文、数字、连字符(-)、下划线(_),不能以连字符或下划线开头结尾', }, workflow: { coreNode: '核心节点', diff --git a/web/src/views/ToolManagement/components/McpServiceModal.tsx b/web/src/views/ToolManagement/components/McpServiceModal.tsx index a104c2d6..bd97b876 100644 --- a/web/src/views/ToolManagement/components/McpServiceModal.tsx +++ b/web/src/views/ToolManagement/components/McpServiceModal.tsx @@ -9,6 +9,7 @@ import RequestHeaderModal from './RequestHeaderModal'; import Table from '@/components/Table'; import { addTool, updateTool, testConnection } from '@/api/tools' import type { McpServiceModalRef } from '../types' +import { stringRegExp } from '@/utils/validator'; const FormItem = Form.Item; @@ -168,14 +169,22 @@ const McpServiceModal = forwardRef(({ name={['config', "server_url"]} label={t('tool.serviceEndpoint')} extra={t('tool.serviceEndpointExtra')} - rules={[{ required: true, message: t('common.pleaseEnter') }]} + rules={[ + { required: true, message: t('common.pleaseEnter') }, + { max: 500 }, + { pattern: /^https?:\/\/\S+$/, message: t('tool.serverUrlInvalid') }, + ]} >
@@ -201,6 +210,7 @@ const McpServiceModal = forwardRef(({ diff --git a/web/src/views/ToolManagement/components/RequestHeaderModal.tsx b/web/src/views/ToolManagement/components/RequestHeaderModal.tsx index 5e20120d..1f2bdff3 100644 --- a/web/src/views/ToolManagement/components/RequestHeaderModal.tsx +++ b/web/src/views/ToolManagement/components/RequestHeaderModal.tsx @@ -4,6 +4,7 @@ import { useTranslation } from 'react-i18next'; import type { RequestHeader, RequestHeaderModalRef } from './McpServiceModal' import RbModal from '@/components/RbModal' +import { stringRegExp } from '@/utils/validator'; const FormItem = Form.Item; @@ -82,7 +83,11 @@ const RequestHeaderModal = forwardRef @@ -90,7 +95,11 @@ const RequestHeaderModal = forwardRef From 4ee198813a360787b29356219bd49f33b3af0e68 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Wed, 4 Mar 2026 16:46:25 +0800 Subject: [PATCH 080/164] feat(web): custom tool add form rules --- .../views/ToolManagement/components/CustomToolModal.tsx | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/web/src/views/ToolManagement/components/CustomToolModal.tsx b/web/src/views/ToolManagement/components/CustomToolModal.tsx index e3203a74..0ea8aca9 100644 --- a/web/src/views/ToolManagement/components/CustomToolModal.tsx +++ b/web/src/views/ToolManagement/components/CustomToolModal.tsx @@ -6,6 +6,7 @@ import type { CustomToolItem, CustomToolModalRef, ToolItem } from '../types' import RbModal from '@/components/RbModal'; import { parseSchema, addTool, updateTool } from '@/api/tools'; import Table from '@/components/Table'; +import { stringRegExp } from '@/utils/validator'; const FormItem = Form.Item; interface CustomToolModalProps { @@ -134,7 +135,11 @@ const CustomToolModal = forwardRef(({ From f5eda38dc99edbf71affe3a56b04d40bc830b0a6 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Wed, 4 Mar 2026 17:04:25 +0800 Subject: [PATCH 081/164] feat(web): ontology extract add form rules --- web/src/i18n/en.ts | 1 + web/src/i18n/zh.ts | 1 + web/src/views/Ontology/components/OntologyClassExtractModal.tsx | 1 + 3 files changed, 3 insertions(+) diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index a80a0aa4..6410a2e9 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -456,6 +456,7 @@ export const en = { logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`, imageSquareRequired: 'Please upload a square image', nameInvalid: 'Name cannot start or end with a space', + notAllSpaces: 'Cannot be all spaces', }, model: { searchPlaceholder: 'search model…', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 1302fdad..889154f4 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1032,6 +1032,7 @@ export const zh = { logoTip: `支持图片格式(JPG、PNG)\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`, imageSquareRequired: '请上传正方形比例图片', nameInvalid: '不能是空格开头或结尾', + notAllSpaces: '不能是纯空格', }, model: { searchPlaceholder: '搜索模型…', diff --git a/web/src/views/Ontology/components/OntologyClassExtractModal.tsx b/web/src/views/Ontology/components/OntologyClassExtractModal.tsx index 2fd305c6..8d3e1a91 100644 --- a/web/src/views/Ontology/components/OntologyClassExtractModal.tsx +++ b/web/src/views/Ontology/components/OntologyClassExtractModal.tsx @@ -185,6 +185,7 @@ const OntologyClassExtractModal = forwardRef From 53dbe2f436bc1f71dde7f8ff46f9dfab68f861b3 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Mar 2026 17:26:30 +0800 Subject: [PATCH 082/164] [fix] Fix the external write memory API --- api/app/controllers/service/memory_api_controller.py | 2 +- api/app/services/memory_api_service.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index accd749e..34489e8a 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -39,7 +39,7 @@ async def write_memory_api_service( Stores memory content for the specified end user using the Memory API Service. """ - logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}") + logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}") memory_api_service = MemoryAPIService(db) diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index a8c39a5a..ad0a8164 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -140,9 +140,11 @@ class MemoryAPIService: try: # Delegate to MemoryAgentService + # Convert string message to list[dict] format expected by MemoryAgentService + messages = message if isinstance(message, list) else [{"role": "user", "content": message}] result = await MemoryAgentService().write_memory( end_user_id=end_user_id, - messages=message, + messages=messages, config_id=config_id, db=self.db, storage_type=storage_type, @@ -151,8 +153,13 @@ class MemoryAPIService: logger.info(f"Memory write successful for end_user: {end_user_id}") + # result may be a string "success" or a dict with a "status" key + if isinstance(result, dict): + status = result.get("status", "success") + else: + status = result if isinstance(result, str) else "success" return { - "status": "success" if result == "success" else result, + "status": status, "end_user_id": end_user_id } From efe3865aa44fdb8ed8ad469a5b714565a83d512e Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Mar 2026 17:26:30 +0800 Subject: [PATCH 083/164] [fix] Fix the external write memory API --- api/app/controllers/service/memory_api_controller.py | 2 +- api/app/services/memory_api_service.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index accd749e..34489e8a 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -39,7 +39,7 @@ async def write_memory_api_service( Stores memory content for the specified end user using the Memory API Service. """ - logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}") + logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}") memory_api_service = MemoryAPIService(db) diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index a8c39a5a..ad0a8164 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -140,9 +140,11 @@ class MemoryAPIService: try: # Delegate to MemoryAgentService + # Convert string message to list[dict] format expected by MemoryAgentService + messages = message if isinstance(message, list) else [{"role": "user", "content": message}] result = await MemoryAgentService().write_memory( end_user_id=end_user_id, - messages=message, + messages=messages, config_id=config_id, db=self.db, storage_type=storage_type, @@ -151,8 +153,13 @@ class MemoryAPIService: logger.info(f"Memory write successful for end_user: {end_user_id}") + # result may be a string "success" or a dict with a "status" key + if isinstance(result, dict): + status = result.get("status", "success") + else: + status = result if isinstance(result, str) else "success" return { - "status": "success" if result == "success" else result, + "status": status, "end_user_id": end_user_id } From ba36ccb21fd42c678f9b3748d7525dd747787c0e Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Mar 2026 17:46:13 +0800 Subject: [PATCH 084/164] [changes] Hide the user knowledge base and unify the display of memory capacity --- .../memory_dashboard_controller.py | 4 +- api/app/repositories/knowledge_repository.py | 43 +++++++++++++++++ api/app/services/memory_dashboard_service.py | 46 +++++++++++++++++-- api/app/tasks.py | 13 ++++-- 4 files changed, 97 insertions(+), 9 deletions(-) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 475d184e..1b5b45fb 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -606,8 +606,8 @@ async def dashboard_data( # 获取RAG相关数据 try: - # total_memory: 使用 total_chunk(总chunk数) - total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user) + # total_memory: 只统计用户知识库(permission_id='Memory')的chunk数 + total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user) rag_data["total_memory"] = total_chunk # total_app: 统计当前空间下的所有app数量 diff --git a/api/app/repositories/knowledge_repository.py b/api/app/repositories/knowledge_repository.py index 681d1c10..e3832214 100644 --- a/api/app/repositories/knowledge_repository.py +++ b/api/app/repositories/knowledge_repository.py @@ -211,3 +211,46 @@ def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int except Exception as e: db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}") raise + + +def get_user_kb_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int: + """ + 根据workspace_id查询knowledges表中permission_id='Memory'(用户知识库)的chunk_num总和 + """ + db_logger.debug(f"Query user KB chunk_num by workspace_id: workspace_id={workspace_id}") + + try: + from sqlalchemy import func + result = db.query(func.sum(Knowledge.chunk_num)).filter( + Knowledge.workspace_id == workspace_id, + Knowledge.status == 1, + Knowledge.permission_id == "Memory" + ).scalar() + + total = result if result is not None else 0 + db_logger.info(f"User KB chunk_num query successful: workspace_id={workspace_id}, total={total}") + return total + except Exception as e: + db_logger.error(f"Failed to query user KB chunk_num: workspace_id={workspace_id} - {str(e)}") + raise + + +def get_non_user_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int: + """ + 根据workspace_id查询knowledges表中排除用户知识库(permission_id!='Memory')的数量 + """ + db_logger.debug(f"Query non-user KB count by workspace_id: workspace_id={workspace_id}") + + try: + count = db.query(Knowledge).filter( + Knowledge.workspace_id == workspace_id, + Knowledge.status == 1, + Knowledge.permission_id != "Memory" + ).count() + + db_logger.info(f"Non-user KB count query successful: workspace_id={workspace_id}, count={count}") + return count + except Exception as e: + db_logger.error(f"Failed to query non-user KB count: workspace_id={workspace_id} - {str(e)}") + raise + diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index 8d6071cc..22752805 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -390,19 +390,59 @@ def get_rag_total_kb( current_user: User ) -> int: """ - 根据当前用户所在的workspace_id查询konwledges表所有不同id的数量 + 根据当前用户所在的workspace_id查询konwledges表中排除用户知识库(permission_id!='Memory')的数量 """ workspace_id = current_user.current_workspace_id - business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}") + business_logger.info(f"获取RAG总知识库数(排除用户知识库): workspace_id={workspace_id}, 操作者: {current_user.username}") try: - total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id) + total_kb = knowledge_repository.get_non_user_kb_count_by_workspace(db, workspace_id) business_logger.info(f"成功获取RAG总知识库数: {total_kb}") return total_kb except Exception as e: business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}") raise + +def get_rag_user_kb_total_chunk( + db: Session, + current_user: User +) -> int: + """ + 根据当前用户所在的workspace_id,从documents表统计所有用户知识库的chunk总数。 + 与 /end_users 接口保持同源:查询 file_name 匹配 end_user_id.txt 的文档 chunk_num 之和。 + """ + workspace_id = current_user.current_workspace_id + business_logger.info(f"获取用户知识库总chunk数(documents表): workspace_id={workspace_id}, 操作者: {current_user.username}") + + try: + from app.models.document_model import Document + from app.models.end_user_model import EndUser + from app.models.app_model import App + from sqlalchemy import func + + # 通过 App 关联取该 workspace 下所有 end_user_id + end_user_ids = [ + str(u.id) for u in db.query(EndUser.id) + .join(App, EndUser.app_id == App.id) + .filter(App.workspace_id == workspace_id) + .all() + ] + if not end_user_ids: + return 0 + + file_names = [f"{uid}.txt" for uid in end_user_ids] + result = db.query(func.sum(Document.chunk_num)).filter( + Document.file_name.in_(file_names) + ).scalar() + + total_chunk = int(result or 0) + business_logger.info(f"成功获取用户知识库总chunk数: {total_chunk}") + return total_chunk + except Exception as e: + business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}") + raise + def get_current_user_total_chunk( end_user_id: str, db: Session, diff --git a/api/app/tasks.py b/api/app/tasks.py index 299d188b..671a03f4 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -62,7 +62,7 @@ def process_item(item: dict): @celery_app.task(name="app.core.rag.tasks.parse_document") -def parse_document(file_path: str, document_id: uuid.UUID): +def parse_document(file_path: str, document_id: str): """ Document parsing, vectorization, and storage """ @@ -74,6 +74,9 @@ def parse_document(file_path: str, document_id: uuid.UUID): db = next(get_db()) # Manually call the generator db_document = None db_knowledge = None + # 确保 document_id 是 UUID 对象 + if not isinstance(document_id, uuid.UUID): + document_id = uuid.UUID(str(document_id)) progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n" try: db_document = db.query(Document).filter(Document.id == document_id).first() @@ -282,11 +285,13 @@ def parse_document(file_path: str, document_id: uuid.UUID): result = f"parse document '{db_document.file_name}' processed successfully." return result except Exception as e: - if 'db_document' in locals(): - db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n" + if db_document is not None: + db_document.progress_msg = (db_document.progress_msg or "") + f"Failed to vectorize and import the parsed document: {str(e)}\n" db_document.run = 0 db.commit() - result = f"parse document '{db_document.file_name}' failed." + result = f"parse document '{db_document.file_name}' failed." + else: + result = f"parse document '{document_id}' failed: document not found in DB. error={str(e)}" return result finally: db.close() From 850d9ee70b098b18d604c6b763e167a855c76fe5 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Mar 2026 17:46:13 +0800 Subject: [PATCH 085/164] [changes] Hide the user knowledge base and unify the display of memory capacity --- .../memory_dashboard_controller.py | 4 +- api/app/repositories/knowledge_repository.py | 43 +++++++++++++++++ api/app/services/memory_dashboard_service.py | 46 +++++++++++++++++-- api/app/tasks.py | 13 ++++-- 4 files changed, 97 insertions(+), 9 deletions(-) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 475d184e..1b5b45fb 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -606,8 +606,8 @@ async def dashboard_data( # 获取RAG相关数据 try: - # total_memory: 使用 total_chunk(总chunk数) - total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user) + # total_memory: 只统计用户知识库(permission_id='Memory')的chunk数 + total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user) rag_data["total_memory"] = total_chunk # total_app: 统计当前空间下的所有app数量 diff --git a/api/app/repositories/knowledge_repository.py b/api/app/repositories/knowledge_repository.py index 681d1c10..e3832214 100644 --- a/api/app/repositories/knowledge_repository.py +++ b/api/app/repositories/knowledge_repository.py @@ -211,3 +211,46 @@ def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int except Exception as e: db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}") raise + + +def get_user_kb_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int: + """ + 根据workspace_id查询knowledges表中permission_id='Memory'(用户知识库)的chunk_num总和 + """ + db_logger.debug(f"Query user KB chunk_num by workspace_id: workspace_id={workspace_id}") + + try: + from sqlalchemy import func + result = db.query(func.sum(Knowledge.chunk_num)).filter( + Knowledge.workspace_id == workspace_id, + Knowledge.status == 1, + Knowledge.permission_id == "Memory" + ).scalar() + + total = result if result is not None else 0 + db_logger.info(f"User KB chunk_num query successful: workspace_id={workspace_id}, total={total}") + return total + except Exception as e: + db_logger.error(f"Failed to query user KB chunk_num: workspace_id={workspace_id} - {str(e)}") + raise + + +def get_non_user_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int: + """ + 根据workspace_id查询knowledges表中排除用户知识库(permission_id!='Memory')的数量 + """ + db_logger.debug(f"Query non-user KB count by workspace_id: workspace_id={workspace_id}") + + try: + count = db.query(Knowledge).filter( + Knowledge.workspace_id == workspace_id, + Knowledge.status == 1, + Knowledge.permission_id != "Memory" + ).count() + + db_logger.info(f"Non-user KB count query successful: workspace_id={workspace_id}, count={count}") + return count + except Exception as e: + db_logger.error(f"Failed to query non-user KB count: workspace_id={workspace_id} - {str(e)}") + raise + diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index 8d6071cc..22752805 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -390,19 +390,59 @@ def get_rag_total_kb( current_user: User ) -> int: """ - 根据当前用户所在的workspace_id查询konwledges表所有不同id的数量 + 根据当前用户所在的workspace_id查询konwledges表中排除用户知识库(permission_id!='Memory')的数量 """ workspace_id = current_user.current_workspace_id - business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}") + business_logger.info(f"获取RAG总知识库数(排除用户知识库): workspace_id={workspace_id}, 操作者: {current_user.username}") try: - total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id) + total_kb = knowledge_repository.get_non_user_kb_count_by_workspace(db, workspace_id) business_logger.info(f"成功获取RAG总知识库数: {total_kb}") return total_kb except Exception as e: business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}") raise + +def get_rag_user_kb_total_chunk( + db: Session, + current_user: User +) -> int: + """ + 根据当前用户所在的workspace_id,从documents表统计所有用户知识库的chunk总数。 + 与 /end_users 接口保持同源:查询 file_name 匹配 end_user_id.txt 的文档 chunk_num 之和。 + """ + workspace_id = current_user.current_workspace_id + business_logger.info(f"获取用户知识库总chunk数(documents表): workspace_id={workspace_id}, 操作者: {current_user.username}") + + try: + from app.models.document_model import Document + from app.models.end_user_model import EndUser + from app.models.app_model import App + from sqlalchemy import func + + # 通过 App 关联取该 workspace 下所有 end_user_id + end_user_ids = [ + str(u.id) for u in db.query(EndUser.id) + .join(App, EndUser.app_id == App.id) + .filter(App.workspace_id == workspace_id) + .all() + ] + if not end_user_ids: + return 0 + + file_names = [f"{uid}.txt" for uid in end_user_ids] + result = db.query(func.sum(Document.chunk_num)).filter( + Document.file_name.in_(file_names) + ).scalar() + + total_chunk = int(result or 0) + business_logger.info(f"成功获取用户知识库总chunk数: {total_chunk}") + return total_chunk + except Exception as e: + business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}") + raise + def get_current_user_total_chunk( end_user_id: str, db: Session, diff --git a/api/app/tasks.py b/api/app/tasks.py index 093f081f..4f7bfacc 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -62,7 +62,7 @@ def process_item(item: dict): @celery_app.task(name="app.core.rag.tasks.parse_document") -def parse_document(file_path: str, document_id: uuid.UUID): +def parse_document(file_path: str, document_id: str): """ Document parsing, vectorization, and storage """ @@ -74,6 +74,9 @@ def parse_document(file_path: str, document_id: uuid.UUID): db = next(get_db()) # Manually call the generator db_document = None db_knowledge = None + # 确保 document_id 是 UUID 对象 + if not isinstance(document_id, uuid.UUID): + document_id = uuid.UUID(str(document_id)) progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n" try: db_document = db.query(Document).filter(Document.id == document_id).first() @@ -286,11 +289,13 @@ def parse_document(file_path: str, document_id: uuid.UUID): result = f"parse document '{db_document.file_name}' processed successfully." return result except Exception as e: - if 'db_document' in locals(): - db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n" + if db_document is not None: + db_document.progress_msg = (db_document.progress_msg or "") + f"Failed to vectorize and import the parsed document: {str(e)}\n" db_document.run = 0 db.commit() - result = f"parse document '{db_document.file_name}' failed." + result = f"parse document '{db_document.file_name}' failed." + else: + result = f"parse document '{document_id}' failed: document not found in DB. error={str(e)}" return result finally: db.close() From 817221347f7b5b34bf54c5be581719388bc69267 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Mar 2026 17:57:58 +0800 Subject: [PATCH 086/164] [fix] Preserve full result dict and default status to "unknown" instead of "success". --- api/app/services/memory_api_service.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index ad0a8164..f86fbed8 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -154,13 +154,17 @@ class MemoryAPIService: logger.info(f"Memory write successful for end_user: {end_user_id}") # result may be a string "success" or a dict with a "status" key + # Preserve the full dict so callers don't silently lose extra fields + # (e.g. error codes, metadata) returned by MemoryAgentService. if isinstance(result, dict): - status = result.get("status", "success") - else: - status = result if isinstance(result, str) else "success" + return { + **result, + "status": result.get("status", "unknown"), + "end_user_id": end_user_id, + } return { - "status": status, - "end_user_id": end_user_id + "status": result if isinstance(result, str) else "success", + "end_user_id": end_user_id, } except ConfigurationError as e: From 420f391f3c242ea4d485e44bcd5fa25ad6df9e29 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Mar 2026 18:01:56 +0800 Subject: [PATCH 087/164] [fix] Fixed tuple unpacking and moved UUID conversion into the try block. --- api/app/services/memory_dashboard_service.py | 2 +- api/app/tasks.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index 22752805..05aed57e 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -423,7 +423,7 @@ def get_rag_user_kb_total_chunk( # 通过 App 关联取该 workspace 下所有 end_user_id end_user_ids = [ - str(u.id) for u in db.query(EndUser.id) + str(eid) for (eid,) in db.query(EndUser.id) .join(App, EndUser.app_id == App.id) .filter(App.workspace_id == workspace_id) .all() diff --git a/api/app/tasks.py b/api/app/tasks.py index 4f7bfacc..2846071a 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -74,11 +74,11 @@ def parse_document(file_path: str, document_id: str): db = next(get_db()) # Manually call the generator db_document = None db_knowledge = None - # 确保 document_id 是 UUID 对象 - if not isinstance(document_id, uuid.UUID): - document_id = uuid.UUID(str(document_id)) progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n" try: + # 确保 document_id 是 UUID 对象 + if not isinstance(document_id, uuid.UUID): + document_id = uuid.UUID(str(document_id)) db_document = db.query(Document).filter(Document.id == document_id).first() db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first() # 1. Document parsing & segmentation From 8aad8faae9b40036486a7ef3e4bb4abd97d7aca2 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Wed, 4 Mar 2026 18:05:54 +0800 Subject: [PATCH 088/164] fix(web): chat loading fix --- web/src/views/ApplicationConfig/components/Chat.tsx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/src/views/ApplicationConfig/components/Chat.tsx b/web/src/views/ApplicationConfig/components/Chat.tsx index 8cb6812c..62f7c592 100644 --- a/web/src/views/ApplicationConfig/components/Chat.tsx +++ b/web/src/views/ApplicationConfig/components/Chat.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:27:39 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-03 14:21:54 + * @Last Modified time: 2026-03-04 18:05:36 */ /** * Chat debugging component for application testing @@ -217,6 +217,8 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc } } if (!isCanSend) { + setLoading(false) + setCompareLoading(false) return } runCompare(data.app_id, { From 647a9788657e1ccea2e1b620d1141c2af28cf58e Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Mar 2026 19:07:40 +0800 Subject: [PATCH 089/164] [fix] Restore task --- api/app/tasks.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/api/app/tasks.py b/api/app/tasks.py index 2846071a..093f081f 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -62,7 +62,7 @@ def process_item(item: dict): @celery_app.task(name="app.core.rag.tasks.parse_document") -def parse_document(file_path: str, document_id: str): +def parse_document(file_path: str, document_id: uuid.UUID): """ Document parsing, vectorization, and storage """ @@ -76,9 +76,6 @@ def parse_document(file_path: str, document_id: str): db_knowledge = None progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n" try: - # 确保 document_id 是 UUID 对象 - if not isinstance(document_id, uuid.UUID): - document_id = uuid.UUID(str(document_id)) db_document = db.query(Document).filter(Document.id == document_id).first() db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first() # 1. Document parsing & segmentation @@ -289,13 +286,11 @@ def parse_document(file_path: str, document_id: str): result = f"parse document '{db_document.file_name}' processed successfully." return result except Exception as e: - if db_document is not None: - db_document.progress_msg = (db_document.progress_msg or "") + f"Failed to vectorize and import the parsed document: {str(e)}\n" + if 'db_document' in locals(): + db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n" db_document.run = 0 db.commit() - result = f"parse document '{db_document.file_name}' failed." - else: - result = f"parse document '{document_id}' failed: document not found in DB. error={str(e)}" + result = f"parse document '{db_document.file_name}' failed." return result finally: db.close() From 590ec3a446d957f17e9d9e05189f2a155dd378ab Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Thu, 5 Mar 2026 09:55:54 +0800 Subject: [PATCH 090/164] feat(model and app): 1. Increase support for visual models and multimodal models; 2. The application and workflow can input various multimodal files such as images, documents, audio, and videos. --- api/app/controllers/app_controller.py | 3 +- api/app/controllers/model_controller.py | 4 +- api/app/controllers/ontology_controller.py | 17 +- api/app/core/agent/langchain_agent.py | 153 ++++--- api/app/core/models/base.py | 30 +- .../core/models/scripts/bedrock_models.yaml | 38 +- .../core/models/scripts/dashscope_models.yaml | 172 ++++++- api/app/core/models/scripts/loader.py | 45 +- .../core/models/scripts/openai_models.yaml | 64 ++- api/app/models/models_model.py | 13 +- api/app/schemas/app_schema.py | 16 +- api/app/schemas/model_schema.py | 18 + api/app/services/app_chat_service.py | 6 +- api/app/services/app_service.py | 2 +- .../services/audio_transcription_service.py | 101 +++++ .../services/collaborative_orchestrator.py | 2 + api/app/services/draft_run_service.py | 11 +- api/app/services/handoffs_service.py | 1 + api/app/services/llm_router.py | 1 + api/app/services/master_agent_router.py | 1 + api/app/services/model_service.py | 58 ++- api/app/services/multi_agent_orchestrator.py | 2 + api/app/services/multi_agent_service.py | 2 +- api/app/services/multimodal_service.py | 426 +++++++++++++----- api/app/services/prompt_optimizer_service.py | 3 +- api/app/services/shared_chat_service.py | 2 + 26 files changed, 958 insertions(+), 233 deletions(-) create mode 100644 api/app/services/audio_transcription_service.py diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index e2849ad6..653f616c 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -835,7 +835,8 @@ async def draft_run_compare( web_search=True, memory=True, parallel=payload.parallel, - timeout=payload.timeout or 60 + timeout=payload.timeout or 60, + files=payload.files ) logger.info( diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py index bb1ba526..0de3d4fe 100644 --- a/api/app/controllers/model_controller.py +++ b/api/app/controllers/model_controller.py @@ -469,7 +469,9 @@ async def create_model_api_key_by_provider( config=api_key_data.config, is_active=api_key_data.is_active, priority=api_key_data.priority, - model_config_ids=model_config_ids + model_config_ids=model_config_ids, + capability=api_key_data.capability, + is_omni=api_key_data.is_omni ) created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index e4a87141..42d4bee0 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -124,15 +124,23 @@ def _get_ontology_service( ) # 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理) - from app.repositories.model_repository import ModelApiKeyRepository - api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id) - if not api_keys: + # from app.repositories.model_repository import ModelApiKeyRepository + from app.services.model_service import ModelApiKeyService + api_key_config = ModelApiKeyService.get_available_api_key(db, model_config.id) + if not api_key_config: logger.error(f"Model {llm_id} has no active API key") raise HTTPException( status_code=400, detail="指定的LLM模型没有可用的API密钥" ) - api_key_config = api_keys[0] + # api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id) + # if not api_keys: + # logger.error(f"Model {llm_id} has no active API key") + # raise HTTPException( + # status_code=400, + # detail="指定的LLM模型没有可用的API密钥" + # ) + # api_key_config = api_keys[0] is_composite = getattr(model_config, 'is_composite', False) logger.info( @@ -154,6 +162,7 @@ def _get_ontology_service( provider=actual_provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, max_retries=3, timeout=60.0 ) diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index fae20ea2..88b6371c 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -11,35 +11,37 @@ LangChain Agent 封装 import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence -from app.core.memory.agent.langgraph_graph.write_graph import write_long_term +from app.core.memory.agent.langgraph_graph.write_graph import write_long_term from app.db import get_db from app.core.logging_config import get_business_logger from app.core.models import RedBearLLM, RedBearModelConfig -from app.models.models_model import ModelType +from app.models.models_model import ModelType, ModelProvider from app.services.memory_agent_service import ( get_end_user_connected_config, ) from langchain.agents import create_agent from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool + logger = get_business_logger() class LangChainAgent: def __init__( - self, - model_name: str, - api_key: str, - provider: str = "openai", - api_base: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 2000, - system_prompt: Optional[str] = None, - tools: Optional[Sequence[BaseTool]] = None, - streaming: bool = False, - max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算) - max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数 + self, + model_name: str, + api_key: str, + provider: str = "openai", + api_base: Optional[str] = None, + is_omni: bool = False, + temperature: float = 0.7, + max_tokens: int = 2000, + system_prompt: Optional[str] = None, + tools: Optional[Sequence[BaseTool]] = None, + streaming: bool = False, + max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算) + max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数 ): """初始化 LangChain Agent @@ -60,12 +62,13 @@ class LangChainAgent: self.provider = provider self.tools = tools or [] self.streaming = streaming + self.is_omni = is_omni self.max_tool_consecutive_calls = max_tool_consecutive_calls - + # 工具调用计数器:记录每个工具的连续调用次数 self.tool_call_counter: Dict[str, int] = {} self.last_tool_called: Optional[str] = None - + # 根据工具数量动态调整最大迭代次数 # 基础值 + 每个工具额外的调用机会 if max_iterations is None: @@ -73,9 +76,9 @@ class LangChainAgent: self.max_iterations = 5 + len(self.tools) * 2 else: self.max_iterations = max_iterations - + self.system_prompt = system_prompt or "你是一个专业的AI助手" - + logger.debug( f"Agent 迭代次数配置: max_iterations={self.max_iterations}, " f"tool_count={len(self.tools)}, " @@ -89,6 +92,7 @@ class LangChainAgent: provider=provider, api_key=api_key, base_url=api_base, + is_omni=is_omni, extra_params={ "temperature": temperature, "max_tokens": max_tokens, @@ -143,21 +147,22 @@ class LangChainAgent: """ from langchain_core.tools import StructuredTool from functools import wraps - + wrapped_tools = [] - + for original_tool in tools: tool_name = original_tool.name original_func = original_tool.func if hasattr(original_tool, 'func') else None - + if not original_func: # 如果无法获取原始函数,直接使用原工具 wrapped_tools.append(original_tool) continue - + # 创建包装函数 def make_wrapped_func(tool_name, original_func): """创建包装函数的工厂函数,避免闭包问题""" + @wraps(original_func) def wrapped_func(*args, **kwargs): """包装后的工具函数,跟踪连续调用次数""" @@ -168,13 +173,13 @@ class LangChainAgent: # 切换到新工具,重置计数器 self.tool_call_counter[tool_name] = 1 self.last_tool_called = tool_name - + current_count = self.tool_call_counter[tool_name] - + logger.debug( f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}" ) - + # 检查是否超过最大连续调用次数 if current_count > self.max_tool_consecutive_calls: logger.warning( @@ -185,12 +190,12 @@ class LangChainAgent: f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次," f"未找到有效结果。请尝试其他方法或直接回答用户的问题。" ) - + # 调用原始工具函数 return original_func(*args, **kwargs) - + return wrapped_func - + # 使用 StructuredTool 创建新工具 wrapped_tool = StructuredTool( name=original_tool.name, @@ -198,17 +203,17 @@ class LangChainAgent: func=make_wrapped_func(tool_name, original_func), args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None ) - + wrapped_tools.append(wrapped_tool) - + return wrapped_tools def _prepare_messages( - self, - message: str, - history: Optional[List[Dict[str, str]]] = None, - context: Optional[str] = None, - files: Optional[List[Dict[str, Any]]] = None + self, + message: str, + history: Optional[List[Dict[str, str]]] = None, + context: Optional[str] = None, + files: Optional[List[Dict[str, Any]]] = None ) -> List[BaseMessage]: """准备消息列表 @@ -248,7 +253,7 @@ class LangChainAgent: messages.append(HumanMessage(content=user_content)) return messages - + def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ 构建多模态消息内容 @@ -261,23 +266,26 @@ class LangChainAgent: List[Dict]: 消息内容列表 """ # 根据 provider 使用不同的文本格式 - if self.provider.lower() in ["bedrock", "anthropic"]: - # Anthropic/Bedrock: {"type": "text", "text": "..."} - content_parts = [{"type": "text", "text": text}] - else: - # 通义千问等: {"text": "..."} - content_parts = [{"text": text}] - + # if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE, + # ModelProvider.GPUSTACK] or ( + # self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)): + # # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."} + # content_parts = [{"type": "text", "text": text}] + # else: + # # 通义千问等: {"text": "..."} + # content_parts = [{"type": "text", "text": text}] + content_parts = [{"type": "text", "text": text}] + # 添加文件内容 # MultimodalService 已经根据 provider 返回了正确格式,直接使用 content_parts.extend(files) - + logger.debug( f"构建多模态消息: provider={self.provider}, " f"parts={len(content_parts)}, " f"files={len(files)}" ) - + return content_parts async def chat( @@ -302,7 +310,7 @@ class LangChainAgent: Returns: Dict: 包含 content 和元数据的字典 """ - message_chat= message + message_chat = message start_time = time.time() actual_config_id = config_id # If config_id is None, try to get from end_user's connected config @@ -322,8 +330,8 @@ class LangChainAgent: except Exception as e: logger.warning(f"Failed to get db session: {e}") actual_end_user_id = end_user_id if end_user_id is not None else "unknown" - logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') - print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') + logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') + print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') try: # 准备消息列表(支持多模态) messages = self._prepare_messages(message, history, context, files) @@ -367,14 +375,14 @@ class LangChainAgent: # 获取最后的 AI 消息 output_messages = result.get("messages", []) content = "" - + logger.debug(f"输出消息数量: {len(output_messages)}") total_tokens = 0 for msg in reversed(output_messages): if isinstance(msg, AIMessage): logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}") logger.debug(f"AI 消息内容: {msg.content}") - + # 处理多模态响应:content 可能是字符串或列表 if isinstance(msg.content, str): content = msg.content @@ -407,12 +415,13 @@ class LangChainAgent: response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0 break - + logger.info(f"最终提取的内容长度: {len(content)}") elapsed_time = time.time() - start_time if memory_flag: - await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id) + await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, + actual_config_id) response = { "content": content, "model": self.model_name, @@ -439,16 +448,16 @@ class LangChainAgent: raise async def chat_stream( - self, - message: str, - history: Optional[List[Dict[str, str]]] = None, - context: Optional[str] = None, - end_user_id:Optional[str] = None, - config_id: Optional[str] = None, - storage_type:Optional[str] = None, - user_rag_memory_id:Optional[str] = None, - memory_flag: Optional[bool] = True, - files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 + self, + message: str, + history: Optional[List[Dict[str, str]]] = None, + context: Optional[str] = None, + end_user_id: Optional[str] = None, + config_id: Optional[str] = None, + storage_type: Optional[str] = None, + user_rag_memory_id: Optional[str] = None, + memory_flag: Optional[bool] = True, + files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 ) -> AsyncGenerator[str, None]: """执行流式对话 @@ -482,7 +491,6 @@ class LangChainAgent: except Exception as e: logger.warning(f"Failed to get db session: {e}") - # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表(支持多模态) @@ -500,13 +508,13 @@ class LangChainAgent: full_content = '' try: async for event in self.agent.astream_events( - {"messages": messages}, - version="v2", - config={"recursion_limit": self.max_iterations} + {"messages": messages}, + version="v2", + config={"recursion_limit": self.max_iterations} ): chunk_count += 1 kind = event.get("event") - + # 处理所有可能的流式事件 if kind == "on_chat_model_stream": # LLM 流式输出 @@ -540,7 +548,7 @@ class LangChainAgent: full_content += item yield item yielded_content = True - + elif kind == "on_llm_stream": # 另一种 LLM 流式事件 chunk = event.get("data", {}).get("chunk") @@ -577,13 +585,13 @@ class LangChainAgent: full_content += chunk yield chunk yielded_content = True - + # 记录工具调用(可选) elif kind == "on_tool_start": logger.debug(f"工具调用开始: {event.get('name')}") elif kind == "on_tool_end": logger.debug(f"工具调用结束: {event.get('name')}") - + logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") # 统计token消耗 output_messages = event.get("data", {}).get("output", {}).get("messages", []) @@ -595,7 +603,8 @@ class LangChainAgent: yield total_tokens break if memory_flag: - await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id) + await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, + actual_config_id) except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) raise @@ -609,5 +618,3 @@ class LangChainAgent: logger.info("=" * 80) logger.info("chat_stream 方法执行结束") logger.info("=" * 80) - - diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index f5f49af0..5d4dbd10 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -27,6 +27,7 @@ class RedBearModelConfig(BaseModel): provider: str api_key: str base_url: Optional[str] = None + is_omni: bool = False # 是否为 Omni 模型 # 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置 timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0"))) # 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置 @@ -45,7 +46,28 @@ class RedBearModelFactory: # 打印供应商信息用于调试 from app.core.logging_config import get_business_logger logger = get_business_logger() - logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}") + logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}") + + # dashscope 的 omni 模型使用 OpenAI 兼容模式 + if provider == ModelProvider.DASHSCOPE and config.is_omni: + import httpx + if not config.base_url: + config.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" + timeout_config = httpx.Timeout( + timeout=config.timeout, + connect=60.0, + read=config.timeout, + write=60.0, + pool=10.0, + ) + return { + "model": config.model_name, + "base_url": config.base_url, + "api_key": config.api_key, + "timeout": timeout_config, + "max_retries": config.max_retries, + **config.extra_params + } if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]: # 使用 httpx.Timeout 对象来设置详细的超时配置 @@ -135,6 +157,12 @@ class RedBearModelFactory: def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]: """根据模型提供商获取对应的模型类""" provider = config.provider.lower() + + # dashscope 的 omni 模型使用 OpenAI 兼容模式 + if provider == ModelProvider.DASHSCOPE and config.is_omni: + from langchain_openai import ChatOpenAI + return ChatOpenAI + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : if type == ModelType.LLM: from langchain_openai import OpenAI diff --git a/api/app/core/models/scripts/bedrock_models.yaml b/api/app/core/models/scripts/bedrock_models.yaml index e5b91d1c..2c0ab757 100644 --- a/api/app/core/models/scripts/bedrock_models.yaml +++ b/api/app/core/models/scripts/bedrock_models.yaml @@ -6,6 +6,8 @@ models: description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 logo: bedrock @@ -15,6 +17,9 @@ models: description: Amazon Nova大语言模型,支持智能体思考、工具调用、流式工具调用、视觉能力,300000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -28,6 +33,9 @@ models: description: Anthropic Claude大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、文档处理,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -42,6 +50,8 @@ models: description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -54,6 +64,9 @@ models: description: DeepSeek大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用,32768上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -67,6 +80,8 @@ models: description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -78,6 +93,8 @@ models: description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -89,6 +106,8 @@ models: description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -101,6 +120,8 @@ models: description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -113,6 +134,8 @@ models: description: amazon.rerank-v1:0重排序模型,5120上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 重排序模型 logo: bedrock @@ -122,6 +145,8 @@ models: description: cohere.rerank-v3-5:0重排序模型,5120上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 重排序模型 logo: bedrock @@ -131,6 +156,9 @@ models: description: amazon.nova-2-multimodal-embeddings-v1:0文本嵌入模型,支持视觉能力,8192上下文窗口 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 文本嵌入模型 - vision @@ -141,6 +169,8 @@ models: description: amazon.titan-embed-text-v1文本嵌入模型,8192上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本嵌入模型 logo: bedrock @@ -150,6 +180,8 @@ models: description: amazon.titan-embed-text-v2:0文本嵌入模型,8192上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本嵌入模型 logo: bedrock @@ -159,6 +191,8 @@ models: description: Cohere Embed 3 English文本嵌入模型,512上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本嵌入模型 logo: bedrock @@ -168,6 +202,8 @@ models: description: Cohere Embed 3 Multilingual文本嵌入模型,512上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本嵌入模型 - logo: bedrock + logo: bedrock \ No newline at end of file diff --git a/api/app/core/models/scripts/dashscope_models.yaml b/api/app/core/models/scripts/dashscope_models.yaml index af1c3619..89a16966 100644 --- a/api/app/core/models/scripts/dashscope_models.yaml +++ b/api/app/core/models/scripts/dashscope_models.yaml @@ -6,6 +6,8 @@ models: description: DeepSeek-R1-Distill-Qwen-14B大语言模型,支持智能体思考,32000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -16,6 +18,8 @@ models: description: DeepSeek-R1-Distill-Qwen-32B大语言模型,支持智能体思考,32000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -26,6 +30,8 @@ models: description: DeepSeek-R1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -36,6 +42,8 @@ models: description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -46,6 +54,8 @@ models: description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -56,6 +66,8 @@ models: description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -66,6 +78,8 @@ models: description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -76,6 +90,8 @@ models: description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -88,6 +104,8 @@ models: description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -100,6 +118,9 @@ models: description: qvq-max-latest大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - vision @@ -112,6 +133,9 @@ models: description: qvq-max大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - vision @@ -124,6 +148,8 @@ models: description: qwen-coder-turbo-0919代码专用大语言模型,支持智能体思考,131072上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -135,6 +161,8 @@ models: description: qwen-max-latest大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -147,6 +175,8 @@ models: description: qwen-max-longcontext长上下文大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -159,6 +189,8 @@ models: description: qwen-max大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -171,6 +203,8 @@ models: description: qwen-mt-plus多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 翻译模型 @@ -182,6 +216,8 @@ models: description: qwen-mt-turbo轻量化多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 翻译模型 @@ -193,6 +229,8 @@ models: description: qwen-plus-0112大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -205,6 +243,8 @@ models: description: qwen-plus-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -217,6 +257,8 @@ models: description: qwen-plus-0723大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -229,6 +271,8 @@ models: description: qwen-plus-0806大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -241,6 +285,8 @@ models: description: qwen-plus-0919大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -253,6 +299,8 @@ models: description: qwen-plus-1125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -265,6 +313,8 @@ models: description: qwen-plus-1127大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -277,6 +327,8 @@ models: description: qwen-plus-1220大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -289,6 +341,10 @@ models: description: qwen-vl-max多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -302,6 +358,10 @@ models: description: qwen-vl-plus-0809多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -315,6 +375,10 @@ models: description: qwen-vl-plus-2025-01-02多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -328,6 +392,10 @@ models: description: qwen-vl-plus-2025-01-25多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -341,6 +409,10 @@ models: description: qwen-vl-plus-latest多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -354,6 +426,10 @@ models: description: qwen-vl-plus多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -367,6 +443,8 @@ models: description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -379,6 +457,8 @@ models: description: qwen3-14b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -391,6 +471,8 @@ models: description: qwen3-235b-a22b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -403,6 +485,8 @@ models: description: qwen3-235b-a22b-thinking-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -415,6 +499,8 @@ models: description: qwen3-235b-a22b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -427,6 +513,8 @@ models: description: qwen3-30b-a3b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -439,6 +527,8 @@ models: description: qwen3-30b-a3b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -451,6 +541,8 @@ models: description: qwen3-32b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -463,6 +555,8 @@ models: description: qwen3-4b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -475,6 +569,8 @@ models: description: qwen3-8b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -487,6 +583,8 @@ models: description: qwen3-coder-30b-a3b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -498,6 +596,8 @@ models: description: qwen3-coder-480b-a35b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -509,6 +609,8 @@ models: description: qwen3-coder-plus-2025-09-23大语言模型,支持智能体思考,1000000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -520,6 +622,8 @@ models: description: qwen3-coder-plus大语言模型,支持智能体思考,1000000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -531,6 +635,8 @@ models: description: qwen3-max-2025-09-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -544,6 +650,8 @@ models: description: qwen3-max-2026-01-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -557,6 +665,8 @@ models: description: qwen3-max-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -569,6 +679,8 @@ models: description: qwen3-max大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -582,6 +694,8 @@ models: description: qwen3-next-80b-a3b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -594,6 +708,8 @@ models: description: qwen3-next-80b-a3b-thinking大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -606,6 +722,11 @@ models: description: qwen3-omni-flash-2025-12-01多模态大语言模型,支持视觉、智能体思考、视频、音频能力,65536上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + - audio + is_omni: true tags: - 大语言模型 - 多模态模型 @@ -620,6 +741,10 @@ models: description: qwen3-vl-235b-a22b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -635,6 +760,10 @@ models: description: qwen3-vl-235b-a22b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -650,6 +779,10 @@ models: description: qwen3-vl-30b-a3b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -665,6 +798,10 @@ models: description: qwen3-vl-30b-a3b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -680,6 +817,10 @@ models: description: qwen3-vl-flash多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -695,6 +836,10 @@ models: description: qwen3-vl-plus-2025-09-23多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -708,6 +853,10 @@ models: description: qwen3-vl-plus多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -721,6 +870,8 @@ models: description: qwq-32b大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -732,6 +883,8 @@ models: description: qwq-plus-0305大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -743,6 +896,8 @@ models: description: qwq-plus大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -754,6 +909,8 @@ models: description: gte-rerank-v2重排序模型,4000上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 重排序模型 logo: dashscope @@ -763,6 +920,8 @@ models: description: gte-rerank重排序模型,4000上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 重排序模型 logo: dashscope @@ -772,6 +931,9 @@ models: description: multimodal-embedding-v1多模态嵌入模型,支持视觉能力,8192上下文窗口,最大分块数10 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 嵌入模型 - 多模态模型 @@ -783,6 +945,8 @@ models: description: text-embedding-v1文本嵌入模型,2048上下文窗口,最大分块数25 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 嵌入模型 - 文本嵌入 @@ -793,6 +957,8 @@ models: description: text-embedding-v2文本嵌入模型,2048上下文窗口,最大分块数25 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 嵌入模型 - 文本嵌入 @@ -803,6 +969,8 @@ models: description: text-embedding-v3文本嵌入模型,8192上下文窗口,最大分块数10 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 嵌入模型 - 文本嵌入 @@ -813,7 +981,9 @@ models: description: text-embedding-v4文本嵌入模型,8192上下文窗口,最大分块数10 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 嵌入模型 - 文本嵌入 - logo: dashscope + logo: dashscope \ No newline at end of file diff --git a/api/app/core/models/scripts/loader.py b/api/app/core/models/scripts/loader.py index a14d3268..e4462efa 100644 --- a/api/app/core/models/scripts/loader.py +++ b/api/app/core/models/scripts/loader.py @@ -6,7 +6,7 @@ from typing import Callable import yaml from sqlalchemy.orm import Session -from app.models.models_model import ModelBase, ModelProvider +from app.models.models_model import ModelBase, ModelProvider, ModelConfig def _load_yaml_config(provider: ModelProvider) -> list[dict]: @@ -55,6 +55,15 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...") for model_data in models: + config_sync_fields = { + "logo": None, + "capability": None, + "is_omni": None, + "name": None, + "provider": None, + "type": None, + "description": None + } try: # 检查模型是否已存在 existing = db.query(ModelBase).filter( @@ -66,6 +75,40 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) # 更新现有模型配置 for key, value in model_data.items(): setattr(existing, key, value) + + # 更新绑定了该 model_id 的 ModelConfig 和 ModelApiKey + sync_fields = [k for k in config_sync_fields.keys() if k in model_data] + if sync_fields: + # 批量更新 ModelConfig + update_kwargs = {k: model_data[k] for k in sync_fields} + db.query(ModelConfig).filter(ModelConfig.model_id == existing.id).update( + update_kwargs, + synchronize_session=False + ) + + # 更新 ModelApiKey 的 capability 和 is_omni + if 'capability' in model_data or 'is_omni' in model_data: + from app.models.models_model import ModelApiKey, model_config_api_key_association + api_key_update = {} + if 'capability' in model_data: + api_key_update['capability'] = model_data['capability'] + if 'is_omni' in model_data: + api_key_update['is_omni'] = model_data['is_omni'] + + if api_key_update: + # 查找所有关联的 API Key + api_key_ids = db.query(model_config_api_key_association.c.api_key_id).join( + ModelConfig, + ModelConfig.id == model_config_api_key_association.c.model_config_id + ).filter(ModelConfig.model_id == existing.id).distinct().all() + + if api_key_ids: + api_key_ids = [aid[0] for aid in api_key_ids] + db.query(ModelApiKey).filter(ModelApiKey.id.in_(api_key_ids)).update( + api_key_update, + synchronize_session=False + ) + db.commit() if not silent: print(f"更新成功: {model_data['name']}") diff --git a/api/app/core/models/scripts/openai_models.yaml b/api/app/core/models/scripts/openai_models.yaml index 68c63ee2..7f6d3a51 100644 --- a/api/app/core/models/scripts/openai_models.yaml +++ b/api/app/core/models/scripts/openai_models.yaml @@ -6,12 +6,19 @@ models: description: chatgpt-4o-latest大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - audio + - video + is_omni: true tags: - 大语言模型 - multi-tool-call - agent-thought - stream-tool-call - vision + - audio + - video logo: openai - name: gpt-3.5-turbo-0125 type: llm @@ -19,6 +26,8 @@ models: description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -31,6 +40,8 @@ models: description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -43,6 +54,8 @@ models: description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -55,6 +68,8 @@ models: description: gpt-3.5-turbo-instruct大语言模型,4096上下文窗口,文本补全模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 logo: openai @@ -64,6 +79,8 @@ models: description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -76,6 +93,8 @@ models: description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -88,6 +107,8 @@ models: description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -100,6 +121,9 @@ models: description: gpt-4-turbo-2024-04-09大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -113,6 +137,8 @@ models: description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -125,6 +151,9 @@ models: description: gpt-4-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -138,6 +167,8 @@ models: description: o1-preview大语言模型,支持智能体思考,128000上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -148,6 +179,9 @@ models: description: o1大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -162,6 +196,9 @@ models: description: o3-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -176,6 +213,8 @@ models: description: o3-mini-2025-01-31大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -189,6 +228,8 @@ models: description: o3-mini大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -202,6 +243,9 @@ models: description: o3-pro-2025-06-10大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -215,6 +259,9 @@ models: description: o3-pro大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -228,6 +275,9 @@ models: description: o3大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -242,6 +292,9 @@ models: description: o4-mini-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -256,6 +309,9 @@ models: description: o4-mini大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -270,6 +326,8 @@ models: description: text-embedding-3-large文本向量模型,8191上下文窗口,最大分块数32 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本向量模型 logo: openai @@ -279,6 +337,8 @@ models: description: text-embedding-3-small文本向量模型,8191上下文窗口,最大分块数32 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本向量模型 logo: openai @@ -288,6 +348,8 @@ models: description: text-embedding-ada-002文本向量模型,8097上下文窗口,最大分块数32 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本向量模型 - logo: openai + logo: openai \ No newline at end of file diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 3e378f17..23fafcef 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -2,7 +2,7 @@ import datetime import uuid from enum import StrEnum -from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table +from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text from sqlalchemy.dialects.postgresql import UUID, JSON from sqlalchemy.orm import relationship from sqlalchemy.sql import func @@ -78,6 +78,9 @@ class ModelConfig(BaseModel): description = Column(String, comment="模型描述") # 模型配置参数 + capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"), + comment="模型能力列表(如['vision', 'audio', 'video'])") + is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)") config = Column(JSON, comment="模型配置参数") # - temperature : 控制生成文本的随机性。值越高,输出越随机、越有创造性;值越低,输出越确定、越保守。 # - top_p : 一种替代 temperature 的采样方法,控制模型从概率最高的词中选择的范围。 @@ -118,6 +121,11 @@ class ModelApiKey(BaseModel): api_key = Column(String, nullable=False, comment="API密钥") api_base = Column(String, comment="API基础URL") + # 模型能力参数 + capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"), + comment="模型能力列表(如['vision', 'audio', 'video'])") + is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)") + # 配置参数 config = Column(JSON, comment="API Key特定配置") @@ -155,6 +163,9 @@ class ModelBase(Base): tags = Column(ARRAY(String), default=list, nullable=False, comment="模型标签(如['聊天', '创作'])") add_count = Column(Integer, default=0, nullable=False, comment="模型被用户添加的次数") created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间", server_default=func.now()) + capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"), + comment="模型能力列表(如['vision', 'audio', 'video'])") + is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)") # 关联关系 configs = relationship("ModelConfig", back_populates="model_base", cascade="all, delete-orphan") diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 07875e13..f073a200 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -21,8 +21,14 @@ class FileType(StrEnum): def trans(cls, value: str) -> 'FileType': if value.startswith("image"): return cls.IMAGE - # TODO: other file type support - raise RuntimeError("Unsupport file type") + elif value.startswith("document"): + return cls.DOCUMENT + elif value.startswith("audio"): + return cls.AUDIO + elif value.startswith("video"): + return cls.VIDEO + else: + raise RuntimeError("Unsupport file type") class TransferMethod(str, Enum): @@ -37,6 +43,12 @@ class FileInput(BaseModel): transfer_method: TransferMethod = Field(..., description="传输方式: local_file/remote_url") upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件ID(local_file时必填)") url: Optional[str] = Field(None, description="远程URL(remote_url时必填)") + file_type: Optional[str] = Field(None, description="具体文件格式(如image/jpg、audio/wav、document/docx、video/mp4)") + + def __init__(self, **data): + if "type" in data: + data['file_type'] = data['type'] + super().__init__(**data) @field_validator("type", mode="before") @classmethod diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py index 0c0bbeed..f25d9408 100644 --- a/api/app/schemas/model_schema.py +++ b/api/app/schemas/model_schema.py @@ -21,6 +21,8 @@ class ModelConfigBase(BaseModel): is_active: bool = Field(True, description="是否激活") is_public: bool = Field(False, description="是否公开") load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略") + capability: List[str] = Field(default_factory=list, description="模型能力列表") + is_omni: bool = Field(False, description="是否为Omni模型") class ApiKeyCreateNested(BaseModel): @@ -30,6 +32,8 @@ class ApiKeyCreateNested(BaseModel): provider: Optional[str] = Field(None, description="API Key提供商") api_key: str = Field(..., description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") priority: str = Field("1", description="优先级", max_length=10) @@ -63,6 +67,8 @@ class ModelConfigUpdate(BaseModel): config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数") is_active: Optional[bool] = Field(None, description="是否激活") is_public: Optional[bool] = Field(None, description="是否公开") + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") class ModelConfig(ModelConfigBase): @@ -95,6 +101,8 @@ class ModelApiKeyCreateByProvider(BaseModel): api_key: str = Field(..., description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) description: Optional[str] = Field(None, description="备注") + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") is_active: bool = Field(True, description="是否激活") priority: str = Field("1", description="优先级", max_length=10) @@ -108,6 +116,8 @@ class ModelApiKeyBase(BaseModel): provider: ModelProvider = Field(..., description="API Key提供商") api_key: str = Field(..., description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) + capability: List[str] = Field(default_factory=list, description="模型能力列表") + is_omni: bool = Field(False, description="是否为Omni模型") config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") is_active: bool = Field(True, description="是否激活") priority: str = Field("1", description="优先级", max_length=10) @@ -124,6 +134,8 @@ class ModelApiKeyUpdate(BaseModel): provider: Optional[ModelProvider] = Field(None, description="API Key提供商") api_key: Optional[str] = Field(None, description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置") is_active: Optional[bool] = Field(None, description="是否激活") priority: Optional[str] = Field(None, description="优先级", max_length=10) @@ -270,6 +282,8 @@ class ModelBaseCreate(BaseModel): description: Optional[str] = Field(None, description="模型描述") is_official: bool = Field(True, description="是否供应商官方模型") tags: List[str] = Field(default_factory=list, description="模型标签") + capability: List[str] = Field(default_factory=list, description="模型能力列表(如['vision', 'audio', 'video'])") + is_omni: bool = Field(False, description="是否为Omni模型") class ModelBaseUpdate(BaseModel): @@ -282,6 +296,8 @@ class ModelBaseUpdate(BaseModel): is_deprecated: Optional[bool] = Field(None, description="是否弃用") is_official: Optional[bool] = Field(None, description="是否供应商官方模型") tags: Optional[List[str]] = Field(None, description="模型标签") + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") class ModelBase(BaseModel): @@ -298,6 +314,8 @@ class ModelBase(BaseModel): is_official: bool tags: List[str] add_count: int + capability: List[str] = [] + is_omni: bool = False class ModelBaseQuery(BaseModel): diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 9723121d..e6ac227b 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -157,6 +157,7 @@ class AppChatService: api_key=api_key_obj.api_key, provider=api_key_obj.provider, api_base=api_key_obj.api_base, + is_omni=api_key_obj.is_omni, temperature=model_parameters.get("temperature", 0.7), max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, @@ -180,7 +181,7 @@ class AppChatService: # 处理多模态文件 processed_files = None if files: - multimodal_service = MultimodalService(self.db) + multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") @@ -343,6 +344,7 @@ class AppChatService: api_key=api_key_obj.api_key, provider=api_key_obj.provider, api_base=api_key_obj.api_base, + is_omni=api_key_obj.is_omni, temperature=model_parameters.get("temperature", 0.7), max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, @@ -366,7 +368,7 @@ class AppChatService: # 处理多模态文件 processed_files = None if files: - multimodal_service = MultimodalService(self.db) + multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 6e6e0ecb..c5919af9 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -232,7 +232,7 @@ class AppService: # 检查主 Agent 的模型配置 multi_agent_config.default_model_config_id = master_agent_release.default_model_config_id - model_api_key = ModelApiKeyService.get_a_api_key(self.db, multi_agent_config.default_model_config_id) + model_api_key = ModelApiKeyService.get_available_api_key(self.db, multi_agent_config.default_model_config_id) if not model_api_key: raise ResourceNotFoundException("模型配置", str(multi_agent_config.default_model_config_id)) diff --git a/api/app/services/audio_transcription_service.py b/api/app/services/audio_transcription_service.py new file mode 100644 index 00000000..11d13f38 --- /dev/null +++ b/api/app/services/audio_transcription_service.py @@ -0,0 +1,101 @@ +""" +音频转文本服务 + +支持的服务商: +- DashScope (阿里云通义千问) +- OpenAI Whisper +""" +import httpx + +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class AudioTranscriptionService: + """音频转文本服务""" + + @staticmethod + async def transcribe_dashscope(audio_url: str, api_key: str) -> str: + """ + 使用阿里云通义千问语音识别服务转换音频为文本 + + Args: + audio_url: 音频文件 URL + api_key: DashScope API Key + + Returns: + str: 转录的文本 + """ + try: + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "X-DashScope-Async": "enable", + }, + json={ + "model": "paraformer-v2", + "input": { + "file_urls": [audio_url] + }, + "parameters": { + "language_hints": ["zh", "en", "ja", "yue", "ko", "de", "fr", "ru"] + } + } + ) + response.raise_for_status() + result = response.json() + + if result.get("output", {}).get("results"): + text = result["output"]["results"][0].get("transcription_text", "") + logger.info(f"音频转文本成功: {len(text)} 字符") + return text + + return "[音频转文本失败]" + + except Exception as e: + logger.error(f"DashScope 音频转文本失败: {e}") + return f"[音频转文本失败: {str(e)}]" + + @staticmethod + async def transcribe_openai(audio_url: str, api_key: str) -> str: + """ + 使用 OpenAI Whisper 转换音频为文本 + + Args: + audio_url: 音频文件 URL + api_key: OpenAI API Key + + Returns: + str: 转录的文本 + """ + try: + # 下载音频文件 + async with httpx.AsyncClient(timeout=60.0) as client: + audio_response = await client.get(audio_url) + audio_response.raise_for_status() + audio_data = audio_response.content + + # 调用 Whisper API + files = {"file": ("audio.mp3", audio_data, "audio/mpeg")} + data = {"model": "whisper-1"} + + response = await client.post( + "https://api.openai.com/v1/audio/transcriptions", + headers={"Authorization": f"Bearer {api_key}"}, + files=files, + data=data + ) + response.raise_for_status() + result = response.json() + + text = result.get("text", "") + logger.info(f"音频转文本成功: {len(text)} 字符") + return text + + except Exception as e: + logger.error(f"OpenAI Whisper 音频转文本失败: {e}") + return f"[音频转文本失败: {str(e)}]" diff --git a/api/app/services/collaborative_orchestrator.py b/api/app/services/collaborative_orchestrator.py index 00a731de..68181cd1 100644 --- a/api/app/services/collaborative_orchestrator.py +++ b/api/app/services/collaborative_orchestrator.py @@ -445,6 +445,7 @@ class CollaborativeOrchestrator: "provider": api_key_config.provider, "api_key": api_key_config.api_key, "api_base": api_key_config.api_base, + "is_omni": api_key_config.is_omni, "model_parameters": config_data.get("model_parameters", {}), "api_key_id": api_key_config.id } @@ -511,6 +512,7 @@ class CollaborativeOrchestrator: provider=agent_config["provider"], api_key=agent_config["api_key"], base_url=agent_config.get("api_base"), + is_omni=agent_config.get("is_omni", False), extra_params=extra_params ) diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 8977710b..693f1a26 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -415,6 +415,7 @@ class DraftRunService: api_key=api_key_config["api_key"], provider=api_key_config.get("provider", "openai"), api_base=api_key_config.get("api_base"), + is_omni=api_key_config.get("is_omni", False), temperature=effective_params.get("temperature", 0.7), max_tokens=effective_params.get("max_tokens", 2000), system_prompt=system_prompt, @@ -442,7 +443,7 @@ class DraftRunService: if files: # 获取 provider 信息 provider = api_key_config.get("provider", "openai") - multimodal_service = MultimodalService(self.db, provider=provider) + multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False)) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") @@ -683,6 +684,7 @@ class DraftRunService: api_key=api_key_config["api_key"], provider=api_key_config.get("provider", "openai"), api_base=api_key_config.get("api_base"), + is_omni=api_key_config.get("is_omni", False), temperature=effective_params.get("temperature", 0.7), max_tokens=effective_params.get("max_tokens", 2000), system_prompt=system_prompt, @@ -711,7 +713,7 @@ class DraftRunService: if files: # 获取 provider 信息 provider = api_key_config.get("provider", "openai") - multimodal_service = MultimodalService(self.db, provider=provider) + multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False)) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") @@ -809,7 +811,7 @@ class DraftRunService: """ return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" - async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict[str, str]: + async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict: """获取模型的 API Key Args: @@ -846,7 +848,8 @@ class DraftRunService: "provider": api_key.provider, "api_key": api_key.api_key, "api_base": api_key.api_base, - "api_key_id": api_key.id + "api_key_id": api_key.id, + "is_omni": api_key.is_omni } async def _ensure_conversation( diff --git a/api/app/services/handoffs_service.py b/api/app/services/handoffs_service.py index e490eea4..8418fe31 100644 --- a/api/app/services/handoffs_service.py +++ b/api/app/services/handoffs_service.py @@ -544,6 +544,7 @@ def convert_multi_agent_config_to_handoffs( provider=model_api_key.provider, api_key=model_api_key.api_key, base_url=model_api_key.api_base, + is_omni=model_api_key.is_omni, extra_params={ "temperature": 0.7, "max_tokens": 2000, diff --git a/api/app/services/llm_router.py b/api/app/services/llm_router.py index e56ad5aa..02895d6b 100644 --- a/api/app/services/llm_router.py +++ b/api/app/services/llm_router.py @@ -414,6 +414,7 @@ class LLMRouter: provider=api_key_config.provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, temperature=0.3, max_tokens=500 ) diff --git a/api/app/services/master_agent_router.py b/api/app/services/master_agent_router.py index 3cf3ecc3..b0f43b51 100644 --- a/api/app/services/master_agent_router.py +++ b/api/app/services/master_agent_router.py @@ -392,6 +392,7 @@ class MasterAgentRouter: provider=api_key_config.provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, extra_params = extra_params ) diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index aa8cfbac..2337427a 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -90,7 +90,8 @@ class ModelConfigService: api_key: str, api_base: Optional[str] = None, model_type: str = "llm", - test_message: str = "Hello" + test_message: str = "Hello", + is_omni: bool = False ) -> Dict[str, Any]: """验证模型配置是否有效 @@ -102,6 +103,7 @@ class ModelConfigService: api_base: API基础URL model_type: 模型类型 (llm/chat/embedding/rerank) test_message: 测试消息 + is_omni: 是否为Omni模型 Returns: Dict: 验证结果 @@ -114,14 +116,27 @@ class ModelConfigService: try: start_time = time.time() - model_config = RedBearModelConfig( - model_name=model_name, - provider=provider, - api_key=api_key, - base_url=api_base, - temperature=0.7, - max_tokens=100 - ) + # dashscope 的 omni 模型需要使用 compatible-mode + if provider.lower() == ModelProvider.DASHSCOPE and is_omni: + if not api_base: + api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1" + model_config = RedBearModelConfig( + model_name=model_name, + provider=ModelProvider.OPENAI, + api_key=api_key, + base_url=api_base, + temperature=0.7, + max_tokens=100 + ) + else: + model_config = RedBearModelConfig( + model_name=model_name, + provider=provider, + api_key=api_key, + base_url=api_base, + temperature=0.7, + max_tokens=100 + ) # 根据模型类型选择不同的验证方式 model_type_lower = model_type.lower() @@ -257,8 +272,9 @@ class ModelConfigService: provider=model_data.provider, api_key=api_key_data.api_key, api_base=api_key_data.api_base, - model_type=model_data.type, # 传递模型类型 - test_message="Hello" + model_type=model_data.type, + test_message="Hello", + is_omni=model_data.is_omni ) if not validation_result["valid"]: raise BusinessException( @@ -279,6 +295,9 @@ class ModelConfigService: for api_key_data in api_key_datas: api_key_data.model_name = model_data.name api_key_data.provider = model_data.provider + # 同步capability和is_omni + api_key_data.capability = model_data.capability + api_key_data.is_omni = model_data.is_omni api_key_create_schema = ModelApiKeyCreate( model_config_ids=[model.id], **api_key_data.model_dump() @@ -497,6 +516,8 @@ class ModelApiKeyService: existing_key.config = data.config existing_key.priority = data.priority existing_key.model_name = model_name + existing_key.capability = data.capability + existing_key.is_omni = data.is_omni # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: @@ -513,7 +534,8 @@ class ModelApiKeyService: api_key=data.api_key, api_base=data.api_base, model_type=model_config.type, - test_message="Hello" + test_message="Hello", + is_omni=data.is_omni ) if not validation_result["valid"]: # 记录验证失败的模型,但不抛出异常 @@ -528,6 +550,8 @@ class ModelApiKeyService: provider=data.provider, api_key=data.api_key, api_base=data.api_base, + capability=data.capability if data.capability is not None else model_config.capability, + is_omni=data.is_omni if data.is_omni is not None else model_config.is_omni, config=data.config, is_active=data.is_active, priority=data.priority @@ -572,6 +596,8 @@ class ModelApiKeyService: existing_key.config = api_key_data.config existing_key.priority = api_key_data.priority existing_key.model_name = api_key_data.model_name + existing_key.capability = api_key_data.capability + existing_key.is_omni = api_key_data.is_omni # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: @@ -589,7 +615,8 @@ class ModelApiKeyService: api_key=api_key_data.api_key, api_base=api_key_data.api_base, model_type=model_config.type, - test_message="Hello" + test_message="Hello", + is_omni=model_config.is_omni ) if not validation_result["valid"]: raise BusinessException( @@ -620,7 +647,8 @@ class ModelApiKeyService: api_key=api_key_data.api_key or existing_api_key.api_key, api_base=api_key_data.api_base or existing_api_key.api_base, model_type=model_config.type, - test_message="Hello" + test_message="Hello", + is_omni=model_config.is_omni ) if not validation_result["valid"]: raise BusinessException( @@ -755,6 +783,8 @@ class ModelBaseService: "type": model_base.type, "logo": model_base.logo, "description": model_base.description, + "capability": model_base.capability, + "is_omni": model_base.is_omni, "is_composite": False } model_config = ModelConfigRepository.create(db, model_config_data) diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index d1aa46d1..650f639b 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -2593,6 +2593,7 @@ class MultiAgentOrchestrator: provider=api_key_config.provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, temperature=0.7, # 整合任务使用中等温度 max_tokens=2000 ) @@ -2758,6 +2759,7 @@ class MultiAgentOrchestrator: provider=api_key_config.provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, temperature=0.7, max_tokens=2000, extra_params={"streaming": True} # 启用流式输出 diff --git a/api/app/services/multi_agent_service.py b/api/app/services/multi_agent_service.py index c52814ed..751099d5 100644 --- a/api/app/services/multi_agent_service.py +++ b/api/app/services/multi_agent_service.py @@ -267,7 +267,7 @@ class MultiAgentService: # 2. 验证模型配置(如果提供了) if data.default_model_config_id: - model_api_key = ModelApiKeyService.get_a_api_key(self.db, data.default_model_config_id) + model_api_key = ModelApiKeyService.get_available_api_key(self.db, data.default_model_config_id) if not model_api_key: raise ResourceNotFoundException("模型配置", str(data.default_model_config_id)) diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index bfb23a56..9b06c287 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -9,47 +9,100 @@ - OpenAI: 支持 URL 和 base64 格式 """ import uuid -from typing import List, Dict, Any, Optional, Protocol +import httpx +import base64 +from typing import List, Dict, Any, Optional +from abc import ABC, abstractmethod from sqlalchemy.orm import Session +from docx import Document +import io +import PyPDF2 from app.core.logging_config import get_business_logger from app.core.exceptions import BusinessException from app.core.error_codes import BizCode from app.schemas.app_schema import FileInput, FileType, TransferMethod -from app.models.generic_file_model import GenericFile +from app.models.file_metadata_model import FileMetadata +from app.core.config import settings +from app.services.audio_transcription_service import AudioTranscriptionService logger = get_business_logger() -class ImageFormatStrategy(Protocol): - """图片格式策略接口""" +class MultimodalFormatStrategy(ABC): + """多模态格式策略基类""" + + @abstractmethod + async def format_image(self, url: str) -> Dict[str, Any]: + """格式化图片""" + pass + + @abstractmethod + async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + """格式化文档""" + pass + + @abstractmethod + async def format_audio(self, file_type: str, url: str) -> Dict[str, Any]: + """格式化音频""" + pass + + @abstractmethod + async def format_video(self, url: str) -> Dict[str, Any]: + """格式化视频""" + pass + + +class DashScopeFormatStrategy(MultimodalFormatStrategy): + """通义千问策略""" async def format_image(self, url: str) -> Dict[str, Any]: - """将图片 URL 转换为特定 provider 的格式""" - ... - - -class DashScopeImageStrategy: - """通义千问图片格式策略""" - - async def format_image(self, url: str) -> Dict[str, Any]: - """通义千问格式: {"type": "image", "image": "url"}""" + """通义千问图片格式:{"type": "image", "image": "url"}""" return { "type": "image", "image": url } + async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + """通义千问文档格式""" + return { + "type": "text", + "text": f"\n{text}\n" + } -class BedrockImageStrategy: - """Bedrock/Anthropic 图片格式策略""" + async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]: + """ + 通义千问音频格式 + - 原生支持: qwen-audio 系列 + - 其他模型: 需要转录为文本 + """ + if transcription: + return { + "type": "text", + "text": f"" + } + # 通义千问音频格式:{"type": "audio", "audio": "url"} + return { + "type": "audio", + "audio": url + } + + async def format_video(self, url: str) -> Dict[str, Any]: + """通义千问视频格式(qwen-vl 系列原生支持)""" + return { + "type": "video", + "video": url + } + + +class BedrockFormatStrategy(MultimodalFormatStrategy): + """Bedrock/Anthropic 策略""" async def format_image(self, url: str) -> Dict[str, Any]: """ Bedrock/Anthropic 格式: base64 编码 {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}} """ - import httpx - import base64 from mimetypes import guess_type logger.info(f"下载并编码图片: {url}") @@ -84,9 +137,46 @@ class BedrockImageStrategy: } } + async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + """Bedrock/Anthropic 文档格式(需要 base64 编码)""" + # Bedrock 文档需要 base64 编码 + text_bytes = text.encode('utf-8') + base64_text = base64.b64encode(text_bytes).decode('utf-8') -class OpenAIImageStrategy: - """OpenAI 图片格式策略""" + return { + "type": "document", + "source": { + "type": "base64", + "media_type": "text/plain", + "data": base64_text + } + } + + async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]: + """ + Bedrock/Anthropic 音频格式 + 不支持原生音频,必须转录为文本 + """ + if transcription: + return { + "type": "text", + "text": f"[音频转录]\n{transcription}" + } + return { + "type": "text", + "text": "[音频文件:Bedrock 不支持原生音频,请启用音频转文本功能]" + } + + async def format_video(self, url: str) -> Dict[str, Any]: + """Bedrock/Anthropic 视频格式""" + return { + "type": "text", + "text": f"" + } + + +class OpenAIFormatStrategy(MultimodalFormatStrategy): + """OpenAI 策略""" async def format_image(self, url: str) -> Dict[str, Any]: """OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}""" @@ -97,29 +187,97 @@ class OpenAIImageStrategy: } } + async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + """OpenAI 文档格式""" + return { + "type": "text", + "text": f"\n{text}\n" + } + + async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]: + """ + OpenAI 音频格式 + - gpt-4o-audio 系列支持原生音频(需要 base64 编码) + - 其他模型使用转录文本 + """ + if transcription: + return { + "type": "text", + "text": f"" + } + + # OpenAI 音频需要 base64 编码 + try: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(url) + response.raise_for_status() + audio_data = response.content + base64_audio = base64.b64encode(audio_data).decode('utf-8') + # 1. 优先从 file_type (MIME) 取扩展名 + file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None + # 2. 从响应头 content-type 取 + if not file_ext: + ct = response.headers.get("content-type", "") + file_ext = ct.split('/')[-1].split(';')[0].strip() if '/' in ct else None + # 3. 从 URL 路径取扩展名 + if not file_ext: + file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None + # 4. 默认 wav + # supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"} + file_ext = "wav" if not file_ext else file_ext + + return { + "type": "input_audio", + "input_audio": { + "data": f"data:;base64,{base64_audio}", + "format": file_ext + } + } + except Exception as e: + logger.error(f"下载音频失败: {e}") + return { + "type": "text", + "text": f"[音频处理失败: {str(e)}]" + } + + async def format_video(self, url: str) -> Dict[str, Any]: + """OpenAI 视频格式""" + return { + "type": "video_url", + "video_url": { + "url": url + } + } + # Provider 到策略的映射 PROVIDER_STRATEGIES = { - "dashscope": DashScopeImageStrategy, - "bedrock": BedrockImageStrategy, - "anthropic": BedrockImageStrategy, - "openai": OpenAIImageStrategy, + "dashscope": DashScopeFormatStrategy, + "bedrock": BedrockFormatStrategy, + "anthropic": BedrockFormatStrategy, + "openai": OpenAIFormatStrategy, } class MultimodalService: """多模态文件处理服务""" - def __init__(self, db: Session, provider: str = "dashscope"): + def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, enable_audio_transcription: bool = False, is_omni: bool = False): """ 初始化多模态服务 Args: db: 数据库会话 - provider: 模型提供商(dashscope, bedrock, anthropic 等) + provider: 模型提供商(dashscope, bedrock, anthropic, openai 等) + api_key: API 密钥(用于音频转文本) + enable_audio_transcription: 是否启用音频转文本 + is_omni: 是否为 Omni 模型(dashscope 的 omni 模型需要使用 OpenAI 兼容格式) """ self.db = db self.provider = provider.lower() + self.api_key = api_key + self.enable_audio_transcription = enable_audio_transcription + self.is_omni = is_omni async def process_files( self, @@ -137,20 +295,32 @@ class MultimodalService: if not files: return [] + # 获取对应的策略 + # dashscope 的 omni 模型使用 OpenAI 兼容格式 + if self.provider == "dashscope" and self.is_omni: + strategy_class = OpenAIFormatStrategy + else: + strategy_class = PROVIDER_STRATEGIES.get(self.provider) + if not strategy_class: + logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略") + strategy_class = DashScopeFormatStrategy + + strategy = strategy_class() + result = [] for idx, file in enumerate(files): try: if file.type == FileType.IMAGE: - content = await self._process_image(file) + content = await self._process_image(file, strategy) result.append(content) elif file.type == FileType.DOCUMENT: - content = await self._process_document(file) + content = await self._process_document(file, strategy) result.append(content) elif file.type == FileType.AUDIO: - content = await self._process_audio(file) + content = await self._process_audio(file, strategy) result.append(content) elif file.type == FileType.VIDEO: - content = await self._process_video(file) + content = await self._process_video(file, strategy) result.append(content) else: logger.warning(f"不支持的文件类型: {file.type}") @@ -172,55 +342,29 @@ class MultimodalService: logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") return result - async def _process_image(self, file: FileInput) -> Dict[str, Any]: + async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理图片文件 Args: file: 图片文件输入 + strategy: 格式化策略 Returns: - Dict: 根据 provider 返回不同格式 - - Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}} - - 通义千问: {"type": "image", "image": "url"} + Dict: 根据 provider 返回不同格式的图片内容 """ - url = await self.get_file_url(file) - - logger.debug(f"处理图片: {url}, provider={self.provider}") - - # 根据 provider 返回不同格式 - if self.provider in ["bedrock", "anthropic"]: - # Anthropic/Bedrock 只支持 base64 格式,需要下载并转换 - try: - logger.info(f"开始下载并编码图片: {url}") - base64_data, media_type = await self._download_and_encode_image(url) - result = { - "type": "image", - "source": { - "type": "base64", - "media_type": media_type, - "data": base64_data[:100] + "..." # 只记录前100个字符 - } - } - logger.info(f"图片编码完成: media_type={media_type}, data_length={len(base64_data)}") - # 返回完整数据 - result["source"]["data"] = base64_data - return result - except Exception as e: - logger.error(f"下载并编码图片失败: {e}", exc_info=True) - # 返回错误提示 - return { - "type": "text", - "text": f"[图片加载失败: {str(e)}]" - } - else: - # 通义千问等其他格式支持 URL + try: + url = await self.get_file_url(file) + return await strategy.format_image(url) + except Exception as e: + logger.error(f"处理图片失败: {e}", exc_info=True) return { - "type": "image", - "image": url + "type": "text", + "text": f"[图片处理失败: {str(e)}]" } - async def _download_and_encode_image(self, url: str) -> tuple[str, str]: + @staticmethod + async def _download_and_encode_image(url: str) -> tuple[str, str]: """ 下载图片并转换为 base64 @@ -230,8 +374,6 @@ class MultimodalService: Returns: tuple: (base64_data, media_type) """ - import httpx - import base64 from mimetypes import guess_type # 下载图片 @@ -258,15 +400,16 @@ class MultimodalService: return base64_data, media_type - async def _process_document(self, file: FileInput) -> Dict[str, Any]: + async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理文档文件(PDF、Word 等) Args: file: 文档文件输入 + strategy: 格式化策略 Returns: - Dict: text 格式的内容(包含提取的文本) + Dict: 根据 provider 返回不同格式的文档内容 """ if file.transfer_method == TransferMethod.REMOTE_URL: # 远程文档暂不支持提取 @@ -277,48 +420,68 @@ class MultimodalService: else: # 本地文件,提取文本内容 text = await self._extract_document_text(file.upload_file_id) - generic_file = self.db.query(GenericFile).filter( - GenericFile.id == file.upload_file_id + file_metadata = self.db.query(FileMetadata).filter( + FileMetadata.id == file.upload_file_id ).first() - file_name = generic_file.file_name if generic_file else "unknown" + file_name = file_metadata.file_name if file_metadata else "unknown" - return { - "type": "text", - "text": f"\n{text}\n" - } + # 使用策略格式化文档 + return await strategy.format_document(file_name, text) - async def _process_audio(self, file: FileInput) -> Dict[str, Any]: + async def _process_audio(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理音频文件 Args: file: 音频文件输入 + strategy: 格式化策略 Returns: - Dict: 音频内容(暂时返回占位符) + Dict: 根据 provider 返回不同格式的音频内容 """ - # TODO: 实现音频转文字功能 - return { - "type": "text", - "text": "[音频文件,暂不支持处理]" - } + try: + url = await self.get_file_url(file) - async def _process_video(self, file: FileInput) -> Dict[str, Any]: + # 如果启用音频转文本且有 API Key + transcription = None + if self.enable_audio_transcription and self.api_key: + logger.info(f"开始音频转文本: {url}") + if self.provider == "dashscope": + transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key) + elif self.provider == "openai": + transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key) + else: + logger.warning(f"Provider {self.provider} 不支持音频转文本") + + return await strategy.format_audio(file.file_type, url, transcription) + except Exception as e: + logger.error(f"处理音频失败: {e}", exc_info=True) + return { + "type": "text", + "text": f"[音频处理失败: {str(e)}]" + } + + async def _process_video(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理视频文件 Args: file: 视频文件输入 + strategy: 格式化策略 Returns: - Dict: 视频内容(暂时返回占位符) + Dict: 根据 provider 返回不同格式的视频内容 """ - # TODO: 实现视频处理功能 - return { - "type": "text", - "text": "[视频文件,暂不支持处理]" - } + try: + url = await self.get_file_url(file) + return await strategy.format_video(url) + except Exception as e: + logger.error(f"处理视频失败: {e}", exc_info=True) + return { + "type": "text", + "text": f"[视频处理失败: {str(e)}]" + } async def get_file_url(self, file: FileInput) -> str: """ @@ -336,26 +499,22 @@ class MultimodalService: if file.transfer_method == TransferMethod.REMOTE_URL: return file.url else: - # 本地文件,通过 file_storage 系统获取永久访问 URL - from app.models.file_metadata_model import FileMetadata - from app.core.config import settings - file_id = file.upload_file_id print("="*50) print("file_id",file_id) - + # 查询 FileMetadata file_metadata = self.db.query(FileMetadata).filter( FileMetadata.id == file_id, FileMetadata.status == "completed" ).first() - + if not file_metadata: raise BusinessException( f"文件不存在或已删除: {file_id}", BizCode.NOT_FOUND ) - + # 返回永久URL server_url = settings.FILE_LOCAL_SERVER_URL return f"{server_url}/storage/permanent/{file_id}" @@ -370,58 +529,79 @@ class MultimodalService: Returns: str: 提取的文本内容 """ - generic_file = self.db.query(GenericFile).filter( - GenericFile.id == file_id, - GenericFile.status == "active" + file_metadata = self.db.query(FileMetadata).filter( + FileMetadata.id == file_id, + FileMetadata.status == "completed" ).first() - if not generic_file: + if not file_metadata: raise BusinessException( f"文件不存在或已删除: {file_id}", BizCode.NOT_FOUND ) - # TODO: 根据文件类型提取文本 - # - PDF: 使用 PyPDF2 或 pdfplumber - # - Word: 使用 python-docx - # - TXT/MD: 直接读取 - - file_ext = generic_file.file_ext.lower() + file_ext = file_metadata.file_ext.lower() + server_url = settings.FILE_LOCAL_SERVER_URL + file_url = f"{server_url}/storage/permanent/{file_id}" if file_ext in ['.txt', '.md', '.markdown']: - return await self._read_text_file(generic_file.storage_path) + return await self._read_text_file(file_url) elif file_ext == '.pdf': - return await self._extract_pdf_text(generic_file.storage_path) + return await self._extract_pdf_text(file_url) elif file_ext in ['.doc', '.docx']: - return await self._extract_word_text(generic_file.storage_path) + return await self._extract_word_text(file_url) else: return f"[不支持的文档格式: {file_ext}]" - async def _read_text_file(self, storage_path: str) -> str: + @staticmethod + async def _read_text_file(file_url: str) -> str: """读取纯文本文件""" try: - with open(storage_path, 'r', encoding='utf-8') as f: - return f.read() + # 下载文件 + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(file_url) + response.raise_for_status() + return response.text except Exception as e: logger.error(f"读取文本文件失败: {e}") return f"[文件读取失败: {str(e)}]" - async def _extract_pdf_text(self, storage_path: str) -> str: + @staticmethod + async def _extract_pdf_text(file_url: str) -> str: """提取 PDF 文本""" try: - # TODO: 实现 PDF 文本提取 - # import PyPDF2 或 pdfplumber - return "[PDF 文本提取功能待实现]" + # 下载 PDF 文件 + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(file_url) + response.raise_for_status() + pdf_data = response.content + + # 使用 BytesIO 读取 PDF + text_parts = [] + pdf_file = io.BytesIO(pdf_data) + pdf_reader = PyPDF2.PdfReader(pdf_file) + for page in pdf_reader.pages: + text_parts.append(page.extract_text()) + return '\n'.join(text_parts) except Exception as e: logger.error(f"提取 PDF 文本失败: {e}") return f"[PDF 提取失败: {str(e)}]" - async def _extract_word_text(self, storage_path: str) -> str: + @staticmethod + async def _extract_word_text(file_url: str) -> str: """提取 Word 文档文本""" try: - # TODO: 实现 Word 文本提取 - # import docx - return "[Word 文本提取功能待实现]" + # 下载 Word 文件 + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(file_url) + response.raise_for_status() + word_data = response.content + + # 使用 BytesIO 读取 Word 文档 + word_file = io.BytesIO(word_data) + doc = Document(word_file) + text_parts = [paragraph.text for paragraph in doc.paragraphs] + return '\n'.join(text_parts) except Exception as e: logger.error(f"提取 Word 文本失败: {e}") return f"[Word 提取失败: {str(e)}]" diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 99edcc0e..184220a8 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -184,7 +184,8 @@ class PromptOptimizerService: model_name=api_config.model_name, provider=api_config.provider, api_key=api_config.api_key, - base_url=api_config.api_base + base_url=api_config.api_base, + is_omni=api_config.is_omni ), type=ModelType(model_config.type)) try: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index 89d3f3d6..0d659832 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -247,6 +247,7 @@ class SharedChatService: api_key=api_key_obj.api_key, provider=api_key_obj.provider, api_base=api_key_obj.api_base, + is_omni=api_key_obj.is_omni, temperature=model_parameters.get("temperature", 0.7), max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, @@ -454,6 +455,7 @@ class SharedChatService: api_key=api_key_obj.api_key, provider=api_key_obj.provider, api_base=api_key_obj.api_base, + is_omni=api_key_obj.is_omni, temperature=model_parameters.get("temperature", 0.7), max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, From 0def474cc219bd3f2b84dbe0e9c6edca66d86b8c Mon Sep 17 00:00:00 2001 From: zhaoying Date: Thu, 5 Mar 2026 10:30:35 +0800 Subject: [PATCH 091/164] feat(web): app's chat support audio/video/document file --- web/src/components/AudioRecorder/index.tsx | 21 +- web/src/components/Chat/ChatInput.tsx | 27 +- web/src/i18n/en.ts | 2 + web/src/i18n/zh.ts | 2 + web/src/utils/stream.ts | 14 +- .../ApplicationConfig/components/Chat.tsx | 260 ++++++++++-------- .../Conversation/components/FileUpload.tsx | 53 +++- .../components/UploadFileListModal.tsx | 6 +- web/src/views/Conversation/index.tsx | 41 ++- .../views/Workflow/components/Chat/Chat.tsx | 24 +- 10 files changed, 284 insertions(+), 166 deletions(-) diff --git a/web/src/components/AudioRecorder/index.tsx b/web/src/components/AudioRecorder/index.tsx index f6a030b4..d31746f6 100644 --- a/web/src/components/AudioRecorder/index.tsx +++ b/web/src/components/AudioRecorder/index.tsx @@ -1,16 +1,21 @@ import { type FC, useRef, useState } from 'react' import RecordRTC from 'recordrtc' -import { fileUpload } from '@/api/fileStorage' +import { fileUploadUrlWithoutApiPrefix } from '@/api/fileStorage' +import { request } from '@/utils/request' interface AudioRecorderProps { - onRecordingComplete?: (file: { file_id: string; file_key: string; }, blob: Blob) => void - className?: string + onRecordingComplete?: (file: { file_id: string; file_key: string; url: string; type?: string; }, blob?: Blob) => void + className?: string; + action?: string; + requestConfig?: Record; } const AudioRecorder: FC = ({ onRecordingComplete, className = '', + action = fileUploadUrlWithoutApiPrefix, + requestConfig = {} }) => { const [isRecording, setIsRecording] = useState(false) const recorderRef = useRef(null) @@ -33,11 +38,17 @@ const AudioRecorder: FC = ({ if (recorderRef.current) { recorderRef.current.stopRecording(() => { const blob = recorderRef.current!.getBlob() + const url = recorderRef.current!.toURL() const formData = new FormData() formData.append('file', blob, `recording_${Date.now()}.webm`) - fileUpload(formData) + request + .uploadFile(action, formData, requestConfig) .then(res => { - onRecordingComplete?.(res as { file_id: string; file_key: string; }, blob) + onRecordingComplete?.({ + ...(res as { file_id: string; file_key: string }), + type: blob.type, + url + }, blob) recorderRef.current?.destroy() recorderRef.current = null }) diff --git a/web/src/components/Chat/ChatInput.tsx b/web/src/components/Chat/ChatInput.tsx index c155bb22..49fb65d2 100644 --- a/web/src/components/Chat/ChatInput.tsx +++ b/web/src/components/Chat/ChatInput.tsx @@ -2,10 +2,11 @@ * @Author: ZhaoYing * @Date: 2025-12-10 16:46:14 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-10 12:13:52 + * @Last Modified time: 2026-03-04 18:42:49 */ import { type FC, useEffect, useMemo } from 'react' import { Flex, Input, Form } from 'antd' + import SendIcon from '@/assets/images/conversation/send.svg' import SendDisabledIcon from '@/assets/images/conversation/sendDisabled.svg' import LoadingIcon from '@/assets/images/conversation/loading.svg' @@ -80,9 +81,31 @@ const ChatInput: FC = ({ ) } + if (file.type.includes('video')) { + return ( +
+
+ ) + } + if (file.type.includes('audio')) { + return ( +
+
+ ) + } return (
- {(file.type.includes('word') || file.type.includes('wordprocessingml.document')) &&
} {(file.type.includes('pdf')) &&
= ({ chatList, data, updateChatList, handleSave, sourc content: '', created_at: Date.now(), }; - + if (isCluster) { updateChatList(prev => prev.map(item => ({ ...item, @@ -134,7 +134,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc }) } /** Update assistant message when error occurs */ - const updateErrorAssistantMessage = (message_length: number, model_config_id?: string) => { + const updateErrorAssistantMessage = (message_length: number, model_config_id?: string) => { if (message_length > 0 || !model_config_id) return updateChatList(prev => { @@ -217,6 +217,8 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc } } if (!isCanSend) { + setLoading(false) + setCompareLoading(false) return } runCompare(data.app_id, { @@ -243,7 +245,15 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc "stream": true, "timeout": 60, }, handleStreamMessage) - .finally(() => setLoading(false)); + .catch(() => { + setLoading(false) + setCompareLoading(false) + updateClusterErrorAssistantMessage(0) + }) + .finally(() => { + setLoading(false) + setCompareLoading(false) + }) }, 0) }) .catch(() => { @@ -288,7 +298,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc }) } /** Update cluster message when error occurs */ - const updateClusterErrorAssistantMessage = (message_length: number) => { + const updateClusterErrorAssistantMessage = (message_length: number) => { if (message_length > 0) return updateChatList(prev => { @@ -331,7 +341,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc data.map(item => { const { conversation_id, content, message_length } = item.data as { conversation_id: string, content: string, message_length: number }; - switch(item.event) { + switch (item.event) { case 'start': if (conversation_id && conversationId !== conversation_id) { setConversationId(conversation_id); @@ -354,27 +364,35 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc }; setTimeout(() => { - draftRun( - data.app_id, - { - message, - conversation_id: conversationId, - stream: true, - files: fileList.map(file => { - if (file.url) { - return file - } else { - return { - type: file.type, - transfer_method: 'local_file', - upload_file_id: file.response.data.file_id - } + draftRun( + data.app_id, + { + message, + conversation_id: conversationId, + stream: true, + files: fileList.map(file => { + if (file.url) { + return file + } else { + return { + type: file.type, + transfer_method: 'local_file', + upload_file_id: file.response.data.file_id } - }), - }, - handleStreamMessage - ) - .finally(() => setLoading(false)) + } + }), + }, + handleStreamMessage + ) + .catch(() => { + setLoading(false) + setCompareLoading(false) + updateClusterErrorAssistantMessage(0) + }) + .finally(() => { + setLoading(false) + setCompareLoading(false) + }) }, 0) }) .catch(() => { @@ -393,12 +411,17 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc const fileChange = (file?: any) => { setFileList([...fileList, file]) } - // const handleRecordingComplete = async (file: any) => { - // console.log('file', file) - // } + const handleRecordingComplete = async (file: any) => { + setFileList([...fileList, { + uid: file.file_id, + response: { data: file }, + thumbUrl: file.url, + type: file.type + }]) + } const handleShowUpload: MenuProps['onClick'] = ({ key }) => { - switch(key) { + switch (key) { case 'define': uploadFileListModalRef.current?.handleOpen() break @@ -415,99 +438,98 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc return (
{chatList.length === 0 - ? - : <> -
- {chatList.map((chat, index) => ( -
1, - })}> - {chat.label && -
-
-
{chat.label}
-
handleDelete(index)} - >
+ : <> +
+ {chatList.map((chat, index) => ( +
1, + })}> + {chat.label && +
+
+
{chat.label}
+
handleDelete(index)} + >
+
-
- } - } - data={chat.list || []} - streamLoading={compareLoading} - labelPosition="top" - labelFormat={(item) => item.role === 'user' ? t('application.you') : chat.label} - errorDesc={t('application.ReplyException')} - /> -
- ))} -
-
- - - - - ) - }, - ], - onClick: handleShowUpload + } + -
-
+ contentClassNames={{ + 'rb:max-w-[400px]!': chatList.length === 1, + 'rb:max-w-[260px]!': chatList.length === 2, + 'rb:max-w-[150px]!': chatList.length === 3, + 'rb:max-w-[108px]!': chatList.length === 4, + }} + empty={} + data={chat.list || []} + streamLoading={compareLoading} + labelPosition="top" + labelFormat={(item) => item.role === 'user' ? t('application.you') : chat.label} + errorDesc={t('application.ReplyException')} + /> +
+ ))} +
+
+ + + + + ) + }, + ], + onClick: handleShowUpload + }} + > +
+
+
+ + + +
- {/* - - - */} - -
-
- + +
+ } { /** Custom file removal callback */ onRemove?: (file: UploadFile) => boolean | void | Promise; } + +const transform_file_type = { + 'text/plain': 'document/text', + 'text/markdown': 'document/markdown', + 'text/x-markdown': 'document/x-markdown', + + 'application/pdf': 'document/pdf', + + 'application/msword': 'document/doc', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'document/docx', + + 'application/vnd.ms-powerpoint': 'document/ppt', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'document/pptx', +} // Mapping of file extensions to MIME types const ALL_FILE_TYPE: { [key: string]: string; } = { - // txt: 'text/plain', + txt: 'text/plain', + md: 'text/markdown', + xmd: 'text/x-markdown', + pdf: 'application/pdf', doc: 'application/msword', docx: 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', - - xls: 'application/vnd.ms-excel', - xlsx: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', - csv: 'text/csv', ppt: 'application/vnd.ms-powerpoint', pptx: 'application/vnd.openxmlformats-officedocument.presentationml.presentation', - - // md: 'text/markdown', - // htm: 'text/html', - // html: 'text/html', - // json: 'application/json', + jpg: 'image/jpeg', jpeg: 'image/jpeg', png: 'image/png', @@ -84,6 +94,23 @@ const ALL_FILE_TYPE: { bmp: 'image/bmp', webp: 'image/webp', svg: 'image/svg+xml', + + mp4: 'video/mp4', + mov: 'video/quicktime', + avi: 'video/x-msvideo', + mkv: 'video/x-matroska', + webm: 'video/webm', + flv: 'video/x-flv', + wmv: 'video/x-ms-wmv', + + mp3: 'audio/mpeg', + wav: 'audio/wav', + ogg: 'audio/ogg', + aac: 'audio/aac', + flac: 'audio/flac', + m4a: 'audio/mp4', + wma: 'audio/x-ms-wma', + xm4a: 'audio/x-m4a', } export interface UploadFilesRef { /** Current file list */ @@ -178,6 +205,10 @@ const UploadFiles = forwardRef(({ * Handles upload state changes */ const handleChange: UploadProps['onChange'] = ({ fileList: newFileList }) => { + newFileList.map(file => { + const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type + file.type = type + }) setFileList(newFileList); if (onChange) { onChange(maxCount === 1 ? newFileList[newFileList.length - 1] : newFileList); diff --git a/web/src/views/Conversation/components/UploadFileListModal.tsx b/web/src/views/Conversation/components/UploadFileListModal.tsx index c5110701..a43b9dd4 100644 --- a/web/src/views/Conversation/components/UploadFileListModal.tsx +++ b/web/src/views/Conversation/components/UploadFileListModal.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-06 21:09:47 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-09 10:17:54 + * @Last Modified time: 2026-03-04 17:47:09 */ /** * Upload File List Modal Component @@ -104,7 +104,9 @@ const UploadFileListModal = forwardRef diff --git a/web/src/views/Conversation/index.tsx b/web/src/views/Conversation/index.tsx index f532ac53..5509ad0a 100644 --- a/web/src/views/Conversation/index.tsx +++ b/web/src/views/Conversation/index.tsx @@ -14,7 +14,7 @@ import { type FC, useState, useEffect, useRef } from 'react' import { useParams, useLocation } from 'react-router-dom' import { useTranslation } from 'react-i18next' import InfiniteScroll from 'react-infinite-scroll-component'; -import { Flex, Skeleton, Form, Dropdown, type MenuProps, App } from 'antd' +import { Flex, Skeleton, Form, Dropdown, type MenuProps, App, Divider } from 'antd' import { SettingOutlined } from '@ant-design/icons' import clsx from 'clsx' import dayjs from 'dayjs' @@ -35,7 +35,7 @@ import OnlineCheckedIcon from '@/assets/images/conversation/onlineChecked.svg' import MemoryFunctionCheckedIcon from '@/assets/images/conversation/memoryFunctionChecked.svg' import { type SSEMessage } from '@/utils/stream' import UploadFiles from './components/FileUpload' -// import AudioRecorder from '@/components/AudioRecorder' +import AudioRecorder from '@/components/AudioRecorder' import { shareFileUploadUrlWithoutApiPrefix } from '@/api/fileStorage' import UploadFileListModal from './components/UploadFileListModal' import type { VariableConfigModalRef } from '@/views/Workflow/types' @@ -305,17 +305,27 @@ const Conversation: FC = () => { }), variables: params }, handleStreamMessage, shareToken) + .catch(() => { + setLoading(false) + setStreamLoading(false) + }) .finally(() => { setLoading(false) + setStreamLoading(false) }) } const fileChange = (file?: any) => { form.setFieldValue('files', [...(queryValues.files || []), file]) } - // const handleRecordingComplete = async (file: any) => { - // console.log('file', file) - // } + const handleRecordingComplete = async (file: any) => { + form.setFieldValue('files', [...(queryValues.files || []), { + uid: file.file_id, + response: { data: file }, + thumbUrl: file.url, + type: file.type + }]) + } const handleShowUpload: MenuProps['onClick'] = ({ key }) => { switch(key) { @@ -329,6 +339,7 @@ const Conversation: FC = () => { form.setFieldValue('files', [...(queryValues.files || []), ...fileList]) } const updateFileList = (fileList?: any[]) => { + console.log('fileList', fileList) form.setFieldValue('files', [...(fileList || [])]) } @@ -383,7 +394,7 @@ const Conversation: FC = () => {
} - contentClassName="rb:h-[calc(100%-180px)]" + contentClassName={!queryValues?.files?.length ? "rb:h-[calc(100%-144px)]" : "rb:h-[calc(100%-208px)]"} data={chatList} streamLoading={streamLoading} loading={loading} @@ -405,13 +416,12 @@ const Conversation: FC = () => { key: 'upload', label: ( ) }, @@ -455,10 +465,19 @@ const Conversation: FC = () => { )} - {/* - + + - */} + diff --git a/web/src/views/Workflow/components/Chat/Chat.tsx b/web/src/views/Workflow/components/Chat/Chat.tsx index 65989b30..f8049cb7 100644 --- a/web/src/views/Workflow/components/Chat/Chat.tsx +++ b/web/src/views/Workflow/components/Chat/Chat.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-06 21:10:56 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-28 16:43:06 + * @Last Modified time: 2026-03-04 18:51:48 */ /** * Workflow Chat Component @@ -23,7 +23,7 @@ */ import { forwardRef, useImperativeHandle, useState, useRef } from 'react' import { useTranslation } from 'react-i18next' -import { App, Space, Button, Flex, Dropdown, type MenuProps } from 'antd' +import { App, Space, Button, Flex, Dropdown, type MenuProps, Divider } from 'antd' import ChatIcon from '@/assets/images/application/chat.png' import RbDrawer from '@/components/RbDrawer'; @@ -38,7 +38,7 @@ import { type SSEMessage } from '@/utils/stream' import type { Variable } from '../Properties/VariableList/types' import ChatInput from '@/components/Chat/ChatInput' import UploadFiles from '@/views/Conversation/components/FileUpload' -// import AudioRecorder from '@/components/AudioRecorder' +import AudioRecorder from '@/components/AudioRecorder' import UploadFileListModal from '@/views/Conversation/components/UploadFileListModal' import type { UploadFileListModalRef } from '@/views/Conversation/types' import Runtime from './Runtime'; @@ -359,6 +359,7 @@ const Chat = forwardRef(({ appId setStreamLoading(true) draftRun(appId, data, handleStreamMessage) .catch((error) => { + console.log('draftRun error', error) setChatList(prev => { const newList = [...prev] const lastIndex = newList.length - 1 @@ -390,9 +391,13 @@ const Chat = forwardRef(({ appId const fileChange = (file?: any) => { setFileList([...fileList, file]) } - // const handleRecordingComplete = async (file: any) => { - // console.log('file', file) - // } + const handleRecordingComplete = async (file: any) => { + setFileList([...fileList, { + response: { data: file }, + thumbUrl: file.url, + type: file.type + }]) + } /** * Handles dropdown menu actions for file upload @@ -424,6 +429,8 @@ const Chat = forwardRef(({ appId handleClose })); + console.log('fileList', fileList) + return ( @@ -470,7 +477,6 @@ const Chat = forwardRef(({ appId { key: 'upload', label: ( ) @@ -484,10 +490,10 @@ const Chat = forwardRef(({ appId >
- {/* + - */} +
From 2bd364eca30e6bf60f00a7473014b78870777e1a Mon Sep 17 00:00:00 2001 From: Mark Date: Thu, 5 Mar 2026 10:46:31 +0800 Subject: [PATCH 092/164] [add] migration script --- .../versions/b4af97639217_202603051033.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 api/migrations/versions/b4af97639217_202603051033.py diff --git a/api/migrations/versions/b4af97639217_202603051033.py b/api/migrations/versions/b4af97639217_202603051033.py new file mode 100644 index 00000000..ddeae41c --- /dev/null +++ b/api/migrations/versions/b4af97639217_202603051033.py @@ -0,0 +1,63 @@ +"""202603051033 + +Revision ID: b4af97639217 +Revises: 4bf27c66ae63 +Create Date: 2026-03-05 10:36:06.282227 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'b4af97639217' +down_revision: Union[str, None] = '4bf27c66ae63' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # Add columns as nullable first to avoid table locks + op.add_column('model_api_keys', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])")) + op.add_column('model_api_keys', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)')) + + op.add_column('model_bases', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])")) + op.add_column('model_bases', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)')) + + op.add_column('model_configs', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])")) + op.add_column('model_configs', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)')) + + # Update existing rows with default values + op.execute("UPDATE model_api_keys SET capability = '{}' WHERE capability IS NULL") + op.execute("UPDATE model_api_keys SET is_omni = false WHERE is_omni IS NULL") + + op.execute("UPDATE model_bases SET capability = '{}' WHERE capability IS NULL") + op.execute("UPDATE model_bases SET is_omni = false WHERE is_omni IS NULL") + + op.execute("UPDATE model_configs SET capability = '{}' WHERE capability IS NULL") + op.execute("UPDATE model_configs SET is_omni = false WHERE is_omni IS NULL") + + # Now make columns NOT NULL + op.alter_column('model_api_keys', 'capability', nullable=False) + op.alter_column('model_api_keys', 'is_omni', nullable=False) + + op.alter_column('model_bases', 'capability', nullable=False) + op.alter_column('model_bases', 'is_omni', nullable=False) + + op.alter_column('model_configs', 'capability', nullable=False) + op.alter_column('model_configs', 'is_omni', nullable=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('model_configs', 'is_omni') + op.drop_column('model_configs', 'capability') + op.drop_column('model_bases', 'is_omni') + op.drop_column('model_bases', 'capability') + op.drop_column('model_api_keys', 'is_omni') + op.drop_column('model_api_keys', 'capability') + # ### end Alembic commands ### From 1b666638bc94f0a775fc9b6748854c4fdaedfafe Mon Sep 17 00:00:00 2001 From: zhaoying Date: Thu, 5 Mar 2026 10:58:25 +0800 Subject: [PATCH 093/164] feat(web): add SYSTEM_DEFAULT_SCENE_CANNOT_DELETE error i18n --- web/src/i18n/en.ts | 1 + web/src/i18n/zh.ts | 1 + web/src/utils/request.ts | 4 +++- web/src/views/Ontology/index.tsx | 12 +++++++----- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 02add0ec..7cef2d6c 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -440,6 +440,7 @@ export const en = { logoutApiCannotRefreshToken: 'Logout API cannot refresh token', publicApiCannotRefreshToken: 'Public API cannot refresh token', refreshTokenNotExist: 'Refresh token does not exist', + SYSTEM_DEFAULT_SCENE_CANNOT_DELETE: 'This is a system preset scene and cannot be deleted', reset: 'Reset', refresh: 'Refresh', return: 'Return', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 06abf63a..5c688934 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1016,6 +1016,7 @@ export const zh = { logoutApiCannotRefreshToken: '退出登录接口不能刷新token', publicApiCannotRefreshToken: '公共接口不能刷新token', refreshTokenNotExist: '刷新token不存在', + SYSTEM_DEFAULT_SCENE_CANNOT_DELETE: '该场景为系统预设场景,不允许删除', reset: '重置', refresh: '刷新', return: '返回', diff --git a/web/src/utils/request.ts b/web/src/utils/request.ts index 3c3e8fa2..f58f5f65 100644 --- a/web/src/utils/request.ts +++ b/web/src/utils/request.ts @@ -183,7 +183,9 @@ service.interceptors.response.use( msg = msg || i18n.t('common.serverError'); break; default: - if (!msg && Array.isArray(error.response?.data?.detail)) { + if (msg === 'SYSTEM_DEFAULT_SCENE_CANNOT_DELETE') { + msg = i18n.t(`common.${msg}`) + } else if (!msg && Array.isArray(error.response?.data?.detail)) { msg = error.response?.data?.detail?.map((item: { msg: string }) => item.msg).join(';') } else { msg = msg || i18n.t('common.unknownError'); diff --git a/web/src/views/Ontology/index.tsx b/web/src/views/Ontology/index.tsx index 42a6544f..37f9118d 100644 --- a/web/src/views/Ontology/index.tsx +++ b/web/src/views/Ontology/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 14:10:15 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 14:10:15 + * @Last Modified time: 2026-03-05 10:57:53 */ import { type FC, useState, useRef, type MouseEvent } from 'react'; import { useNavigate } from 'react-router-dom'; @@ -164,11 +164,13 @@ const Ontology: FC = () => {
))} - +
{t('ontology.entityTypes')}:
- {item.entity_type?.map((type, i) => ( - {type} - ))} +
+ {item.entity_type?.map((type, i) => ( + {type} + ))} +
{item.type_num > 3 && ( +{item.type_num - 3} )} From b5ba53208e41fa3c265e859b92bb1ccb0401b102 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Thu, 5 Mar 2026 11:05:51 +0800 Subject: [PATCH 094/164] feat(web): chat variable support paragraph --- web/src/views/Workflow/components/Chat/VariableConfigModal.tsx | 1 + 1 file changed, 1 insertion(+) diff --git a/web/src/views/Workflow/components/Chat/VariableConfigModal.tsx b/web/src/views/Workflow/components/Chat/VariableConfigModal.tsx index 5acd3eb1..66491ab7 100644 --- a/web/src/views/Workflow/components/Chat/VariableConfigModal.tsx +++ b/web/src/views/Workflow/components/Chat/VariableConfigModal.tsx @@ -80,6 +80,7 @@ const VariableConfigModal = forwardRef } + { field.type === 'paragraph' && } { field.type === 'number' && form.setFieldValue(['variables', name, 'value'], value)} /> } From e511b149330ca42dd36dc1b8cbfc1ecc463651f2 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Thu, 5 Mar 2026 11:06:46 +0800 Subject: [PATCH 095/164] [fix] Deleting the default scene results in a 400 status code. A unified language pop-up prompt is displayed. --- api/app/controllers/ontology_controller.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 42d4bee0..c892b013 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -523,10 +523,9 @@ async def delete_scene( f"尝试删除系统默认场景: user_id={current_user.id}, " f"scene_id={scene_id}, scene_name={scene.scene_name}" ) - return fail( - BizCode.BAD_REQUEST, - "系统默认场景不可删除", - "该场景为系统预设场景,不允许删除" + raise HTTPException( + status_code=400, + detail="SYSTEM_DEFAULT_SCENE_CANNOT_DELETE" ) # 创建OntologyService实例 @@ -552,6 +551,9 @@ async def delete_scene( return success(data={"deleted": success_flag}, msg="场景删除成功") + except HTTPException: + raise + except ValueError as e: api_logger.warning(f"Validation error in scene deletion: {str(e)}") return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) From 16c1cbe24fa9fe5c1f21d2c2c6a53e5fb292d5d0 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Thu, 5 Mar 2026 11:17:56 +0800 Subject: [PATCH 096/164] feat(agent): add input variable validation --- api/app/controllers/app_controller.py | 16 +- api/app/core/models/base.py | 59 +-- api/app/core/workflow/nodes/agent/node.py | 6 +- api/app/schemas/api_key_schema.py | 9 +- api/app/schemas/multi_agent_schema.py | 6 +- api/app/services/agent_tools.py | 4 +- api/app/services/app_chat_service.py | 198 ++------- api/app/services/app_service.py | 413 ------------------ api/app/services/draft_run_service.py | 437 +++++++++---------- api/app/services/langchain_tool_server.py | 37 +- api/app/services/multi_agent_orchestrator.py | 17 +- api/app/services/skill_service.py | 2 +- api/app/services/tool_service.py | 8 +- 13 files changed, 330 insertions(+), 882 deletions(-) diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 653f616c..cdf94345 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -396,10 +396,10 @@ async def draft_run( from app.models import AgentConfig, ModelConfig from sqlalchemy import select from app.core.exceptions import BusinessException - from app.services.draft_run_service import DraftRunService + from app.services.draft_run_service import AgentRunService service = AppService(db) - draft_service = DraftRunService(db) + draft_service = AgentRunService(db) # 1. 验证应用 app = service._get_app_or_404(app_id) @@ -484,8 +484,8 @@ async def draft_run( } ) - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) result = await draft_service.run( agent_config=agent_cfg, model_config=model_config, @@ -789,8 +789,8 @@ async def draft_run_compare( # 流式返回 if payload.stream: async def event_generator(): - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) async for event in draft_service.run_compare_stream( agent_config=agent_cfg, models=model_configs, @@ -820,8 +820,8 @@ async def draft_run_compare( ) # 非流式返回 - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) result = await draft_service.run_compare( agent_config=agent_cfg, models=model_configs, diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index 5d4dbd10..dba6717d 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -21,6 +21,7 @@ from pydantic import BaseModel, Field T = TypeVar("T") + class RedBearModelConfig(BaseModel): """模型配置基类""" model_name: str @@ -32,17 +33,18 @@ class RedBearModelConfig(BaseModel): timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0"))) # 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置 max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2"))) - concurrency: int = 5 # 并发限流 + concurrency: int = 5 # 并发限流 extra_params: Dict[str, Any] = {} + class RedBearModelFactory: """模型工厂类""" - + @classmethod def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]: """根据提供商获取模型参数""" provider = config.provider.lower() - + # 打印供应商信息用于调试 from app.core.logging_config import get_business_logger logger = get_business_logger() @@ -87,7 +89,7 @@ class RedBearModelFactory: "timeout": timeout_config, "max_retries": config.max_retries, **config.extra_params - } + } elif provider == ModelProvider.DASHSCOPE: # DashScope (通义千问) 使用自己的参数格式 # 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数 @@ -104,7 +106,7 @@ class RedBearModelFactory: # region 从 base_url 或 extra_params 获取 from botocore.config import Config as BotoConfig from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id - + max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50")) max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2")) # Configure with increased connection pool @@ -112,16 +114,16 @@ class RedBearModelFactory: max_pool_connections=max_pool_connections, retries={'max_attempts': max_retries, 'mode': 'adaptive'} ) - + # 标准化模型 ID(自动转换简化名称为完整 Bedrock Model ID) model_id = normalize_bedrock_model_id(config.model_name) - + params = { "model_id": model_id, "config": boto_config, **config.extra_params } - + # 解析 API key (格式: access_key_id:secret_access_key) if config.api_key and ":" in config.api_key: access_key_id, secret_access_key = config.api_key.split(":", 1) @@ -129,51 +131,52 @@ class RedBearModelFactory: params["aws_secret_access_key"] = secret_access_key elif config.api_key: params["aws_access_key_id"] = config.api_key - + # 设置 region if config.base_url: params["region_name"] = config.base_url elif "region_name" not in params: params["region_name"] = "us-east-1" # 默认区域 - + return params else: raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) - + @classmethod def get_rerank_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]: """根据提供商获取模型参数""" provider = config.provider.lower() if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: - return { + return { "model": config.model_name, # "base_url": config.base_url, "jina_api_key": config.api_key, **config.extra_params - } + } else: raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) -def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]: + +def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelType.LLM) -> type[BaseLLM]: """根据模型提供商获取对应的模型类""" provider = config.provider.lower() - + # dashscope 的 omni 模型使用 OpenAI 兼容模式 if provider == ModelProvider.DASHSCOPE and config.is_omni: from langchain_openai import ChatOpenAI return ChatOpenAI - - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : if type == ModelType.LLM: from langchain_openai import OpenAI - return OpenAI + return OpenAI elif type == ModelType.CHAT: from langchain_openai import ChatOpenAI return ChatOpenAI elif provider == ModelProvider.DASHSCOPE: from langchain_community.chat_models import ChatTongyi return ChatTongyi - elif provider == ModelProvider.OLLAMA: + elif provider == ModelProvider.OLLAMA: from langchain_ollama import OllamaLLM return OllamaLLM elif provider == ModelProvider.BEDROCK: @@ -183,15 +186,16 @@ def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType. else: raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) + def get_provider_embedding_class(provider: str) -> type[Embeddings]: """根据模型提供商获取对应的模型类""" provider = provider.lower() - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: from langchain_openai import OpenAIEmbeddings - return OpenAIEmbeddings + return OpenAIEmbeddings elif provider == ModelProvider.DASHSCOPE: from langchain_community.embeddings import DashScopeEmbeddings - return DashScopeEmbeddings + return DashScopeEmbeddings elif provider == ModelProvider.OLLAMA: from langchain_ollama import OllamaEmbeddings return OllamaEmbeddings @@ -201,14 +205,15 @@ def get_provider_embedding_class(provider: str) -> type[Embeddings]: else: raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) + def get_provider_rerank_class(provider: str): """根据模型提供商获取对应的模型类""" - provider = provider.lower() - if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + provider = provider.lower() + if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: from langchain_community.document_compressors import JinaRerank - return JinaRerank - # elif provider == ModelProvider.OLLAMA: + return JinaRerank + # elif provider == ModelProvider.OLLAMA: # from langchain_ollama import OllamaEmbeddings # return OllamaEmbeddings else: - raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) \ No newline at end of file + raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) diff --git a/api/app/core/workflow/nodes/agent/node.py b/api/app/core/workflow/nodes/agent/node.py index 98d8bb75..3fbbbdbc 100644 --- a/api/app/core/workflow/nodes/agent/node.py +++ b/api/app/core/workflow/nodes/agent/node.py @@ -16,7 +16,7 @@ from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.variable.base_variable import VariableType from app.db import get_db from app.models import AppRelease -from app.services.draft_run_service import DraftRunService +from app.services.draft_run_service import AgentRunService logger = logging.getLogger(__name__) @@ -39,7 +39,7 @@ class AgentNode(BaseNode): def _output_types(self) -> dict[str, VariableType]: return {"output": VariableType.STRING} - def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]: + def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]: """准备 Agent(公共逻辑) Args: @@ -65,7 +65,7 @@ class AgentNode(BaseNode): if not release: raise ValueError(f"Agent 不存在: {agent_id}") - draft_service = DraftRunService(db) + draft_service = AgentRunService(db) return draft_service, release, message diff --git a/api/app/schemas/api_key_schema.py b/api/app/schemas/api_key_schema.py index d19cf061..323c1a69 100644 --- a/api/app/schemas/api_key_schema.py +++ b/api/app/schemas/api_key_schema.py @@ -155,8 +155,7 @@ class ApiKey(BaseModel): return datetime.datetime.now() > self.expires_at @field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at') - @classmethod - def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]: + def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]: """将datetime转换为时间戳""" return datetime_to_timestamp(v) @@ -171,8 +170,7 @@ class ApiKeyStats(BaseModel): avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)") @field_serializer('last_used_at') - @classmethod - def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]: + def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]: """将datetime转换为时间戳""" return datetime_to_timestamp(v) @@ -219,7 +217,6 @@ class ApiKeyLog(BaseModel): created_at: datetime.datetime @field_serializer('created_at') - @classmethod - def serialize_datetime(cls, v: datetime.datetime) -> int: + def serialize_datetime(self, v: datetime.datetime) -> int: """将datetime转换为时间戳""" return datetime_to_timestamp(v) diff --git a/api/app/schemas/multi_agent_schema.py b/api/app/schemas/multi_agent_schema.py index 8fba2929..3573e87c 100644 --- a/api/app/schemas/multi_agent_schema.py +++ b/api/app/schemas/multi_agent_schema.py @@ -64,14 +64,14 @@ class ExecutionConfig(BaseModel): class MultiAgentConfigCreate(BaseModel): """创建多 Agent 配置""" master_agent_id: uuid.UUID = Field(..., description="主 Agent ID") - master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称") + master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称") orchestration_mode: str = Field( default="collaboration", pattern="^(collaboration|supervisor)$", description="协作模式:collaboration(协作)| supervisor(监督)" ) sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表") - routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则") + routing_rules: Optional[List[RoutingRule]] = Field(default=None, description="路由规则") execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置") aggregation_strategy: str = Field( default="merge", @@ -83,7 +83,7 @@ class MultiAgentConfigCreate(BaseModel): class MultiAgentConfigUpdate(BaseModel): """更新多 Agent 配置""" master_agent_id: Optional[uuid.UUID] = None - master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称") + master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称") default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID") model_parameters: Optional[ModelParameters] = Field( None, diff --git a/api/app/services/agent_tools.py b/api/app/services/agent_tools.py index 3ca7bddd..a4768b51 100644 --- a/api/app/services/agent_tools.py +++ b/api/app/services/agent_tools.py @@ -263,8 +263,8 @@ def create_agent_invocation_tool( try: # 9. 调用 Agent - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) result = await draft_service.run( agent_config=agent_config, diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index e6ac227b..5430d2f9 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -10,25 +10,24 @@ from sqlalchemy.orm import Session from app.core.agent.agent_middleware import AgentMiddleware from app.core.agent.langchain_agent import LangChainAgent -from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger -from app.db import get_db, get_db_context -from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig -from app.schemas import DraftRunRequest -from app.schemas.app_schema import FileInput -from app.services.tool_service import ToolService -from app.repositories.tool_repository import ToolRepository from app.db import get_db from app.models import MultiAgentConfig, AgentConfig +from app.models import WorkflowConfig +from app.repositories.tool_repository import ToolRepository +from app.schemas import DraftRunRequest +from app.schemas.app_schema import FileInput from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.services.conversation_service import ConversationService -from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool +from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \ + AgentRunService from app.services.draft_run_service import create_web_search_tool from app.services.model_service import ModelApiKeyService from app.services.multi_agent_orchestrator import MultiAgentOrchestrator -from app.services.workflow_service import WorkflowService from app.services.multimodal_service import MultimodalService +from app.services.tool_service import ToolService +from app.services.workflow_service import WorkflowService logger = get_business_logger() @@ -39,6 +38,8 @@ class AppChatService: def __init__(self, db: Session): self.db = db self.conversation_service = ConversationService(db) + self.agent_service = AgentRunService(db) + self.workflow_service = WorkflowService(db) async def agnet_chat( self, @@ -55,12 +56,10 @@ class AppChatService: files: Optional[List[FileInput]] = None # 新增:多模态文件 ) -> Dict[str, Any]: """聊天(非流式)""" - start_time = time.time() config_id = None - if variables is None: - variables = {} + variables = self.agent_service.prepare_variables(variables, config.variables) # 获取模型配置ID model_config_id = config.default_model_config_id @@ -79,74 +78,20 @@ class AppChatService: tools = [] # 获取工具服务 - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) - # 从配置中获取启用的工具 - if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): - for tool_config in config.tools: - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): - web_tools = config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(config, 'skills') and config.skills: - skills = config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - knowledge_retrieval = config.knowledge_retrieval - if knowledge_retrieval: - knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) - kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) - tools.append(kb_tool) - - # 添加长期记忆工具 + tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id)) + skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) memory_flag = False - if memory == True: - memory_config = config.memory - if memory_config.get("enabled") and user_id: - memory_flag = True - memory_tool = create_long_term_memory_tool(memory_config, user_id) - tools.append(memory_tool) + if memory: + memory_tools, memory_flag = self.agent_service.load_memory_config( + config.memory, user_id, storage_type, user_rag_memory_id + ) + tools.extend(memory_tools) # 获取模型参数 model_parameters = config.model_parameters @@ -246,10 +191,9 @@ class AppChatService: try: start_time = time.time() config_id = None + yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" - if variables is None: - variables = {} - + variables = self.agent_service.prepare_variables(variables, config.variables) # 获取模型配置ID model_config_id = config.default_model_config_id api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) @@ -267,73 +211,22 @@ class AppChatService: tools = [] # 获取工具服务 - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) - if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): - for tool_config in config.tools: - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): - web_tools = config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(config, 'skills') and config.skills: - skills = config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - knowledge_retrieval = config.knowledge_retrieval - if knowledge_retrieval: - knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) - kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) - tools.append(kb_tool) + tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id)) + skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) # 添加长期记忆工具 memory_flag = False if memory: - memory_config = config.memory - if memory_config.get("enabled") and user_id: - memory_flag = True - memory_tool = create_long_term_memory_tool(memory_config, user_id) - tools.append(memory_tool) + memory_tools, memory_flag = self.agent_service.load_memory_config( + config.memory, user_id, storage_type, user_rag_memory_id + ) + tools.extend(memory_tools) # 获取模型参数 model_parameters = config.model_parameters @@ -372,9 +265,6 @@ class AppChatService: processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") - # 发送开始事件 - yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" - # 流式调用 Agent(支持多模态) full_content = "" total_tokens = 0 @@ -418,7 +308,7 @@ class AppChatService: ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) # 发送结束事件 - end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)} + end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None} yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" logger.info( @@ -437,7 +327,7 @@ class AppChatService: except Exception as e: logger.error(f"流式聊天失败: {str(e)}", exc_info=True) # 发送错误事件 - yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" + yield f"event: end\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" async def multi_agent_chat( self, @@ -491,10 +381,10 @@ class AppChatService: "mode": result.get("mode"), "elapsed_time": result.get("elapsed_time"), "usage": result.get("usage", { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - }) + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) } ) @@ -524,8 +414,6 @@ class AppChatService: """多 Agent 聊天(流式)""" start_time = time.time() - actual_config_id = None - config_id = actual_config_id if variables is None: variables = {} @@ -631,7 +519,6 @@ class AppChatService: user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """聊天(非流式)""" - workflow_service = WorkflowService(self.db) payload = DraftRunRequest( message=message, variables=variables, @@ -639,7 +526,7 @@ class AppChatService: stream=True, user_id=user_id ) - return await workflow_service.run( + return await self.workflow_service.run( app_id=app_id, payload=payload, config=config, @@ -666,7 +553,6 @@ class AppChatService: ) -> AsyncGenerator[dict, None]: """聊天(流式)""" - workflow_service = WorkflowService(self.db) payload = DraftRunRequest( message=message, variables=variables, @@ -675,7 +561,7 @@ class AppChatService: user_id=user_id, files=files ) - async for event in workflow_service.run_stream( + async for event in self.workflow_service.run_stream( app_id=app_id, payload=payload, config=config, diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index c5919af9..a248f869 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1791,372 +1791,6 @@ class AppService: return shares - # ==================== 试运行功能 ==================== - - async def draft_run( - self, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None - ) -> Dict[str, Any]: - """试运行 Agent(使用当前草稿配置) - - Args: - app_id: 应用ID - message: 用户消息 - conversation_id: 会话ID(用于多轮对话) - user_id: 用户ID(用于会话管理) - variables: 自定义变量参数值 - workspace_id: 工作空间ID(用于权限验证) - - Returns: - Dict: 包含 AI 回复和元数据的字典 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用类型不支持或配置缺失时 - """ - from app.services.draft_run_service import DraftRunService - - logger.info("试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]}) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - - if not agent_cfg: - raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 3. 获取模型配置 - model_config = None - if agent_cfg.default_model_config_id: - from app.models import ModelConfig - model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - - if not model_config: - raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 4. 调用试运行服务 - logger.debug( - "准备调用试运行服务", - extra={ - "app_id": str(app_id), - "model": model_config.name, - "has_conversation_id": bool(conversation_id), - "has_variables": bool(variables) - } - ) - - draft_service = DraftRunService(self.db) - result = await draft_service.run( - agent_config=agent_cfg, - model_config=model_config, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables - ) - - logger.debug( - "试运行服务返回结果", - extra={ - "result_type": str(type(result)), - "result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict", - "has_message": "message" in result if isinstance(result, dict) else False, - "has_conversation_id": "conversation_id" in result if isinstance(result, dict) else False - } - ) - - logger.info( - "试运行完成", - extra={ - "app_id": str(app_id), - "elapsed_time": result.get("elapsed_time"), - "model": model_config.name - } - ) - - return result - - async def draft_run_stream( - self, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None - ): - """试运行 Agent(流式返回) - - Args: - app_id: 应用ID - message: 用户消息 - conversation_id: 会话ID(用于多轮对话) - user_id: 用户ID(用于会话管理) - variables: 自定义变量参数值 - workspace_id: 工作空间ID(用于权限验证) - - Yields: - str: SSE 格式的事件数据 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用类型不支持或配置缺失时 - """ - from app.services.draft_run_service import DraftRunService - - logger.info("流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]}) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - - if not agent_cfg: - raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 3. 获取模型配置 - model_config = None - if agent_cfg.default_model_config_id: - from app.models import ModelConfig - model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - - if not model_config: - raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 4. 调用流式试运行服务 - draft_service = DraftRunService(self.db) - async for event in draft_service.run_stream( - agent_config=agent_cfg, - model_config=model_config, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables - ): - yield event - - # ==================== 多模型对比试运行 ==================== - - async def draft_run_compare( - self, - *, - app_id: uuid.UUID, - message: str, - models: List[app_schema.ModelCompareItem], - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None, - parallel: bool = True, - timeout: int = 60 - ) -> Dict[str, Any]: - """多模型对比试运行 - - Args: - app_id: 应用ID - message: 用户消息 - models: 要对比的模型列表 - conversation_id: 会话ID - user_id: 用户ID - variables: 变量参数 - workspace_id: 工作空间ID - parallel: 是否并行执行 - timeout: 超时时间(秒) - - Returns: - Dict: 对比结果 - """ - from app.models import ModelConfig - from app.services.draft_run_service import DraftRunService - - logger.info( - "多模型对比试运行", - extra={ - "app_id": str(app_id), - "model_count": len(models), - "parallel": parallel - } - ) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - if not agent_cfg: - raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - - # 3. 准备所有模型配置 - model_configs = [] - for model_item in models: - model_config = self.db.get(ModelConfig, model_item.model_config_id) - if not model_config: - raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) - - # 合并参数:agent配置参数 + 请求覆盖参数 - merged_parameters = { - **(agent_cfg.model_parameters or {}), - **(model_item.model_parameters or {}) - } - - model_configs.append({ - "model_config": model_config, - "parameters": merged_parameters, - "label": model_item.label or model_config.name, - "model_config_id": model_item.model_config_id - }) - - # 4. 调用 DraftRunService 的对比方法 - draft_service = DraftRunService(self.db) - result = await draft_service.run_compare( - agent_config=agent_cfg, - models=model_configs, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - parallel=parallel, - timeout=timeout - ) - - logger.info( - "多模型对比完成", - extra={ - "app_id": str(app_id), - "successful": result["successful_count"], - "failed": result["failed_count"] - } - ) - - return result - - async def draft_run_compare_stream( - self, - *, - app_id: uuid.UUID, - message: str, - models: List[app_schema.ModelCompareItem], - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None, - parallel: bool = True, - timeout: int = 60 - ): - """多模型对比试运行(流式返回) - - Args: - app_id: 应用ID - message: 用户消息 - models: 要对比的模型列表 - conversation_id: 会话ID - user_id: 用户ID - variables: 变量参数 - workspace_id: 工作空间ID - timeout: 超时时间(秒) - - Yields: - str: SSE 格式的事件数据 - """ - from app.models import ModelConfig - from app.services.draft_run_service import DraftRunService - - logger.info( - "多模型对比流式试运行", - extra={ - "app_id": str(app_id), - "model_count": len(models) - } - ) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - if not agent_cfg: - raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - - # 3. 准备所有模型配置 - model_configs = [] - for model_item in models: - model_config = self.db.get(ModelConfig, model_item.model_config_id) - if not model_config: - raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) - - # 合并参数:agent配置参数 + 请求覆盖参数 - merged_parameters = { - **(agent_cfg.model_parameters or {}), - **(model_item.model_parameters or {}) - } - - model_configs.append({ - "model_config": model_config, - "parameters": merged_parameters, - "label": model_item.label or model_config.name, - "model_config_id": model_item.model_config_id - }) - - # 4. 调用 DraftRunService 的流式对比方法 - draft_service = DraftRunService(self.db) - async for event in draft_service.run_compare_stream( - agent_config=agent_cfg, - models=model_configs, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - parallel=parallel, - timeout=timeout - ): - yield event - - logger.info( - "多模型对比流式完成", - extra={"app_id": str(app_id)} - ) - - # ==================== 向后兼容的函数接口 ==================== # 保留函数接口以兼容现有代码,但内部使用服务类 @@ -2278,53 +1912,6 @@ def get_apps_by_ids( return service.get_apps_by_ids(app_ids, workspace_id) -# ==================== 向后兼容的函数接口 ==================== - -async def draft_run( - db: Session, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None -) -> Dict[str, Any]: - """试运行 Agent(向后兼容接口)""" - service = AppService(db) - return await service.draft_run( - app_id=app_id, - message=message, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - workspace_id=workspace_id - ) - - -async def draft_run_stream( - db: Session, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None -): - """试运行 Agent 流式返回(向后兼容接口)""" - service = AppService(db) - async for event in service.draft_run_stream( - app_id=app_id, - message=message, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - workspace_id=workspace_id - ): - yield event - - # ==================== 依赖注入函数 ==================== def get_app_service( diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 693f1a26..0cf68be2 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -17,6 +17,7 @@ from sqlalchemy.orm import Session from app.celery_app import celery_app from app.core.agent.agent_middleware import AgentMiddleware +from app.core.agent.langchain_agent import LangChainAgent from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger @@ -26,6 +27,7 @@ from app.repositories.tool_repository import ToolRepository from app.schemas.app_schema import FileInput from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.services import task_service +from app.services.conversation_service import ConversationService from app.services.langchain_tool_server import Search from app.services.memory_agent_service import MemoryAgentService from app.services.model_parameter_merger import ModelParameterMerger @@ -52,8 +54,12 @@ class LongTermMemoryInput(BaseModel): description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写") -def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None): +def create_long_term_memory_tool( + memory_config: Dict[str, Any], + end_user_id: str, + storage_type: Optional[str] = None, + user_rag_memory_id: Optional[str] = None +): """创建记忆工具, @@ -61,6 +67,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str memory_config: 记忆配置 end_user_id: 用户ID storage_type: 存储类型(可选) + user_rag_memory_id: 用户RAG记忆ID(可选) Returns: 长期记忆工具 @@ -188,7 +195,9 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 Args: - query: 需要检索的问题或关键词 + kb_config: 知识库配置 + kb_ids: 知识库ID列表 + user_id: 用户ID Returns: 检索到的相关知识内容 @@ -232,17 +241,141 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): return knowledge_retrieval_tool -class DraftRunService: - """试运行服务类""" +class AgentRunService: + """Agent运行服务类""" def __init__(self, db: Session): - """初始化试运行服务 + """Agent运行服务 Args: db: 数据库会话 """ self.db = db + @staticmethod + def prepare_variables( + input_vars: dict | None, + variables_config: dict | None + ) -> dict: + input_vars = input_vars or {} + for variable in variables_config: + if variable.get("required") and variable.get("name") not in input_vars: + raise ValueError(f"The required parameter '{variable.get('name')}' was not provided") + return input_vars + + def load_tools_config(self, tools_config, web_search, tenant_id) -> list: + """加载工具配置""" + if not tools_config: + return [] + tools = [] + tool_service = ToolService(self.db) + + if tools_config and isinstance(tools_config, list): + for tool_config in tools_config: + if tool_config.get("enabled", False): + # 根据工具名称查找工具实例 + tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id) + if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue + # 转换为LangChain工具 + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + elif tools_config and isinstance(tools_config, dict): + web_search_choice = tools_config.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search and web_search_enable: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) + return tools + + def load_skill_config( + self, + skills_config: dict | None, + message: str, tenant_id + ) -> tuple[list, str]: + if not skills_config: + return [], "" + + tools = [] + skill_prompts = "" + skill_enable = skills_config.get("enabled", False) + if skill_enable: + middleware = AgentMiddleware(skills=skills_config) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, + tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + skill_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + + return tools, skill_prompts + + def load_knowledge_retrieval_config( + self, + knowledge_retrieval_config: dict | None, + user_id + ) -> list: + if not knowledge_retrieval_config: + return [] + + tools = [] + knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", []) + kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) + if kb_ids: + # 创建知识库检索工具 + kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id) + tools.append(kb_tool) + + logger.debug( + "已添加知识库检索工具", + extra={ + "kb_ids": kb_ids, + "tool_count": len(tools) + } + ) + return tools + + def load_memory_config( + self, + memory_config: dict | None, + user_id, + storage_type, + user_rag_memory_id + ) -> tuple[list, bool]: + """加载长期记忆配置""" + if not memory_config: + return [], False + + tools = [] + if memory_config.get("enabled"): + if user_id: + # 创建长期记忆工具 + memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type, + user_rag_memory_id) + tools.append(memory_tool) + + logger.debug( + "已添加长期记忆工具", + extra={ + "user_id": user_id, + "tool_count": len(tools) + } + ) + return tools, bool(memory_config.get("enabled")) + async def run( self, *, @@ -270,19 +403,21 @@ class DraftRunService: conversation_id: 会话ID(用于多轮对话) user_id: 用户ID variables: 自定义变量参数值 + storage_type: 存储类型(可选) + user_rag_memory_id: 用户RAG记忆ID(可选) + web_search: 是否启用网络搜索(默认True) + memory: 是否启用长期记忆(默认True) + sub_agent: 是否为子代理调用(默认False) + files: 多模态文件列表(可选) Returns: Dict: 包含 AI 回复和元数据的字典 """ - memory_flag = False - - print('===========', storage_type) - - print(user_id) - if variables == None: variables = {} - from app.core.agent.langchain_agent import LangChainAgent - start_time = time.time() + tools_config: dict | list | None = agent_config.tools + skills_config: dict | None = agent_config.skills + knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval + memory_config: dict | None = agent_config.memory try: # 1. 获取 API Key 配置 @@ -302,112 +437,40 @@ class DraftRunService: agent_config=agent_config ) - items_params = variables + if sub_agent: + variables = self.prepare_variables(variables, agent_config.variables) + else: + # FIXME: subagent input valid + variables = variables or {} + system_prompt = render_prompt_message( - agent_config.system_prompt, # 修正拼写错误 + agent_config.system_prompt, PromptMessageRole.USER, - items_params + variables ) # 3. 处理系统提示词(支持变量替换) system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" - print('系统提示词:', system_prompt) # 4. 准备工具列表 tools = [] - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 - if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): - if hasattr(agent_config, 'tools') and agent_config.tools: - for tool_config in agent_config.tools: - print("+" * 50) - print(f"agent_config:{agent_config}") - print(f"tool_config:{tool_config}") - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict): - web_tools = agent_config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(agent_config, 'skills') and agent_config.skills: - skills = agent_config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - if agent_config.knowledge_retrieval: - kb_config = agent_config.knowledge_retrieval - knowledge_bases = kb_config.get("knowledge_bases", []) - kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) - if kb_ids: - # 创建知识库检索工具 - kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id) - tools.append(kb_tool) - - logger.debug( - "已添加知识库检索工具", - extra={ - "kb_ids": kb_ids, - "tool_count": len(tools) - } - ) - + tools.extend(self.load_tools_config(tools_config, web_search, tenant_id)) + skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) # 添加长期记忆工具 + memory_flag = False if memory: - if agent_config.memory and agent_config.memory.get("enabled"): - memory_flag = True - - memory_config = agent_config.memory - if user_id: - # 创建长期记忆工具 - memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type, - user_rag_memory_id) - tools.append(memory_tool) - - logger.debug( - "已添加长期记忆工具", - extra={ - "user_id": user_id, - "tool_count": len(tools) - } - ) + memory_tools, memory_flag = self.load_memory_config( + memory_config, user_id, storage_type, user_rag_memory_id + ) + tools.extend(memory_tools) # 4. 创建 LangChain Agent agent = LangChainAgent( @@ -432,7 +495,7 @@ class DraftRunService: # 6. 加载历史消息 history = [] - if agent_config.memory and agent_config.memory.get("enabled"): + if memory_config and memory_config.get("enabled"): history = await self._load_conversation_history( conversation_id=conversation_id, max_history=agent_config.memory.get("max_history", 10) @@ -482,7 +545,7 @@ class DraftRunService: ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id")) # 9. 保存会话消息 - if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): + if not sub_agent and memory_config and memory_config.get("enabled"): await self._save_conversation_message( conversation_id=conversation_id, user_message=message, @@ -557,16 +620,21 @@ class DraftRunService: Yields: str: SSE 格式的事件数据 """ - memory_flag = False - if variables == None: variables = {} - - from app.core.agent.langchain_agent import LangChainAgent + tools_config: dict | list | None = agent_config.tools + skills_config: dict | None = agent_config.skills + knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval + memory_config: dict | None = agent_config.memory start_time = time.time() try: # 1. 获取 API Key 配置 api_key_config = await self._get_api_key(model_config.id) + if not sub_agent: + variables = self.prepare_variables(variables, agent_config.variables) + else: + # FIXME: subagent input valid + variables = variables or {} # 2. 合并模型参数 effective_params = ModelParameterMerger.get_effective_parameters( @@ -588,95 +656,22 @@ class DraftRunService: # 4. 准备工具列表 tools = [] - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 - if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): - for tool_config in agent_config.tools: - # print("+"*50) - # print(f"agent_config:{agent_config}") - # print(f"tool_config:{tool_config}") - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict): - web_tools = agent_config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) + tools.extend(self.load_tools_config(tools_config, web_search, tenant_id)) + skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(agent_config, 'skills') and agent_config.skills: - skills = agent_config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - if agent_config.knowledge_retrieval: - kb_config = agent_config.knowledge_retrieval - knowledge_bases = kb_config.get("knowledge_bases", []) - kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) - if kb_ids: - # 创建知识库检索工具 - kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id) - tools.append(kb_tool) - - logger.debug( - "已添加知识库检索工具", - extra={ - "kb_ids": kb_ids, - "tool_count": len(tools) - } - ) # 添加长期记忆工具 + memory_flag = False if memory: - if agent_config.memory and agent_config.memory.get("enabled"): - memory_flag = True - memory_config = agent_config.memory - if user_id: - # 创建长期记忆工具 - memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type, - user_rag_memory_id) - tools.append(memory_tool) - - logger.debug( - "已添加长期记忆工具", - extra={ - "user_id": user_id, - "tool_count": len(tools) - } - ) + memory_tools, memory_flag = self.load_memory_config(memory_config, user_id, storage_type, + user_rag_memory_id) + tools.extend(memory_tools) # 4. 创建 LangChain Agent agent = LangChainAgent( @@ -702,10 +697,10 @@ class DraftRunService: # 6. 加载历史消息 history = [] - if agent_config.memory and agent_config.memory.get("enabled"): + if memory_config and memory_config.get("enabled"): history = await self._load_conversation_history( conversation_id=conversation_id, - max_history=agent_config.memory.get("max_history", 10) + max_history=memory_config.get("max_history", 10) ) # 6. 处理多模态文件 @@ -763,7 +758,7 @@ class DraftRunService: }) # 10. 保存会话消息 - if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): + if not sub_agent and memory_config and memory_config.get("enabled"): await self._save_conversation_message( conversation_id=conversation_id, user_message=message, @@ -969,7 +964,6 @@ class DraftRunService: List[Dict]: 历史消息列表 """ try: - from app.services.conversation_service import ConversationService conversation_service = ConversationService(self.db) history = conversation_service.get_conversation_history( @@ -1489,6 +1483,15 @@ class DraftRunService: "conversation_id": returned_conversation_id, "content": chunk })) + + if event_type == "error" and event_data: + await event_queue.put(self._format_sse_event("model_error", { + "model_index": idx, + "model_config_id": model_config_id, + "label": model_label, + "conversation_id": returned_conversation_id, + "error": event_data.get("error", "未知错误") + })) except Exception as e: logger.warning(f"解析流式事件失败: {e}") finally: @@ -1673,41 +1676,3 @@ class DraftRunService: "total_time": sum(r.get("elapsed_time", 0) for r in results) } ) - - -async def draft_run( - db: Session, - *, - agent_config: AgentConfig, - model_config: ModelConfig, - message: str, - user_id: Optional[str] = None, - kb_ids: Optional[List[str]] = None, - similarity_threshold: float = 0.7, - top_k: int = 3 -) -> Dict[str, Any]: - """试运行 Agent(便捷函数) - - Args: - db: 数据库会话 - agent_config: Agent 配置 - model_config: 模型配置 - message: 用户消息 - user_id: 用户ID - kb_ids: 知识库ID列表 - similarity_threshold: 相似度阈值 - top_k: 检索返回的文档数量 - - Returns: - Dict: 包含 AI 回复和元数据的字典 - """ - service = DraftRunService(db) - return await service.run( - agent_config=agent_config, - model_config=model_config, - message=message, - user_id=user_id, - kb_ids=kb_ids, - similarity_threshold=similarity_threshold, - top_k=top_k - ) diff --git a/api/app/services/langchain_tool_server.py b/api/app/services/langchain_tool_server.py index f44e4cdc..2c151956 100644 --- a/api/app/services/langchain_tool_server.py +++ b/api/app/services/langchain_tool_server.py @@ -9,6 +9,8 @@ load_dotenv() # 读取web_search环境变量 web_search_value = os.getenv('web_search') + + def Search(query): url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions" api_key = web_search_value @@ -18,23 +20,24 @@ def Search(query): "role": "user", "content": query } - ], #搜索输入 - "edition":"standard", #搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。 - "search_source": "baidu_search_v2", #使用的搜索引擎版本 - "resource_type_filter": [{"type": "web","top_k": 20}], #支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5 + ], # 搜索输入 + "edition": "standard", # 搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。 + "search_source": "baidu_search_v2", # 使用的搜索引擎版本 + "resource_type_filter": [{"type": "web", "top_k": 20}], + # 支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5 "search_filter": { "range": { "page_time": { - "gte": "now-1w/d", #时间查询参数,大于或等于 - "lt": "now/d", #时间查询参数,小于 - "gt": "", #时间查询参数,大于 - "lte": "" #时间查询参数,小于或等于 + "gte": "now-1w/d", # 时间查询参数,大于或等于 + "lt": "now/d", # 时间查询参数,小于 + "gt": "", # 时间查询参数,大于 + "lte": "" # 时间查询参数,小于或等于 } } }, - "block_websites":["tieba.baidu.com"], #需要屏蔽的站点列表 - "search_recency_filter":"week", #根据网页发布时间进行筛选,可填值为:week,month,semiyear,year - "enable_full_content":True #是否输出网页完整原文 + "block_websites": ["tieba.baidu.com"], # 需要屏蔽的站点列表 + "search_recency_filter": "week", # 根据网页发布时间进行筛选,可填值为:week,month,semiyear,year + "enable_full_content": True # 是否输出网页完整原文 }, ensure_ascii=False) headers = { 'Content-Type': 'application/json', @@ -42,10 +45,10 @@ def Search(query): } response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json() - content=[] + content = [] for i in response['references']: - title=i['title'] - snippet=i['snippet'] - content.append(title+';'+snippet) - content='。'.join(content) - return content \ No newline at end of file + title = i['title'] + snippet = i['snippet'] + content.append(title + ';' + snippet) + content = '。'.join(content) + return content diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index 650f639b..f42ee95a 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -123,11 +123,14 @@ class MultiAgentOrchestrator: user_id: 用户 ID variables: 变量参数 use_llm_routing: 是否使用 LLM 路由 + web_search: 是否启用网络搜索 + memory: 是否启用记忆功能 + storage_type: 存储类型 + user_rag_memory_id: 用户 RAG 记忆 ID Yields: SSE 格式的事件流 """ - import json start_time = time.time() @@ -200,7 +203,8 @@ class MultiAgentOrchestrator: except Exception as e: logger.error( "多 Agent 任务执行失败(流式)", - extra={"error": str(e), "mode": self._normalized_mode} + extra={"error": str(e), "mode": self._normalized_mode}, + exc_info=True ) # 发送错误事件 yield self._format_sse_event("error", { @@ -1267,7 +1271,7 @@ class MultiAgentOrchestrator: Yields: SSE 格式的事件流 """ - from app.services.draft_run_service import DraftRunService + from app.services.draft_run_service import AgentRunService # 获取模型配置 model_config = self.db.get(ModelConfig, agent_config.default_model_config_id) @@ -1278,7 +1282,7 @@ class MultiAgentOrchestrator: ) # 流式执行 Agent - draft_service = DraftRunService(self.db) + draft_service = AgentRunService(self.db) async for event in draft_service.run_stream( agent_config=agent_config, model_config=model_config, @@ -1320,7 +1324,7 @@ class MultiAgentOrchestrator: Returns: 执行结果 """ - from app.services.draft_run_service import DraftRunService + from app.services.draft_run_service import AgentRunService # 获取模型配置 model_config = self.db.get(ModelConfig, agent_config.default_model_config_id) @@ -1331,7 +1335,7 @@ class MultiAgentOrchestrator: ) # 执行 Agent - draft_service = DraftRunService(self.db) + draft_service = AgentRunService(self.db) result = await draft_service.run( agent_config=agent_config, model_config=model_config, @@ -1633,6 +1637,7 @@ class MultiAgentOrchestrator: self.memory = config_data.get("memory") self.variables = config_data.get("variables", []) self.tools = config_data.get("tools", {}) + self.skills = config_data.get("skills", {}) self.default_model_config_id = release.default_model_config_id return AgentConfigProxy(release, app, config_data) diff --git a/api/app/services/skill_service.py b/api/app/services/skill_service.py index 5eb80795..0b7de6cf 100644 --- a/api/app/services/skill_service.py +++ b/api/app/services/skill_service.py @@ -121,7 +121,7 @@ class SkillService: if skill and skill.is_active: # 加载技能关联的工具 for tool_config in skill.tools: - tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) + tool = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool: langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None)) tools.append(langchain_tool) diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 2bb96e53..d2400ded 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -209,7 +209,7 @@ class ToolService: try: # 获取工具实例 - tool = self._get_tool_instance(tool_id, tenant_id) + tool = self.get_tool_instance(tool_id, tenant_id) if not tool: return ToolResult.error_result( error=f"工具不存在: {tool_id}", @@ -335,7 +335,7 @@ class ToolService: return [] # 获取工具实例 - tool_instance = self._get_tool_instance(str(config.id), config.tenant_id) + tool_instance = self.get_tool_instance(str(config.id), config.tenant_id) if not tool_instance: return [] @@ -792,7 +792,7 @@ class ToolService: """获取工具配置""" return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id) - def _get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]: + def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]: """获取工具实例""" if tool_id in self._tool_cache: return self._tool_cache[tool_id] @@ -1416,7 +1416,7 @@ class ToolService: """测试内置工具连接""" try: # 获取工具实例 - tool_instance = self._get_tool_instance(str(config.id), config.tenant_id) + tool_instance = self.get_tool_instance(str(config.id), config.tenant_id) if not tool_instance: return {"success": False, "message": "无法创建工具实例"} From a72d5d2c7768e090793172b7165e4a9f5ab2430e Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Thu, 5 Mar 2026 11:18:48 +0800 Subject: [PATCH 097/164] fix(workflow): add backward compatibility for old dify configs --- api/app/core/workflow/adapters/dify/converter.py | 6 ++++-- api/app/core/workflow/adapters/dify/dify_adapter.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/api/app/core/workflow/adapters/dify/converter.py b/api/app/core/workflow/adapters/dify/converter.py index 2014b4c3..06c988d3 100644 --- a/api/app/core/workflow/adapters/dify/converter.py +++ b/api/app/core/workflow/adapters/dify/converter.py @@ -98,7 +98,7 @@ class DifyConverter(BaseConverter): if not var_selector: return "" selector = var_selector.split('.') - if len(selector) not in [2, 3]: + if len(selector) not in [2, 3] and var_selector != "context": raise Exception(f"invalid variable selector: {var_selector}") if len(selector) == 3: selector = selector[1:] @@ -332,7 +332,9 @@ class DifyConverter(BaseConverter): messages.append( MessageConfig( role="user", - content=self.trans_variable_format(node_data["memory"]["query_prompt_template"]) + content=self.trans_variable_format( + node_data["memory"].get("query_prompt_template", "{{#sys.query#}}") + ) ) ) vision = node_data["vision"]["enabled"] diff --git a/api/app/core/workflow/adapters/dify/dify_adapter.py b/api/app/core/workflow/adapters/dify/dify_adapter.py index dcd14c7f..6336b1f9 100644 --- a/api/app/core/workflow/adapters/dify/dify_adapter.py +++ b/api/app/core/workflow/adapters/dify/dify_adapter.py @@ -80,7 +80,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): return True def validate_config(self) -> bool: - require_fields = frozenset({'app', 'dependencies', 'kind', 'version', 'workflow'}) + require_fields = frozenset({'app', 'kind', 'version', 'workflow'}) if not all(field in self.config for field in require_fields): return False From 78ce2a9a8b557082c454fba875165db39ae05749 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Thu, 5 Mar 2026 14:07:27 +0800 Subject: [PATCH 098/164] feat(workflow): support multimodal input --- api/app/core/workflow/engine/variable_pool.py | 34 ++++++++++----- api/app/core/workflow/nodes/base_node.py | 41 ++++++++++++++----- .../workflow/nodes/cycle_graph/iteration.py | 4 +- .../core/workflow/nodes/cycle_graph/loop.py | 2 +- api/app/core/workflow/nodes/llm/node.py | 6 +-- .../core/workflow/variable/base_variable.py | 12 ++++-- .../workflow/variable/variable_objects.py | 7 +++- api/app/services/draft_run_service.py | 2 +- api/app/services/workflow_service.py | 14 ++++--- 9 files changed, 84 insertions(+), 38 deletions(-) diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index d08f47e5..bc88df19 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -303,38 +303,52 @@ class VariablePool: """ return self._get_variable_struct(selector) is not None - def get_all_system_vars(self) -> dict[str, Any]: + def get_all_system_vars(self, literal=False) -> dict[str, Any]: """获取所有系统变量 Returns: 系统变量字典 """ sys_namespace = self.variables.get("sys", {}) + if literal: + return {k: v.instance.to_literal() for k, v in sys_namespace.items()} return {k: v.instance.get_value() for k, v in sys_namespace.items()} - def get_all_conversation_vars(self) -> dict[str, Any]: + def get_all_conversation_vars(self, literal=False) -> dict[str, Any]: """获取所有会话变量 Returns: 会话变量字典 """ conv_namespace = self.variables.get("conv", {}) + if literal: + return {k: v.instance.to_literal() for k, v in conv_namespace.items()} return {k: v.instance.get_value() for k, v in conv_namespace.items()} - def get_all_node_outputs(self) -> dict[str, Any]: + def get_all_node_outputs(self, literal=False) -> dict[str, Any]: """获取所有节点输出(运行时变量) Returns: 节点输出字典,键为节点 ID """ - runtime_vars = { - namespace: { - k: v.instance.get_value() - for k, v in vars_dict.items() + if literal: + runtime_vars = { + namespace: { + k: v.instance.to_literal() + for k, v in vars_dict.items() + } + for namespace, vars_dict in self.variables.items() + if namespace not in ("sys", "conv") + } + else: + runtime_vars = { + namespace: { + k: v.instance.get_value() + for k, v in vars_dict.items() + } + for namespace, vars_dict in self.variables.items() + if namespace not in ("sys", "conv") } - for namespace, vars_dict in self.variables.items() - if namespace not in ("sys", "conv") - } return runtime_vars def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None: diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 3e30c00e..3f30718c 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -1,5 +1,6 @@ import asyncio import logging +import uuid from abc import ABC, abstractmethod from functools import cached_property from typing import Any, AsyncGenerator @@ -10,8 +11,10 @@ from app.core.config import settings from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.enums import BRANCH_NODES -from app.core.workflow.variable.base_variable import VariableType -from app.services.multimodal_service import PROVIDER_STRATEGIES +from app.core.workflow.variable.base_variable import VariableType, FileObject +from app.db import get_db_read +from app.schemas import FileInput +from app.services.multimodal_service import MultimodalService logger = logging.getLogger(__name__) @@ -548,9 +551,9 @@ class BaseNode(ABC): return render_template( template=template, - conv_vars=variable_pool.get_all_conversation_vars(), - node_outputs=variable_pool.get_all_node_outputs(), - system_vars=variable_pool.get_all_system_vars(), + conv_vars=variable_pool.get_all_conversation_vars(literal=True), + node_outputs=variable_pool.get_all_node_outputs(literal=True), + system_vars=variable_pool.get_all_system_vars(literal=True), strict=strict ) @@ -614,16 +617,32 @@ class BaseNode(ABC): return variable_pool.has(selector) @staticmethod - async def process_message(provider, content, enable_file=False) -> dict | str | None: + async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None: if isinstance(content, str): if enable_file: return {"text": content} return content - elif isinstance(content, dict): - trans_tool = PROVIDER_STRATEGIES[provider]() - result = await trans_tool.format_image(content["url"]) - return result - raise TypeError('Unexpect input value type') + + elif isinstance(content, FileObject): + if content.content_cache.get(provider): + return content.content_cache[provider] + with get_db_read() as db: + multimodel_service = MultimodalService(db, provider) + message = await multimodel_service.process_files( + [FileInput.model_construct( + type=content.type, + url=content.url, + transfer_method=content.transfer_method, + file_type=content.origin_file_type, + upload_file_id=content.file_id + )] + ) + + if message: + content.content_cache[provider] = message[0] + return message[0] + return None + raise TypeError(f'Unexpect input value type - {type(content)}') @staticmethod def process_model_output(content) -> str: diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py index e4026f2d..cf7ac976 100644 --- a/api/app/core/workflow/nodes/cycle_graph/iteration.py +++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py @@ -91,8 +91,8 @@ class IterationRuntime: return loopstate def merge_conv_vars(self): - self.variable_pool.get_all_conversation_vars().update( - self.child_variable_pool.get_all_conversation_vars() + self.variable_pool.variables["conv"].update( + self.child_variable_pool.variables["conv"] ) async def run_task(self, item, idx): diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index cebadfdc..d3ada1ec 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -156,7 +156,7 @@ class LoopRuntime: def merge_conv_vars(self, loopstate): self.variable_pool.variables["conv"].update( - self.child_variable_pool.variables.get("conv", {}) + self.child_variable_pool.variables["conv"] ) loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False) loopstate["node_outputs"][self.node_id] = loop_vars diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index fdd5df58..c109d59b 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -172,9 +172,9 @@ class LLMNode(BaseNode): if self.typed_config.vision_input and self.typed_config.vision: file_content = [] - files = variable_pool.get_value(self.typed_config.vision_input) - for file in files: - content = await self.process_message(provider, file, self.typed_config.vision) + files = variable_pool.get_instance(self.typed_config.vision_input) + for file in files.value: + content = await self.process_message(provider, file.value, self.typed_config.vision) if content: file_content.append(content) if messages and messages[-1]["role"] == 'user': diff --git a/api/app/core/workflow/variable/base_variable.py b/api/app/core/workflow/variable/base_variable.py index 19cbdc74..dd821ea7 100644 --- a/api/app/core/workflow/variable/base_variable.py +++ b/api/app/core/workflow/variable/base_variable.py @@ -2,7 +2,7 @@ from enum import StrEnum from abc import abstractmethod, ABC from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field from app.schemas import FileType @@ -45,7 +45,7 @@ class VariableType(StrEnum): return cls.NUMBER elif isinstance(var, bool): return cls.BOOLEAN - elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('__file')): + elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')): return cls.FILE elif isinstance(var, dict): return cls.OBJECT @@ -109,7 +109,13 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any: class FileObject(BaseModel): type: FileType url: str - __file: bool + transfer_method: str + origin_file_type: str + file_id: str | None + + content_cache: dict = Field(default_factory=dict) + + is_file: bool class BaseVariable(ABC): diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py index 49541afc..63437fd9 100644 --- a/api/app/core/workflow/variable/variable_objects.py +++ b/api/app/core/workflow/variable/variable_objects.py @@ -63,13 +63,16 @@ class FileVariable(BaseVariable): def valid_value(self, value) -> FileObject: if isinstance(value, dict): - if not value.get("__file"): + if not value.get("is_file"): raise TypeError(f"Value must be a FileObject - {type(value)}:{value}") return FileObject( **{ "type": str(value.get('type')), + "transfer_method": value.get("transfer_method"), "url": value.get('url'), - "__file": True + "file_id": value.get("file_id"), + "origin_file_type": value.get("origin_file_type"), + "is_file": True } ) if isinstance(value, FileObject): diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 0cf68be2..bb68c815 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -255,7 +255,7 @@ class AgentRunService: @staticmethod def prepare_variables( input_vars: dict | None, - variables_config: dict | None + variables_config: dict ) -> dict: input_vars = input_vars or {} for variable in variables_config: diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 02819efb..d13e3454 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -16,6 +16,7 @@ from app.core.workflow.adapters.registry import PlatformAdapterRegistry from app.core.workflow.executor import execute_workflow, execute_workflow_stream from app.core.workflow.nodes.enums import NodeType from app.core.workflow.validator import validate_workflow_config +from app.core.workflow.variable.base_variable import FileObject from app.db import get_db from app.models import App from app.models.workflow_model import WorkflowConfig, WorkflowExecution @@ -453,11 +454,14 @@ class WorkflowService: files_struct = [] for file in files: files_struct.append( - { - "type": file.type, - "url": await self.multimodal_service.get_file_url(file), - "__file": True - } + FileObject( + type=file.type, + url=await self.multimodal_service.get_file_url(file), + transfer_method=file.transfer_method, + file_id=str(file.upload_file_id), + origin_file_type=file.file_type, + is_file=True + ).model_dump() ) return files_struct From b5199b2eb91e90df116bdd430b9ffd9a971f769c Mon Sep 17 00:00:00 2001 From: lixiangcheng1 Date: Thu, 5 Mar 2026 14:18:33 +0800 Subject: [PATCH 099/164] =?UTF-8?q?=E3=80=90ADD=E3=80=91list=20operational?= =?UTF-8?q?=20mcp=20servers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../mcp_market_config_controller.py | 61 ++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/api/app/controllers/mcp_market_config_controller.py b/api/app/controllers/mcp_market_config_controller.py index 98012568..7f73663e 100644 --- a/api/app/controllers/mcp_market_config_controller.py +++ b/api/app/controllers/mcp_market_config_controller.py @@ -90,7 +90,7 @@ async def get_mcp_servers( cookies=cookies) raise_for_http_status(r) except requests.exceptions.RequestException as e: - api_logger.error(f"mFailed to get MCP servers: {str(e)}") + api_logger.error(f"Failed to get MCP servers: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to get MCP servers: {str(e)}" @@ -118,6 +118,65 @@ async def get_mcp_servers( return success(data=result, msg="Query of mcp servers list successful") +@router.get("/operational_mcp_servers", response_model=ApiResponse) +async def get_operational_mcp_servers( + mcp_market_config_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + Query the operational mcp servers list in pages + - Support keyword search for name,author,owner + - Return paging metadata + operational mcp server list + """ + api_logger.info( + f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}") + + # 1. Query mcp market config information from the database + api_logger.debug(f"Query mcp market config: {mcp_market_config_id}") + db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, + mcp_market_config_id=mcp_market_config_id, + current_user=current_user) + if not db_mcp_market_config: + api_logger.warning( + f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The mcp market config does not exist or access is denied" + ) + + # 2. Execute paged query + api = MCPApi() + token = db_mcp_market_config.token + api.login(token) + + url = f'{api.mcp_base_url}/operational' + headers = api.builder_headers(api.headers) + + try: + cookies = api.get_cookies(access_token=token, cookies_required=True) + r = api.session.get(url, headers=headers, cookies=cookies) + raise_for_http_status(r) + except requests.exceptions.RequestException as e: + api_logger.error(f"Failed to get operational MCP servers: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to get operational MCP servers: {str(e)}" + ) + + data = api._handle_response(r) + total = data.get('total_count', 0) + mcp_server_list = data.get('mcp_server_list', []) + # items = [{ + # 'name': item.get('name', ''), + # 'id': item.get('id', ''), + # 'description': item.get('description', '') + # } for item in mcp_server_list] + + # 3. Return structured response + return success(data=mcp_server_list, msg="Query of operational mcp servers list successful") + + @router.get("/mcp_server", response_model=ApiResponse) async def get_mcp_server( mcp_market_config_id: uuid.UUID, From 218637e81d29edb83c6c8d7204b2c291e6174c02 Mon Sep 17 00:00:00 2001 From: Mark Date: Thu, 5 Mar 2026 14:42:42 +0800 Subject: [PATCH 100/164] [add] migration script --- .../versions/6a4641cf192b_202603051440.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 api/migrations/versions/6a4641cf192b_202603051440.py diff --git a/api/migrations/versions/6a4641cf192b_202603051440.py b/api/migrations/versions/6a4641cf192b_202603051440.py new file mode 100644 index 00000000..0322c9e2 --- /dev/null +++ b/api/migrations/versions/6a4641cf192b_202603051440.py @@ -0,0 +1,43 @@ +"""202603051440 + +Revision ID: 6a4641cf192b +Revises: b4af97639217 +Create Date: 2026-03-05 14:41:03.371557 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '6a4641cf192b' +down_revision: Union[str, None] = 'b4af97639217' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('implicit_emotions_storage', + sa.Column('id', sa.UUID(), nullable=False, comment='主键ID'), + sa.Column('end_user_id', sa.String(length=255), nullable=False, comment='终端用户ID'), + sa.Column('implicit_profile', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='隐性记忆用户画像数据'), + sa.Column('emotion_suggestions', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='情绪个性化建议数据'), + sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'), + sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'), + sa.Column('implicit_generated_at', sa.DateTime(), nullable=True, comment='隐性记忆画像生成时间'), + sa.Column('emotion_generated_at', sa.DateTime(), nullable=True, comment='情绪建议生成时间'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('end_user_id') + ) + op.create_index('idx_updated_at', 'implicit_emotions_storage', ['updated_at'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('idx_updated_at', table_name='implicit_emotions_storage') + op.drop_table('implicit_emotions_storage') + # ### end Alembic commands ### From 60a95f655661e5cd0d22464ed63d943054c00759 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Thu, 5 Mar 2026 15:02:01 +0800 Subject: [PATCH 101/164] [changes] --- api/app/cache/__init__.py | 4 +--- api/app/cache/memory/__init__.py | 4 ---- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/api/app/cache/__init__.py b/api/app/cache/__init__.py index 5300348c..ca6a8784 100644 --- a/api/app/cache/__init__.py +++ b/api/app/cache/__init__.py @@ -4,10 +4,8 @@ Cache 缓存模块 提供各种缓存功能的统一入口 注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存 """ -from .memory import EmotionMemoryCache, ImplicitMemoryCache, InterestMemoryCache +from .memory import InterestMemoryCache __all__ = [ - "EmotionMemoryCache", - "ImplicitMemoryCache", "InterestMemoryCache", ] diff --git a/api/app/cache/memory/__init__.py b/api/app/cache/memory/__init__.py index 46ad0b73..7bc86068 100644 --- a/api/app/cache/memory/__init__.py +++ b/api/app/cache/memory/__init__.py @@ -4,12 +4,8 @@ Memory 缓存模块 提供记忆系统相关的缓存功能 注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存 """ -from .emotion_memory import EmotionMemoryCache -from .implicit_memory import ImplicitMemoryCache from .interest_memory import InterestMemoryCache __all__ = [ - "EmotionMemoryCache", - "ImplicitMemoryCache", "InterestMemoryCache", ] From 2e1eb9a5a67f2a041a3bb0b8efa02f0fad3af28d Mon Sep 17 00:00:00 2001 From: zhaoying Date: Thu, 5 Mar 2026 15:12:18 +0800 Subject: [PATCH 102/164] feat(web): file type add default value --- web/src/views/Conversation/components/FileUpload.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/src/views/Conversation/components/FileUpload.tsx b/web/src/views/Conversation/components/FileUpload.tsx index 98ece0e3..9da64cc7 100644 --- a/web/src/views/Conversation/components/FileUpload.tsx +++ b/web/src/views/Conversation/components/FileUpload.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-06 21:09:42 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-04 18:54:47 + * @Last Modified time: 2026-03-05 15:09:22 */ /** * File Upload Component @@ -206,7 +206,7 @@ const UploadFiles = forwardRef(({ */ const handleChange: UploadProps['onChange'] = ({ fileList: newFileList }) => { newFileList.map(file => { - const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type + const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type || 'document' file.type = type }) setFileList(newFileList); From 9c9fe9dde70873ac4d1c533f04ff6a0b72271de4 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Thu, 5 Mar 2026 16:21:27 +0800 Subject: [PATCH 103/164] [fix] Remove the unused ones --- api/app/cache/__init__.py | 5 +---- api/app/cache/memory/__init__.py | 5 ----- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/api/app/cache/__init__.py b/api/app/cache/__init__.py index 5300348c..ca7aa91a 100644 --- a/api/app/cache/__init__.py +++ b/api/app/cache/__init__.py @@ -2,12 +2,9 @@ Cache 缓存模块 提供各种缓存功能的统一入口 -注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存 """ -from .memory import EmotionMemoryCache, ImplicitMemoryCache, InterestMemoryCache +from .memory import InterestMemoryCache __all__ = [ - "EmotionMemoryCache", - "ImplicitMemoryCache", "InterestMemoryCache", ] diff --git a/api/app/cache/memory/__init__.py b/api/app/cache/memory/__init__.py index 46ad0b73..9a7fd225 100644 --- a/api/app/cache/memory/__init__.py +++ b/api/app/cache/memory/__init__.py @@ -2,14 +2,9 @@ Memory 缓存模块 提供记忆系统相关的缓存功能 -注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存 """ -from .emotion_memory import EmotionMemoryCache -from .implicit_memory import ImplicitMemoryCache from .interest_memory import InterestMemoryCache __all__ = [ - "EmotionMemoryCache", - "ImplicitMemoryCache", "InterestMemoryCache", ] From 621b074b3d6d2ee32e78cd9aaeb1521923103415 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Thu, 5 Mar 2026 16:36:39 +0800 Subject: [PATCH 104/164] feat(web): memory config & ontology add default tag --- web/src/i18n/en.ts | 1 + web/src/i18n/zh.ts | 1 + web/src/views/MemoryManagement/index.tsx | 12 +++++++++--- web/src/views/MemoryManagement/types.ts | 3 ++- web/src/views/Ontology/index.tsx | 11 ++++++++--- web/src/views/Ontology/types.ts | 6 ++++-- 6 files changed, 25 insertions(+), 9 deletions(-) diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 7cef2d6c..e0b144a9 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -454,6 +454,7 @@ export const en = { prevStep: 'Previous Step', exportSuccess: 'Export successful', recommend: 'Recommend', + default: 'Default', }, model: { searchPlaceholder: 'search model…', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 5c688934..5306f711 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1030,6 +1030,7 @@ export const zh = { prevStep: '上一步', exportSuccess: '导出成功', recommend: '推荐', + default: '默认', }, model: { searchPlaceholder: '搜索模型…', diff --git a/web/src/views/MemoryManagement/index.tsx b/web/src/views/MemoryManagement/index.tsx index ac2b4fa5..6ebb49c7 100644 --- a/web/src/views/MemoryManagement/index.tsx +++ b/web/src/views/MemoryManagement/index.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 17:33:15 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 17:33:15 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-05 16:28:58 */ /** * Memory Management Page @@ -110,9 +110,15 @@ const MemoryManagement: React.FC = () => { + {item.is_system_default && +
+ {t('common.default')} +
+ } -
{item.config_desc}
+
{item.config_desc}
diff --git a/web/src/views/MemoryManagement/types.ts b/web/src/views/MemoryManagement/types.ts index 48bdbb77..dc3ae091 100644 --- a/web/src/views/MemoryManagement/types.ts +++ b/web/src/views/MemoryManagement/types.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 17:33:01 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 17:33:24 + * @Last Modified time: 2026-03-05 16:33:53 */ /** * Memory management form data type @@ -42,6 +42,7 @@ export interface Memory { workspace_id: string; scene_id: string; scene_name: string; + is_system_default: boolean; [key: string]: string | number | boolean; } /** diff --git a/web/src/views/Ontology/index.tsx b/web/src/views/Ontology/index.tsx index 37f9118d..eaf1188b 100644 --- a/web/src/views/Ontology/index.tsx +++ b/web/src/views/Ontology/index.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 14:10:15 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-05 10:57:53 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-05 16:28:53 */ import { type FC, useState, useRef, type MouseEvent } from 'react'; import { useNavigate } from 'react-router-dom'; @@ -144,8 +144,13 @@ const Ontology: FC = () => { title={item.scene_name} extra={{item.type_num} {t('ontology.typeCount')}} onClick={() => handleJump(item)} - className="rb:cursor-pointer" + className="rb:cursor-pointer rb:relative" > + {item.is_system_default && +
+ {t('common.default')} +
+ }
diff --git a/web/src/views/Ontology/types.ts b/web/src/views/Ontology/types.ts index d78d8464..aad94ee0 100644 --- a/web/src/views/Ontology/types.ts +++ b/web/src/views/Ontology/types.ts @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 14:10:10 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 14:10:10 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-05 16:18:56 */ /** * Query parameters for ontology list pagination and filtering @@ -38,6 +38,8 @@ export interface OntologyItem { updated_at: number; /** Total count of classes in the scene */ classes_count: number; + /** Whether this is the system default configuration */ + is_system_default: boolean; } /** From 495c5802a0496a2177a30f31dbfa691bc193ef6a Mon Sep 17 00:00:00 2001 From: zhaoying Date: Thu, 5 Mar 2026 16:43:59 +0800 Subject: [PATCH 105/164] feat(web): knowledge add form rules --- web/src/views/KnowledgeBase/components/CreateModal.tsx | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/web/src/views/KnowledgeBase/components/CreateModal.tsx b/web/src/views/KnowledgeBase/components/CreateModal.tsx index 76640058..d9727d18 100644 --- a/web/src/views/KnowledgeBase/components/CreateModal.tsx +++ b/web/src/views/KnowledgeBase/components/CreateModal.tsx @@ -15,6 +15,7 @@ import { } from '@/api/knowledgeBase' import RbModal from '@/components/RbModal' import SliderInput from '@/components/SliderInput' +import { stringRegExp } from '@/utils/validator' const { TextArea } = Input; const { confirm } = Modal @@ -519,12 +520,16 @@ const CreateModal = forwardRef(({ )} - +