diff --git a/api/app/cache/memory/__init__.py b/api/app/cache/memory/__init__.py index 9a7fd225..551062ac 100644 --- a/api/app/cache/memory/__init__.py +++ b/api/app/cache/memory/__init__.py @@ -4,7 +4,9 @@ Memory 缓存模块 提供记忆系统相关的缓存功能 """ from .interest_memory import InterestMemoryCache +from .activity_stats_cache import ActivityStatsCache __all__ = [ "InterestMemoryCache", + "ActivityStatsCache", ] diff --git a/api/app/cache/memory/activity_stats_cache.py b/api/app/cache/memory/activity_stats_cache.py new file mode 100644 index 00000000..6b162cdd --- /dev/null +++ b/api/app/cache/memory/activity_stats_cache.py @@ -0,0 +1,124 @@ +""" +Recent Activity Stats Cache + +记忆提取活动统计缓存模块 +用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放 +查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843 +""" +import json +import logging +from typing import Optional, Dict, Any +from datetime import datetime + +from app.aioRedis import aio_redis + +logger = logging.getLogger(__name__) + +# 缓存过期时间:24小时 +ACTIVITY_STATS_CACHE_EXPIRE = 86400 + + +class ActivityStatsCache: + """记忆提取活动统计缓存类""" + + PREFIX = "cache:memory:activity_stats" + + @classmethod + def _get_key(cls, workspace_id: str) -> str: + """生成 Redis key + + Args: + workspace_id: 工作空间ID + + Returns: + 完整的 Redis key + """ + return f"{cls.PREFIX}:by_workspace:{workspace_id}" + + @classmethod + async def set_activity_stats( + cls, + workspace_id: str, + stats: Dict[str, Any], + expire: int = ACTIVITY_STATS_CACHE_EXPIRE, + ) -> bool: + """设置记忆提取活动统计缓存 + + Args: + workspace_id: 工作空间ID + stats: 统计数据,格式: + { + "chunk_count": int, + "statements_count": int, + "triplet_entities_count": int, + "triplet_relations_count": int, + "temporal_count": int, + } + expire: 过期时间(秒),默认24小时 + + Returns: + 是否设置成功 + """ + try: + key = cls._get_key(workspace_id) + payload = { + "stats": stats, + "generated_at": datetime.now().isoformat(), + "workspace_id": workspace_id, + "cached": True, + } + value = json.dumps(payload, ensure_ascii=False) + await aio_redis.set(key, value, ex=expire) + logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒") + return True + except Exception as e: + logger.error(f"设置活动统计缓存失败: {e}", exc_info=True) + return False + + @classmethod + async def get_activity_stats( + cls, + workspace_id: str, + ) -> Optional[Dict[str, Any]]: + """获取记忆提取活动统计缓存 + + Args: + workspace_id: 工作空间ID + + Returns: + 统计数据字典,缓存不存在或已过期返回 None + """ + try: + key = cls._get_key(workspace_id) + value = await aio_redis.get(key) + if value: + payload = json.loads(value) + logger.info(f"命中活动统计缓存: {key}") + return payload + logger.info(f"活动统计缓存不存在或已过期: {key}") + return None + except Exception as e: + logger.error(f"获取活动统计缓存失败: {e}", exc_info=True) + return None + + @classmethod + async def delete_activity_stats( + cls, + workspace_id: str, + ) -> bool: + """删除记忆提取活动统计缓存 + + Args: + workspace_id: 工作空间ID + + Returns: + 是否删除成功 + """ + try: + key = cls._get_key(workspace_id) + result = await aio_redis.delete(key) + logger.info(f"删除活动统计缓存: {key}, 结果: {result}") + return result > 0 + except Exception as e: + logger.error(f"删除活动统计缓存失败: {e}", exc_info=True) + return False diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 0319e079..e6b239dd 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -113,6 +113,7 @@ celery_app.conf.update( 'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'}, 'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'}, 'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'}, + 'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'}, }, ) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 1b5b45fb..1c82b636 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -149,6 +149,17 @@ async def get_workspace_end_users( return {uid: {"total": 0} for uid in end_user_ids} + # 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据 + try: + from app.celery_app import celery_app as _celery_app + _celery_app.send_task( + "app.tasks.init_implicit_emotions_for_users", + kwargs={"end_user_ids": end_user_ids}, + ) + api_logger.info(f"已触发隐性记忆按需初始化任务,候选用户数: {len(end_user_ids)}") + except Exception as e: + api_logger.warning(f"触发隐性记忆按需初始化任务失败(不影响主流程): {e}") + # 并发执行配置查询和记忆数量查询 memory_configs_map, memory_nums_map = await asyncio.gather( get_memory_configs(), diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index ee45fb83..d91dfc36 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -544,10 +544,11 @@ async def clear_hot_memory_tags_cache( @router.get("/analytics/recent_activity_stats", response_model=ApiResponse) async def get_recent_activity_stats_api( current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info("Recent activity stats requested") +) -> dict: + workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None + api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}") try: - result = await analytics_recent_activity_stats() + result = await analytics_recent_activity_stats(workspace_id=workspace_id) return success(data=result, msg="查询成功") except Exception as e: api_logger.error(f"Recent activity stats failed: {str(e)}") diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py index 0de3d4fe..6204a745 100644 --- a/api/app/controllers/model_controller.py +++ b/api/app/controllers/model_controller.py @@ -371,6 +371,11 @@ def update_model( if model_data.type is not None or model_data.provider is not None: raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER) + + if model_data.is_active: + active_keys = ModelApiKeyService.get_api_keys_by_model(db=db, model_config_id=model_id, is_active=model_data.is_active) + if not active_keys: + raise BusinessException("请先为该模型配置可用的 API Key", BizCode.INVALID_PARAMETER) try: api_logger.debug(f"开始更新模型配置: model_id={model_id}") diff --git a/api/app/controllers/tool_controller.py b/api/app/controllers/tool_controller.py index a3624ea4..ce5b15c0 100644 --- a/api/app/controllers/tool_controller.py +++ b/api/app/controllers/tool_controller.py @@ -97,6 +97,12 @@ async def create_tool( ): """创建工具""" try: + # 将 MCP 来源字段合并进 config + if request.tool_type == ToolType.MCP: + for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"): + val = getattr(request, key, None) + if val is not None: + request.config[key] = val tool_id = service.create_tool( name=request.name, tool_type=request.tool_type, diff --git a/api/app/core/config.py b/api/app/core/config.py index ba17da93..bbe327b6 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -192,8 +192,10 @@ class Settings: # Celery configuration (internal) # NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持 # 详见 docs/celery-env-bug-report.md - REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "1")) - REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "2")) + # 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离 + # 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰 + REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3")) + REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4")) # SMTP Email Configuration SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com") diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index c8cc0460..784e5802 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -111,7 +111,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState: "error_type": type(e).__name__, "error_message": str(e), "content_length": len(content), - "llm_model_id": memory_config.llm_model_id if memory_config else None + "llm_model_id": str(memory_config.llm_model_id) if memory_config else None } logger.error(f"Split_The_Problem error details: {error_details}") @@ -221,7 +221,7 @@ async def Problem_Extension(state: ReadState) -> ReadState: "error_type": type(e).__name__, "error_message": str(e), "questions_count": len(databasets), - "llm_model_id": memory_config.llm_model_id if memory_config else None + "llm_model_id": str(memory_config.llm_model_id) if memory_config else None } logger.error(f"Problem_Extension error details: {error_details}") diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py index ad0473fc..10fe96ba 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py @@ -1,3 +1,4 @@ +from app.cache.memory.interest_memory import InterestMemoryCache from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.utils.write_tools import write from app.core.logging_config import get_agent_logger @@ -40,6 +41,15 @@ async def write_node(state: WriteState) -> WriteState: ) logger.info(f"Write completed successfully! Config: {memory_config.config_name}") + # 写入 neo4j 成功后,删除该用户的兴趣分布缓存,确保下次请求重新生成 + for lang in ["zh", "en"]: + deleted = await InterestMemoryCache.delete_interest_distribution( + end_user_id=end_user_id, + language=lang, + ) + if deleted: + logger.info(f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}") + write_result = { "status": "success", "data": structured_messages, diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index 22555fff..ea44d0a5 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -82,7 +82,9 @@ async def get_chunked_dialogs( pruning_config = PruningConfig( pruning_switch=memory_config.pruning_enabled, pruning_scene=memory_config.pruning_scene or "education", - pruning_threshold=memory_config.pruning_threshold + pruning_threshold=memory_config.pruning_threshold, + scene_id=str(memory_config.scene_id) if memory_config.scene_id else None, + ontology_classes=memory_config.ontology_classes, ) logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}") diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 93c6ef6f..22030278 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -225,5 +225,24 @@ async def write( with open(log_file, "a", encoding="utf-8") as f: f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n") + # 将提取统计写入 Redis,按 workspace_id 存储 + try: + from app.cache.memory.activity_stats_cache import ActivityStatsCache + + stats_to_cache = { + "chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0, + "statements_count": len(all_statement_nodes) if all_statement_nodes else 0, + "triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0, + "triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0, + "temporal_count": 0, + } + await ActivityStatsCache.set_activity_stats( + workspace_id=str(memory_config.workspace_id), + stats=stats_to_cache, + ) + logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}") + except Exception as cache_err: + logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) + logger.info("=== Pipeline Complete ===") logger.info(f"Total execution time: {total_time:.2f} seconds") \ No newline at end of file diff --git a/api/app/core/memory/models/config_models.py b/api/app/core/memory/models/config_models.py index ca1780aa..c2d62ac1 100644 --- a/api/app/core/memory/models/config_models.py +++ b/api/app/core/memory/models/config_models.py @@ -10,7 +10,7 @@ Classes: TemporalSearchParams: Parameters for temporal search queries """ -from typing import Optional +from typing import Optional, List from pydantic import BaseModel, Field @@ -55,17 +55,26 @@ class PruningConfig(BaseModel): Attributes: pruning_switch: Enable or disable semantic pruning - pruning_scene: Scene type for pruning ('education', 'online_service', 'outbound') + pruning_scene: Scene name for pruning, either a built-in key + ('education', 'online_service', 'outbound') or a custom scene_name + from ontology_scene table pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal) + scene_id: Optional ontology scene UUID, used to load custom ontology classes + ontology_classes: List of class_name strings from ontology_class table, + injected into the prompt when pruning_scene is not a built-in scene """ pruning_switch: bool = Field(False, description="Enable semantic pruning when True.") pruning_scene: str = Field( "education", - description="Scene for pruning: one of 'education', 'online_service', 'outbound'.", + description="Scene for pruning: built-in key or custom scene_name from ontology_scene.", ) pruning_threshold: float = Field( 0.5, ge=0.0, le=0.9, description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).") + scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).") + ontology_classes: Optional[List[str]] = Field( + None, description="Class names from ontology_class table for custom scenes." + ) class TemporalSearchParams(BaseModel): diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index 0a913633..904b238f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -86,19 +86,26 @@ class SemanticPruner: self._detailed_prune_logging = True # 是否启用详细日志 self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志 - # 加载场景特定配置 + # 加载场景特定配置(内置场景走专门规则,自定义场景 fallback 到通用规则) self.scene_config: ScenePatterns = SceneConfigRegistry.get_config( self.config.pruning_scene, fallback_to_generic=True ) - # 检查场景是否有专门支持 - is_supported = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene) - if is_supported: - self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用专门配置") + # 判断是否为内置专门场景 + self._is_builtin_scene = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene) + + # 自定义场景的本体类型列表(用于注入提示词) + self._ontology_classes = getattr(self.config, "ontology_classes", None) or [] + + if self._is_builtin_scene: + self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用内置专门配置") else: - self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 未预定义,使用通用配置(保守策略)") - self._log(f"[剪枝-初始化] 支持的场景: {SceneConfigRegistry.get_all_scenes()}") + self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 为自定义场景,使用通用规则 + 本体类型提示词注入") + if self._ontology_classes: + self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}") + else: + self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词") # Load Jinja2 template self.template = prompt_env.get_template("extracat_Pruning.jinja2") @@ -424,12 +431,16 @@ class SemanticPruner: self._log(f"[剪枝-缓存] LRU缓存已满,删除最旧条目") rendered = self.template.render( - pruning_scene=self.config.pruning_scene, + pruning_scene=self.config.pruning_scene, + is_builtin_scene=self._is_builtin_scene, + ontology_classes=self._ontology_classes, dialog_text=dialog_text, language=self.language ) log_template_rendering("extracat_Pruning.jinja2", { "pruning_scene": self.config.pruning_scene, + "is_builtin_scene": self._is_builtin_scene, + "ontology_classes_count": len(self._ontology_classes), "language": self.language }) log_prompt_rendering("pruning-extract", rendered) diff --git a/api/app/core/memory/utils/prompt/prompts/extracat_Pruning.jinja2 b/api/app/core/memory/utils/prompt/prompts/extracat_Pruning.jinja2 index 8253924b..6b620df9 100644 --- a/api/app/core/memory/utils/prompt/prompts/extracat_Pruning.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extracat_Pruning.jinja2 @@ -1,6 +1,6 @@ {# 对话级抽取与相关性判定模板(用于剪枝加速) - 输入:pruning_scene, dialog_text + 输入:pruning_scene, is_builtin_scene, ontology_classes, dialog_text, language 输出:严格 JSON(不要包含任何多余文本),字段: - is_related: bool,是否与所选场景相关 - times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等) @@ -16,7 +16,8 @@ - 仅输出上述键;避免多余解释或字段。 #} -{% set scene_instructions = { +{# ── 内置场景的固定说明 ── #} +{% set builtin_scene_instructions = { 'education': { 'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。', 'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.' @@ -31,16 +32,40 @@ } } %} -{% set scene_key = pruning_scene %} -{% if scene_key not in scene_instructions %} -{% set scene_key = 'education' %} +{# ── 确定最终使用的场景说明 ── #} +{% if is_builtin_scene %} + {# 内置专门场景:使用固定说明 #} + {% set scene_key = pruning_scene %} + {% if scene_key not in builtin_scene_instructions %}{% set scene_key = 'education' %}{% endif %} + {% set instruction = builtin_scene_instructions[scene_key][language] if language in ['zh', 'en'] else builtin_scene_instructions[scene_key]['zh'] %} + {% set custom_types_str = '' %} +{% else %} + {# 自定义场景:使用场景名称 + 本体类型列表构建说明 #} + {% if ontology_classes and ontology_classes | length > 0 %} + {% if language == 'en' %} + {% set custom_types_str = ontology_classes | join(', ') %} + {% set instruction = 'Custom scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %} + {% else %} + {% set custom_types_str = ontology_classes | join('、') %} + {% set instruction = '自定义场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %} + {% endif %} + {% else %} + {# 无本体类型时退化为通用说明 #} + {% if language == 'en' %} + {% set instruction = 'Custom scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %} + {% else %} + {% set instruction = '自定义场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %} + {% endif %} + {% set custom_types_str = '' %} + {% endif %} {% endif %} -{% set instruction = scene_instructions[scene_key][language] if language in ['zh', 'en'] else scene_instructions[scene_key]['zh'] %} - {% if language == "zh" %} 请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性: 场景说明:{{ instruction }} +{% if not is_builtin_scene and custom_types_str %} +重要提示:只要对话中出现与上述实体类型({{ custom_types_str }})相关的内容,即判定为相关(is_related=true)。 +{% endif %} 对话全文: """ @@ -60,6 +85,9 @@ {% else %} Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario: Scenario Description: {{ instruction }} +{% if not is_builtin_scene and custom_types_str %} +Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true). +{% endif %} Full Dialogue: """ diff --git a/api/app/core/workflow/adapters/dify/converter.py b/api/app/core/workflow/adapters/dify/converter.py index 06c988d3..467beb07 100644 --- a/api/app/core/workflow/adapters/dify/converter.py +++ b/api/app/core/workflow/adapters/dify/converter.py @@ -8,34 +8,60 @@ from typing import Any from urllib.parse import quote from app.core.workflow.adapters.base_converter import BaseConverter -from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \ +from app.core.workflow.adapters.errors import ( + UnsupportVariableType, + UnknowModelWarning, + ExceptionDefineition, ExceptionType -from app.core.workflow.nodes.assigner import AssignerNodeConfig +) from app.core.workflow.nodes.assigner.config import AssignmentItem from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig -from app.core.workflow.nodes.code import CodeNodeConfig from app.core.workflow.nodes.code.config import InputVariable, OutputVariable -from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig -from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig -from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \ +from app.core.workflow.nodes.configs import ( + StartNodeConfig, + LLMNodeConfig, + AssignerNodeConfig, + CodeNodeConfig, + LoopNodeConfig, + IterationNodeConfig, + EndNodeConfig, + HttpRequestNodeConfig, + IfElseNodeConfig, + JinjaRenderNodeConfig, + KnowledgeRetrievalNodeConfig, + NoteNodeConfig, + ParameterExtractorNodeConfig, + QuestionClassifierNodeConfig, + VariableAggregatorNodeConfig +) +from app.core.workflow.nodes.cycle_graph.config import ( + ConditionDetail as LoopConditionDetail, + ConditionsConfig, CycleVariable -from app.core.workflow.nodes.end import EndNodeConfig -from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \ - HttpContentType, HttpErrorHandle -from app.core.workflow.nodes.http_request import HttpRequestNodeConfig -from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \ - HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig -from app.core.workflow.nodes.if_else import IfElseNodeConfig +) +from app.core.workflow.nodes.enums import ( + ValueInputType, + ComparisonOperator, + AssignmentOperator, + HttpAuthType, + HttpContentType, + HttpErrorHandle, + NodeType +) +from app.core.workflow.nodes.http_request.config import ( + HttpAuthConfig, + HttpContentTypeConfig, + HttpFormData, + HttpTimeOutConfig, + HttpRetryConfig, + HttpErrorDefaultTamplete, + HttpErrorHandleConfig +) from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig -from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig -from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig -from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig -from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig from app.core.workflow.nodes.question_classifier.config import ClassifierConfig -from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE @@ -48,24 +74,24 @@ class DifyConverter(BaseConverter): def __init__(self): self.CONFIG_CONVERT_MAP = { - "start": self.convert_start_node_config, - "llm": self.convert_llm_node_config, - "answer": self.convert_end_node_config, - "if-else": self.convert_if_else_node_config, - "loop": self.convert_loop_node_config, - "iteration": self.convert_iteration_node_config, - "assigner": self.convert_assigner_node_config, - "code": self.convert_code_node_config, - "http-request": self.convert_http_node_config, - "template-transform": self.convert_jinja_render_node_config, - "knowledge-retrieval": self.convert_knowledge_node_config, - "parameter-extractor": self.convert_parameter_extractor_node_config, - "question-classifier": self.convert_question_classifier_node_config, - "variable-aggregator": self.convert_variable_aggregator_node_config, - "tool": self.convert_tool_node_config, - "loop-start": lambda x: {}, - "iteration-start": lambda x: {}, - "loop-end": lambda x: {}, + NodeType.START: self.convert_start_node_config, + NodeType.LLM: self.convert_llm_node_config, + NodeType.END: self.convert_end_node_config, + NodeType.IF_ELSE: self.convert_if_else_node_config, + NodeType.LOOP: self.convert_loop_node_config, + NodeType.ITERATION: self.convert_iteration_node_config, + NodeType.ASSIGNER: self.convert_assigner_node_config, + NodeType.CODE: self.convert_code_node_config, + NodeType.HTTP_REQUEST: self.convert_http_node_config, + NodeType.JINJARENDER: self.convert_jinja_render_node_config, + NodeType.KNOWLEDGE_RETRIEVAL: self.convert_knowledge_node_config, + NodeType.PARAMETER_EXTRACTOR: self.convert_parameter_extractor_node_config, + NodeType.QUESTION_CLASSIFIER: self.convert_question_classifier_node_config, + NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config, + NodeType.TOOL: self.convert_tool_node_config, + NodeType.NOTES: self.convert_notes_config, + NodeType.CYCLE_START: lambda x: {}, + NodeType.BREAK: lambda x: {}, } def get_node_convert(self, node_type): @@ -129,11 +155,11 @@ class DifyConverter(BaseConverter): @staticmethod def _convert_file(var): - pass + return None @staticmethod def _convert_array_file(var): - pass + return [] @staticmethod def variable_type_map(source_type) -> VariableType | None: @@ -185,6 +211,9 @@ class DifyConverter(BaseConverter): "not empty": ComparisonOperator.NOT_EMPTY, "start with": ComparisonOperator.START_WITH, "end with": ComparisonOperator.END_WITH, + "not contains": ComparisonOperator.NOT_CONTAINS, + "exists": ComparisonOperator.NOT_EMPTY, + "not exists": ComparisonOperator.EMPTY } return operator_map.get(operator, operator) @@ -198,7 +227,7 @@ class DifyConverter(BaseConverter): "over-write": AssignmentOperator.COVER, "remove-last": AssignmentOperator.REMOVE_LAST, "remove-first": AssignmentOperator.REMOVE_FIRST, - + "set": AssignmentOperator.ASSIGN, } return operator_map.get(operator, operator) @@ -267,10 +296,10 @@ class DifyConverter(BaseConverter): type=var_type, required=var["required"], default=self.convert_variable_type( - var_type, var["default"] + var_type, var.get("default") ), description=var["label"], - max_length=var.get("max_length"), + max_length=var.get("max_length", 50), ) start_vars.append(var_def) result = StartNodeConfig.model_construct( @@ -333,7 +362,7 @@ class DifyConverter(BaseConverter): MessageConfig( role="user", content=self.trans_variable_format( - node_data["memory"].get("query_prompt_template", "{{#sys.query#}}") + node_data["memory"].get("query_prompt_template") or "{{#sys.query#}}" ) ) ) @@ -364,7 +393,7 @@ class DifyConverter(BaseConverter): node_data = node["data"] cases = [] for case in node_data["cases"]: - case_id = case["id"] + case_id = case.get("id") or case.get("case_id") logical_operator = case["logical_operator"] conditions = [] for condition in case["conditions"]: @@ -540,7 +569,8 @@ class DifyConverter(BaseConverter): ] = self.trans_variable_format(content["value"]) else: if node_data["body"]["data"]: - body_content = node_data["body"]["data"][0]["value"] + body_content = (node_data["body"]["data"][0].get("value") or + self._process_list_variable_litearl(node_data["body"]["data"][0].get("file"))) else: body_content = "" @@ -612,7 +642,7 @@ class DifyConverter(BaseConverter): ), headers=headers, params=params, - verify_ssl=node_data["ssl_verify"], + verify_ssl=node_data.get("ssl_verify", False), timeouts=HttpTimeOutConfig.model_construct( connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5, read_timeout=node_data["timeout"]["max_read_timeout"] or 5, @@ -696,7 +726,7 @@ class DifyConverter(BaseConverter): group_variables = {} group_type = {} if not advanced_settings or not advanced_settings["group_enabled"]: - group_variables["output"] = [ + group_variables = [ self._process_list_variable_litearl(variable) for variable in node_data["variables"] ] @@ -728,3 +758,16 @@ class DifyConverter(BaseConverter): detail=f"Please reconfigure the tool node.", )) return {} + + @staticmethod + def convert_notes_config(node: dict): + node_data = node["data"] + result = NoteNodeConfig.model_construct( + author=node_data.get("author", ""), + text=node_data.get("text", ""), + width=node_data.get("width", 80), + height=node_data.get("height", 80), + theme=node_data.get("theme", "blue"), + show_author=node_data.get("showAuthor", True) + ).model_dump() + return result diff --git a/api/app/core/workflow/adapters/dify/dify_adapter.py b/api/app/core/workflow/adapters/dify/dify_adapter.py index 6336b1f9..10397ad0 100644 --- a/api/app/core/workflow/adapters/dify/dify_adapter.py +++ b/api/app/core/workflow/adapters/dify/dify_adapter.py @@ -44,12 +44,13 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): "parameter-extractor": NodeType.PARAMETER_EXTRACTOR, "question-classifier": NodeType.QUESTION_CLASSIFIER, "variable-aggregator": NodeType.VAR_AGGREGATOR, - "tool": NodeType.TOOL + "tool": NodeType.TOOL, + "": NodeType.NOTES } def __init__(self, config: dict[str, Any]): DifyConverter.__init__(self) - BasePlatformAdapter.__init__(self, config) + BasePlatformAdapter.__init__(self, config) def get_metadata(self) -> PlatformMetadata: return PlatformMetadata( @@ -58,7 +59,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): support_node_types=list(self.NODE_TYPE_MAPPING.keys()) ) - def map_node_type(self, platform_node_type) -> str: + def map_node_type(self, platform_node_type) -> NodeType: return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN) @property @@ -83,6 +84,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): require_fields = frozenset({'app', 'kind', 'version', 'workflow'}) if not all(field in self.config for field in require_fields): return False + if self.config.get("app", {}).get("mode") == "workflow": + self.errors.append(ExceptionDefineition( + type=ExceptionType.PLATFORM, + detail="workflow mode is not supported" + )) + return False for node in self.origin_nodes: if not self._valid_nodes(node): @@ -134,6 +141,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): for node in self.origin_nodes: if self.map_node_type(node["data"]["type"]) == NodeType.LLM: self.node_output_map[f"{node['id']}.text"] = f"{node['id']}.output" + elif self.map_node_type(node["data"]["type"]) == NodeType.KNOWLEDGE_RETRIEVAL: + self.node_output_map[f"{node['id']}.result"] = f"{node['id']}.output" def _convert_cycle_node_position(self, node_id: str, position: dict): for node in self.origin_nodes: @@ -154,13 +163,14 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None: node_data = node["data"] try: + node_type = self.map_node_type(node_data["type"]) return NodeDefinition( id=node["id"], - type=self.map_node_type(node_data["type"]), - name=node_data.get("title"), + type=node_type, + name=node_data.get("title") or "notes", cycle=node.get("parentId"), description=None, - config=self._convert_node_config(node), + config=self._convert_node_config(node_type, node), position={ "x": node["position"]["x"], "y": node["position"]["y"] @@ -174,17 +184,16 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): except Exception as e: logger.debug(f"convert node error - {e}", exc_info=True) - def _convert_node_config(self, node: dict): - node_data = node["data"] - node_type = node_data["type"] + def _convert_node_config(self, node_type: NodeType, node: dict): try: + node_data = node["data"] converter = self.get_node_convert(node_type) - if node_type not in self.CONFIG_CONVERT_MAP: + if node_type == NodeType.UNKNOWN: self.errors.append(ExceptionDefineition( type=ExceptionType.NODE, node_id=node["id"], node_name=node["data"]["title"], - detail=f"node type {node_type} is unsupported", + detail=f"node type {node_data.get('type')} is unsupported", )) return converter(node) except Exception as e: @@ -201,16 +210,15 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): source = edge["source"] target = edge["target"] - edge_id = edge["id"] label = None if source in self.branch_node_cache: - case_id = "-".join(edge_id.split("-")[1:-2]) + case_id = edge["sourceHandle"] if case_id == "false": - label = f'CASE{len(self.branch_node_cache[source])+1}' + label = f'CASE{len(self.branch_node_cache[source]) + 1}' else: label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}' if source in self.error_branch_node_cache: - case_id = "-".join(edge_id.split("-")[1:-2]) + case_id = edge["sourceHandle"] if case_id == "source": label = "SUCCESS" else: @@ -235,6 +243,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): name=variable["name"], default=variable["value"], type=self.variable_type_map(variable["value_type"]), + description=variable.get("description") ) except Exception as e: self.errors.append(ExceptionDefineition( @@ -248,5 +257,3 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig: return ExecutionConfig() - - diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index 7b5c059c..90668ad9 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -292,6 +292,8 @@ class GraphBuilder: """ for node in self.nodes: node_type = node.get("type") + if node_type == NodeType.NOTES: + continue node_id = node.get("id") cycle_node = node.get("cycle") if cycle_node: @@ -320,7 +322,7 @@ class GraphBuilder: # Used later to determine which branch to take based on the node's output # Assumes node output `node..output` matches the edge's label # For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' - related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" + related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'" if node_instance: # Wrap node's run method to avoid closure issues diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 78149e4c..ff979f2b 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -158,18 +158,36 @@ class WorkflowExecutor: full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False) # Append messages for user and assistant - result["messages"].extend( - [ - { - "role": "user", - "content": input_data.get("message", '') - }, - { - "role": "assistant", - "content": full_content - } - ] - ) + if input_data.get("files"): + result["messages"].extend( + [ + { + "role": "user", + "content": input_data.get("message", '') + }, + { + "role": "user", + "content": input_data.get("files") + }, + { + "role": "assistant", + "content": full_content + } + ] + ) + else: + result["messages"].extend( + [ + { + "role": "user", + "content": input_data.get("message", '') + }, + { + "role": "assistant", + "content": full_content + } + ] + ) # Calculate elapsed time end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() @@ -308,18 +326,36 @@ class WorkflowExecutor: elapsed_time = (end_time - start_time).total_seconds() # Append messages for user and assistant - result["messages"].extend( - [ - { - "role": "user", - "content": input_data.get("message", '') - }, - { - "role": "assistant", - "content": full_content - } - ] - ) + if input_data.get("files"): + result["messages"].extend( + [ + { + "role": "user", + "content": input_data.get("message", '') + }, + { + "role": "user", + "content": input_data.get("files") + }, + { + "role": "assistant", + "content": full_content + } + ] + ) + else: + result["messages"].extend( + [ + { + "role": "user", + "content": input_data.get("message", '') + }, + { + "role": "assistant", + "content": full_content + } + ] + ) logger.info( f"Workflow execution completed (streaming), " f"elapsed: {elapsed_time:.2f}ms, execution_id: {self.execution_context.execution_id}" diff --git a/api/app/core/workflow/nodes/base_config.py b/api/app/core/workflow/nodes/base_config.py index 973e120d..4ae89376 100644 --- a/api/app/core/workflow/nodes/base_config.py +++ b/api/app/core/workflow/nodes/base_config.py @@ -85,20 +85,20 @@ class BaseNodeConfig(BaseModel): - tags: 节点标签(用于分类和搜索) """ - name: str | None = Field( - default=None, - description="节点名称(显示名称),如果不设置则使用节点 ID" - ) - - description: str | None = Field( - default=None, - description="节点描述,说明节点的作用" - ) - - tags: list[str] = Field( - default_factory=list, - description="节点标签,用于分类和搜索" - ) + # name: str | None = Field( + # default=None, + # description="节点名称(显示名称),如果不设置则使用节点 ID" + # ) + # + # description: str | None = Field( + # default=None, + # description="节点描述,说明节点的作用" + # ) + # + # tags: list[str] = Field( + # default_factory=list, + # description="节点标签,用于分类和搜索" + # ) class Config: """Pydantic 配置""" diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 3f30718c..496454ba 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -1,7 +1,7 @@ import asyncio import logging -import uuid from abc import ABC, abstractmethod +from datetime import datetime from functools import cached_property from typing import Any, AsyncGenerator @@ -13,6 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.enums import BRANCH_NODES from app.core.workflow.variable.base_variable import VariableType, FileObject from app.db import get_db_read +from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy from app.schemas import FileInput from app.services.multimodal_service import MultimodalService @@ -617,17 +618,31 @@ class BaseNode(ABC): return variable_pool.has(selector) @staticmethod - async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None: + async def process_message( + provider: str, + is_omni: bool, + content: str | dict | FileObject, + enable_file=False + ) -> list | str | None: + if isinstance(content, dict): + content = FileObject( + type=content.get("type"), + url=content.get("url"), + transfer_method=content.get("transfer_method"), + origin_file_type=content.get("origin_file_type"), + file_id=content.get("file_id"), + is_file=True + ) if isinstance(content, str): if enable_file: - return {"text": content} + return [{"type": "text", "text": content}] return content elif isinstance(content, FileObject): if content.content_cache.get(provider): return content.content_cache[provider] with get_db_read() as db: - multimodel_service = MultimodalService(db, provider) + multimodel_service = MultimodalService(db, provider, is_omni=is_omni) message = await multimodel_service.process_files( [FileInput.model_construct( type=content.type, @@ -637,10 +652,9 @@ class BaseNode(ABC): upload_file_id=content.file_id )] ) - if message: - content.content_cache[provider] = message[0] - return message[0] + content.content_cache[provider] = message + return message return None raise TypeError(f'Unexpect input value type - {type(content)}') @@ -658,3 +672,12 @@ class BaseNode(ABC): elif isinstance(content, str): return content return result + + @staticmethod + def model_balance(model_config: ModelConfig) -> ModelApiKey: + api_keys = [key for key in model_config.api_keys if key.is_active] + if not api_keys: + raise ValueError("No active API keys available for model") + if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: + return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min)) + return api_keys[0] diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index e4e418fe..31dadc38 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -23,6 +23,7 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie from app.core.workflow.nodes.start.config import StartNodeConfig from app.core.workflow.nodes.tool.config import ToolNodeConfig from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig +from app.core.workflow.nodes.notes.config import NoteNodeConfig __all__ = [ # 基础类 @@ -47,5 +48,6 @@ __all__ = [ "ToolNodeConfig", "MemoryReadNodeConfig", "MemoryWriteNodeConfig", - "CodeNodeConfig" + "CodeNodeConfig", + "NoteNodeConfig" ] diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index ae9b81ff..43ab593b 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -25,6 +25,7 @@ class NodeType(StrEnum): MEMORY_WRITE = "memory-write" UNKNOWN = "unknown" + NOTES = "notes" BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER] diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 17f55319..696298eb 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -180,6 +180,8 @@ class KnowledgeRetrievalNode(BaseNode): RuntimeError: If no valid knowledge base is found or access is denied. """ self.typed_config = KnowledgeRetrievalNodeConfig(**self.config) + if not self.typed_config.knowledge_bases: + return [] query = self._render_template(self.typed_config.query, variable_pool) with get_db_read() as db: knowledge_bases = self.typed_config.knowledge_bases diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index c109d59b..186c204f 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -112,11 +112,12 @@ class LLMNode(BaseNode): raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) # 在 Session 关闭前提取所有需要的数据 - api_config = config.api_keys[0] + api_config = self.model_balance(config) model_name = api_config.model_name provider = api_config.provider api_key = api_config.api_key api_base = api_config.api_base + is_omni = api_config.is_omni model_type = config.type # 4. 创建 LLM 实例(使用已提取的数据) @@ -129,7 +130,8 @@ class LLMNode(BaseNode): provider=provider, api_key=api_key, base_url=api_base, - extra_params=extra_params + extra_params=extra_params, + is_omni=is_omni ), type=ModelType(model_type) ) @@ -151,39 +153,53 @@ class LLMNode(BaseNode): if role == "system": messages.append({ "role": "system", - "content": content + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) elif role in ["user", "human"]: messages.append({ "role": "user", - "content": content + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) elif role in ["ai", "assistant"]: messages.append({ "role": "assistant", - "content": content + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) else: logger.warning(f"未知的消息角色: {role},默认使用 user") messages.append({ "role": "user", - "content": content + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) if self.typed_config.vision_input and self.typed_config.vision: file_content = [] files = variable_pool.get_instance(self.typed_config.vision_input) for file in files.value: - content = await self.process_message(provider, file.value, self.typed_config.vision) + content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision) if content: - file_content.append(content) + file_content.extend(content) if messages and messages[-1]["role"] == 'user': - messages[-1]['content'] = [messages[-1]["content"]] + file_content + messages[-1]['content'] = messages[-1]["content"] + file_content else: messages.append({"role": "user", "content": file_content}) if self.typed_config.memory.enable: - messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:] + history_message = [] + for message in state["messages"][-self.typed_config.memory.window_size:]: + if isinstance(message["content"], list): + file_content = [] + for file in message["content"]: + content = await self.process_message(provider, is_omni, file, self.typed_config.vision) + if content: + file_content.extend(content) + history_message.append( + {"role": message["role"], "content": file_content} + ) + else: + message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision) + history_message.append(message) + messages = messages[:-1] + history_message + messages[-1:] self.messages = messages else: # 使用简单的 prompt 格式(向后兼容) diff --git a/api/app/core/workflow/nodes/notes/__init__.py b/api/app/core/workflow/nodes/notes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/core/workflow/nodes/notes/config.py b/api/app/core/workflow/nodes/notes/config.py new file mode 100644 index 00000000..42b4a1ab --- /dev/null +++ b/api/app/core/workflow/nodes/notes/config.py @@ -0,0 +1,12 @@ +from pydantic import Field + +from app.core.workflow.nodes.base_config import BaseNodeConfig + + +class NoteNodeConfig(BaseNodeConfig): + author: str = Field(default="", description="author") + text: str = Field(default="", description="note content") + width: int = Field(default=80) + height: int = Field(default=80) + theme: str = Field(default="blue") + show_author: bool = Field(default=True) diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index 4811c118..700ed85f 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -95,11 +95,12 @@ class ParameterExtractorNode(BaseNode): if not config.api_keys or len(config.api_keys) == 0: raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER) - api_config = config.api_keys[0] + api_config = self.model_balance(config) model_name = api_config.model_name provider = api_config.provider api_key = api_config.api_key api_base = api_config.api_base + is_omni = api_config.is_omni model_type = config.type llm = RedBearLLM( @@ -108,6 +109,7 @@ class ParameterExtractorNode(BaseNode): provider=provider, api_key=api_key, base_url=api_base, + is_omni=is_omni ), type=ModelType(model_type) ) diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index e2fd97ae..5cebd886 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -56,11 +56,12 @@ class QuestionClassifierNode(BaseNode): if not config.api_keys or len(config.api_keys) == 0: raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) - api_config = config.api_keys[0] + api_config = self.model_balance(config) model_name = api_config.model_name provider = api_config.provider api_key = api_config.api_key base_url = api_config.api_base + is_omni = api_config.is_omni model_type = config.type return RedBearLLM( @@ -69,6 +70,7 @@ class QuestionClassifierNode(BaseNode): provider=provider, api_key=api_key, base_url=base_url, + is_omni=is_omni ), type=ModelType(model_type) ) diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index 47256b75..3b6e9036 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -138,7 +138,7 @@ class WorkflowValidator: errors.append("工作流必须至少有一个 end 节点") # 3. 验证节点 ID 唯一性 - node_ids = [n.get("id") for n in nodes] + node_ids = [n.get("id") for n in nodes if n.get("type") != NodeType.NOTES] if len(node_ids) != len(set(node_ids)): duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1] errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}") diff --git a/api/app/models/tool_model.py b/api/app/models/tool_model.py index ccd28693..98448bc5 100644 --- a/api/app/models/tool_model.py +++ b/api/app/models/tool_model.py @@ -3,7 +3,7 @@ import uuid from datetime import datetime from enum import StrEnum -from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean +from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean, text from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship @@ -163,6 +163,17 @@ class CustomToolConfig(Base): return f"" +class MCPSourceChannel(StrEnum): + """MCP来源渠道枚举""" + ALIYUN_BAILIAN = "aliyun_bailian" # 阿里云百炼 + MODELSCOPE = "modelscope" # ModelScope + TOKENFLUX = "tokenflux" # TokenFlux + LANGENG = "langeng" # 蓝耕科技 + AI_302 = "302ai" # 302.AI + MCP_ROUTER = "mcp_router" # MCP Router + SELF_HOSTED = "self_hosted" # 自建 + + class MCPToolConfig(Base): """MCP工具配置模型""" __tablename__ = "mcp_tool_configs" @@ -170,6 +181,13 @@ class MCPToolConfig(Base): id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True) server_url = Column(String(1000), nullable=False) # MCP服务器URL connection_config = Column(JSON, default=dict) # 连接配置(包含认证信息) + + # 来源渠道 + source_channel = Column(String(50), default=MCPSourceChannel.SELF_HOSTED, + server_default=text(f"'{MCPSourceChannel.SELF_HOSTED}'"), nullable=False, comment="来源渠道") + market_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场id") + market_config_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场配置id") + mcp_service_id = Column(String(255), nullable=True, comment="mcp服务id") # 服务状态 last_health_check = Column(DateTime) diff --git a/api/app/repositories/implicit_emotions_storage_repository.py b/api/app/repositories/implicit_emotions_storage_repository.py index 97405ab6..58e98dfd 100644 --- a/api/app/repositories/implicit_emotions_storage_repository.py +++ b/api/app/repositories/implicit_emotions_storage_repository.py @@ -5,13 +5,15 @@ Implicit Emotions Storage Repository 事务由调用方控制,仓储层只使用 flush/refresh """ import logging -from datetime import datetime, date, timezone, timedelta -from typing import Optional, Generator -from sqlalchemy.orm import Session -from sqlalchemy import select, not_, exists +from datetime import date, datetime, timedelta, timezone +from typing import Generator, Optional + +import redis +from sqlalchemy import exists, not_, select +from sqlalchemy.orm import Session -from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage from app.models.end_user_model import EndUser +from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage logger = logging.getLogger(__name__) @@ -111,6 +113,96 @@ class ImplicitEmotionsStorageRepository: logger.error(f"分批获取用户ID失败: offset={offset}, error={e}") break + def get_users_needing_refresh(self, redis_client: Optional[redis.StrictRedis], batch_size: int = 100) -> Generator[str, None, None]: + """分批次获取需要刷新隐性记忆/情绪数据的存量用户ID。 + + 筛选逻辑: + - 查询 implicit_emotions_storage 中所有用户的 end_user_id 和 updated_at + - 从 Redis 读取 write_message:last_done:{end_user_id} 的时间戳 + - 若 Redis 中无记录(该用户从未写入过记忆),跳过 + - 若 last_done > updated_at,说明上次刷新后又有新记忆写入,需要刷新 + - 若 last_done <= updated_at,说明已是最新,跳过 + + 如果 redis_client 为 None,则降级为返回所有用户(禁用时间过滤)。 + + Args: + redis_client: 同步 redis.StrictRedis 实例(连接 CELERY_BACKEND DB),如果为 None 则禁用时间过滤 + batch_size: 每批次加载的数量 + + Yields: + 需要刷新的用户ID字符串 + """ + from datetime import timezone + + from redis.exceptions import RedisError + + # 如果 Redis 不可用,降级为处理所有用户 + if redis_client is None: + logger.warning( + "Redis 客户端不可用,时间过滤已禁用,将处理所有存量用户" + ) + yield from self.get_all_user_ids(batch_size) + return + + offset = 0 + while True: + try: + stmt = ( + select(ImplicitEmotionsStorage.end_user_id, ImplicitEmotionsStorage.updated_at) + .order_by(ImplicitEmotionsStorage.end_user_id) + .limit(batch_size) + .offset(offset) + ) + batch = self.db.execute(stmt).all() + if not batch: + break + + # 批量获取当前批次所有用户的 last_done 时间戳(一次网络往返) + keys = [f"write_message:last_done:{end_user_id}" for end_user_id, _ in batch] + + try: + raw_values = redis_client.mget(keys) + except RedisError as e: + logger.error( + f"Redis mget 操作失败: {e},当前批次降级为处理所有用户", + extra={"offset": offset, "batch_size": len(batch)} + ) + # Redis 操作失败,降级为返回当前批次所有用户 + yield from (end_user_id for end_user_id, _ in batch) + offset += batch_size + continue + + for (end_user_id, updated_at), raw in zip(batch, raw_values): + if raw is None: + continue + try: + CST = timezone(timedelta(hours=8)) + last_done = datetime.fromisoformat(raw) + # 统一转为 CST naive 时间做比较 + if last_done.tzinfo is None: + last_done = last_done.replace(tzinfo=timezone.utc).astimezone(CST).replace(tzinfo=None) + else: + last_done = last_done.astimezone(CST).replace(tzinfo=None) + + if updated_at is None: + yield end_user_id + continue + # updated_at 同样转为 CST naive + if updated_at.tzinfo is None: + updated_at_cst = updated_at.replace(tzinfo=timezone.utc).astimezone(CST).replace(tzinfo=None) + else: + updated_at_cst = updated_at.astimezone(CST).replace(tzinfo=None) + + if last_done > updated_at_cst: + yield end_user_id + except Exception as e: + logger.warning(f"解析 last_done 时间戳失败: end_user_id={end_user_id}, raw={raw}, error={e}") + + offset += batch_size + except Exception as e: + logger.error(f"get_users_needing_refresh 分批查询失败: offset={offset}, error={e}") + break + def get_new_user_ids_today(self, batch_size: int = 100) -> Generator[str, None, None]: """分批次获取当天新增的、尚未初始化隐性记忆和情绪建议数据的用户ID @@ -124,7 +216,8 @@ class ImplicitEmotionsStorageRepository: Yields: 用户ID字符串 """ - from sqlalchemy import cast, String as SAString + from sqlalchemy import String as SAString + from sqlalchemy import cast CST = timezone(timedelta(hours=8)) now_cst = datetime.now(CST) today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None) diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 2dae51ef..22f13449 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -233,6 +233,7 @@ class MemoryConfigRepository: config_desc=params.config_desc, workspace_id=params.workspace_id, scene_id=params.scene_id, + pruning_scene=params.pruning_scene, llm_id=params.llm_id, embedding_id=params.embedding_id, rerank_id=params.rerank_id, diff --git a/api/app/schemas/conversation_schema.py b/api/app/schemas/conversation_schema.py index 0fcbc718..13766ef6 100644 --- a/api/app/schemas/conversation_schema.py +++ b/api/app/schemas/conversation_schema.py @@ -86,6 +86,7 @@ class ChatResponse(BaseModel): """聊天响应(非流式)""" conversation_id: uuid.UUID message: str + message_id: str usage: Optional[Dict[str, Any]] = None elapsed_time: Optional[float] = None diff --git a/api/app/schemas/memory_config_schema.py b/api/app/schemas/memory_config_schema.py index 0b63844b..0c359d70 100644 --- a/api/app/schemas/memory_config_schema.py +++ b/api/app/schemas/memory_config_schema.py @@ -417,6 +417,7 @@ class MemoryConfig: # Ontology scene association scene_id: Optional[UUID] = None + ontology_classes: Optional[list] = field(default=None) def __post_init__(self): """Validate configuration after initialization.""" diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 776d2783..046b79e7 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -232,14 +232,15 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body, # 本体场景关联(可选) scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表") + # 语义剪枝场景(由 service 层根据 scene_id 自动推导,值为关联场景的 scene_name,前端无需传入) + pruning_scene: Optional[str] = Field(None, description="语义剪枝场景,由 scene_id 对应的 scene_name 自动填充") + # 模型配置字段(可选,用于手动指定或自动填充) llm_id: Optional[str] = Field(None, description="LLM模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致") emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致") - - class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) model_config = ConfigDict(populate_by_name=True, extra="forbid") # config_name: str = Field("配置名称", description="配置名称(字符串)") @@ -274,8 +275,8 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数 # 剪枝配置:与 runtime.json 中 pruning 段对应 pruning_enabled: Optional[bool] = Field(None, description="是否启动智能语义剪枝") - pruning_scene: Optional[Literal["education", "online_service", "outbound"]] = Field( - None, description="智能剪枝场景:education/online_service/outbound" + pruning_scene: Optional[str] = Field( + None, description="智能剪枝场景:education/online_service/outbound 或本体工程自定义场景" ) pruning_threshold: Optional[float] = Field( None, ge=0.0, le=0.9, description="智能语义剪枝阈值(0-0.9)" diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py index ea4183a5..4f3878ce 100644 --- a/api/app/schemas/model_schema.py +++ b/api/app/schemas/model_schema.py @@ -23,6 +23,7 @@ class ModelConfigBase(BaseModel): load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略") capability: List[str] = Field(default_factory=list, description="模型能力列表") is_omni: bool = Field(False, description="是否为Omni模型") + model_id: Optional[uuid.UUID] = Field(None, description="基础模型ID") class ApiKeyCreateNested(BaseModel): diff --git a/api/app/schemas/tool_schema.py b/api/app/schemas/tool_schema.py index 48afe2c3..2ba86c2c 100644 --- a/api/app/schemas/tool_schema.py +++ b/api/app/schemas/tool_schema.py @@ -155,6 +155,10 @@ class MCPToolConfigSchema(BaseModel): health_status: str = "unknown" error_message: Optional[str] = None available_tools: List[Dict[str, Dict[str, Any]]] = Field(default_factory=list, description="工具列表,格式: [{'tool_name': str, 'arguments': dict}]") + source_channel: Optional[str] = Field(None, description="来源渠道") + market_id: Optional[str] = Field(None, description="渠道市场id") + market_config_id: Optional[str] = Field(None, description="渠道市场配置id") + mcp_service_id: Optional[str] = Field(None, description="mcp服务id") class Config: from_attributes = True @@ -192,6 +196,10 @@ class ToolCreateRequest(BaseModel): tool_type: ToolType config: Dict[str, Any] = Field(default_factory=dict) tags: List[str] = Field(default_factory=list) + source_channel: Optional[str] = Field(None, description="来源渠道(仅MCP工具)") + market_id: Optional[str] = Field(None, description="渠道市场id(仅MCP工具)") + market_config_id: Optional[str] = Field(None, description="渠道市场配置id(仅MCP工具)") + mcp_service_id: Optional[str] = Field(None, description="mcp服务id(仅MCP工具)") class ToolUpdateRequest(BaseModel): diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 5430d2f9..f3cdde2a 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -144,7 +144,7 @@ class AppChatService: ) # 保存消息 - self.conversation_service.save_conversation_messages( + message_id = self.conversation_service.save_conversation_messages( conversation_id=conversation_id, user_message=message, assistant_message=result["content"], @@ -163,6 +163,7 @@ class AppChatService: return { "conversation_id": conversation_id, + "message_id": str(message_id), "message": result["content"], "usage": result.get("usage", { "prompt_tokens": 0, @@ -191,7 +192,11 @@ class AppChatService: try: start_time = time.time() config_id = None - yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" + message_id = uuid.uuid4() + yield f"event: start\ndata: {json.dumps({ + 'conversation_id': str(conversation_id), + "message_id": str(message_id) + }, ensure_ascii=False)}\n\n" variables = self.agent_service.prepare_variables(variables, config.variables) # 获取模型配置ID @@ -296,6 +301,7 @@ class AppChatService: ) self.conversation_service.add_message( + message_id=message_id, conversation_id=conversation_id, role="assistant", content=full_content, @@ -373,7 +379,7 @@ class AppChatService: content=message ) - self.conversation_service.add_message( + ai_message = self.conversation_service.add_message( conversation_id=conversation_id, role="assistant", content=result.get("message", ""), @@ -391,6 +397,7 @@ class AppChatService: return { "conversation_id": conversation_id, "message": result.get("message", ""), + "message_id": str(ai_message.id), "usage": { "prompt_tokens": 0, "completion_tokens": 0, @@ -419,9 +426,9 @@ class AppChatService: variables = {} try: - + message_id = uuid.uuid4() # 发送开始事件 - yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" + yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), "message_id": str(message_id)}, ensure_ascii=False)}\n\n" full_content = "" total_tokens = 0 @@ -429,6 +436,7 @@ class AppChatService: # 2. 创建编排器 orchestrator = MultiAgentOrchestrator(self.db, config) + # 3. 流式执行任务 async for event in orchestrator.execute_stream( message=message, @@ -472,6 +480,7 @@ class AppChatService: ) self.conversation_service.add_message( + message_id=message_id, conversation_id=conversation_id, role="assistant", content=full_content, diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index a248f869..5a799937 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -703,7 +703,7 @@ class AppService: self.db.flush() # 如果是 agent 类型,复制 AgentConfig - if source_app.type == "agent": + if source_app.type == AppType.AGENT: source_config = self.db.query(AgentConfig).filter( AgentConfig.app_id == source_app.id ).first() @@ -725,6 +725,50 @@ class AppService: ) self.db.add(new_config) + elif source_app.type == AppType.WORKFLOW: + source_config = self.db.query(WorkflowConfig).filter( + WorkflowConfig.app_id == source_app.id + ).first() + + if source_config: + new_config = WorkflowConfig( + id=uuid.uuid4(), + app_id=new_app.id, + nodes=source_config.nodes.copy() if source_config.nodes else [], + edges=source_config.edges.copy() if source_config.edges else [], + variables=source_config.variables.copy() if source_config.variables else [], + execution_config=source_config.execution_config.copy() if source_config.execution_config else {}, + triggers=source_config.triggers.copy() if source_config.triggers else [], + is_active=True, + created_at=now, + updated_at=now, + ) + self.db.add(new_config) + + elif source_app.type == AppType.MULTI_AGENT: + source_config = self.db.query(MultiAgentConfig).filter( + MultiAgentConfig.app_id == source_app.id + ).first() + + if source_config: + new_config = MultiAgentConfig( + id=uuid.uuid4(), + app_id=new_app.id, + master_agent_id=source_config.master_agent_id, + master_agent_name=source_config.master_agent_name, + default_model_config_id=source_config.default_model_config_id, + model_parameters=source_config.model_parameters, + orchestration_mode=source_config.orchestration_mode, + sub_agents=source_config.sub_agents.copy() if source_config.sub_agents else [], + routing_rules=source_config.routing_rules.copy() if source_config.routing_rules else None, + execution_config=source_config.execution_config.copy() if source_config.execution_config else {}, + aggregation_strategy=source_config.aggregation_strategy, + is_active=True, + created_at=now, + updated_at=now, + ) + self.db.add(new_config) + self.db.commit() self.db.refresh(new_app) diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index 553aefc4..aff5f533 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -178,7 +178,8 @@ class ConversationService: conversation_id: uuid.UUID, role: str, content: str, - meta_data: Optional[dict] = None + meta_data: Optional[dict] = None, + message_id: Optional[uuid.UUID] = None, ) -> Message: """ Add a message to a conversation using UnitOfWork. @@ -188,6 +189,7 @@ class ConversationService: role (str): Role of the message sender ('user' or 'assistant'). content (str): Message content. meta_data (Optional[dict]): Optional metadata. + message_id (Optional[uuid.UUID]): Optional custom message UUID. Returns: Message: Newly created Message instance. @@ -198,6 +200,7 @@ class ConversationService: ) message = Message( + id=message_id if message_id else uuid.uuid4(), conversation_id=conversation_id, role=role, content=content, @@ -317,7 +320,7 @@ class ConversationService: content=user_message ) - self.add_message( + ai_message = self.add_message( conversation_id=conversation_id, role="assistant", content=assistant_message, @@ -332,6 +335,7 @@ class ConversationService: "assistant_message_length": len(assistant_message) } ) + return ai_message.id def delete_conversation( self, diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index ccfd5482..00757f8c 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -107,6 +107,40 @@ def _validate_config_id(config_id, db: Session = None): ) +# 专门场景的内置 key 集合,直接从 SceneConfigRegistry 派生,避免重复维护 +# 使用懒加载函数避免模块级循环导入 +def _get_builtin_pruning_scenes() -> set: + from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import SceneConfigRegistry + return set(SceneConfigRegistry.get_all_scenes()) + + +def _load_ontology_classes(db: Session, scene_id, pruning_scene: Optional[str]) -> Optional[list]: + """当 pruning_scene 不是内置场景时,从 ontology_class 表加载类型名称列表。 + + Args: + db: 数据库会话 + scene_id: 本体场景 UUID + pruning_scene: 语义剪枝场景名称 + + Returns: + class_name 字符串列表,或 None(内置场景 / 无数据时) + """ + if not scene_id: + return None + # 内置场景走 SceneConfigRegistry,不需要注入类型列表 + if pruning_scene in _get_builtin_pruning_scenes(): + return None + try: + from app.repositories.ontology_class_repository import OntologyClassRepository + repo = OntologyClassRepository(db) + classes = repo.get_classes_by_scene(scene_id) + names = [c.class_name for c in classes if c.class_name] + return names if names else None + except Exception as e: + logger.warning(f"Failed to load ontology classes for scene_id={scene_id}: {e}") + return None + + class MemoryConfigService: """ Centralized service for memory configuration loading and validation. @@ -359,6 +393,7 @@ class MemoryConfigService: pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, # Ontology scene association scene_id=memory_config.scene_id, + ontology_classes=_load_ontology_classes(self.db, memory_config.scene_id, memory_config.pruning_scene), ) elapsed_ms = (time.time() - start_time) * 1000 diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 02fd1051..6e7c1ad4 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -146,6 +146,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) if not params.emotion_model_id: params.emotion_model_id = params.llm_id + # 根据关联的本体场景推导 pruning_scene(语义剪枝场景与本体工程场景保持一致) + if params.scene_id and not getattr(params, 'pruning_scene', None): + params.pruning_scene = self._resolve_pruning_scene_from_scene_id(params.scene_id) + config = MemoryConfigRepository.create(self.db, params) self.db.commit() return {"affected": 1, "config_id": config.config_id} @@ -161,6 +165,23 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) finally: db_session.close() + def _resolve_pruning_scene_from_scene_id(self, scene_id) -> Optional[str]: + """根据本体场景ID获取对应的 scene_name,作为语义剪枝场景值 + + Args: + scene_id: 本体场景UUID + + Returns: + scene_name 字符串,查询失败时返回 None + """ + try: + from app.models.ontology_scene import OntologyScene + scene = self.db.query(OntologyScene).filter_by(scene_id=scene_id).first() + return scene.scene_name if scene else None + except Exception as e: + logger.warning(f"_resolve_pruning_scene_from_scene_id failed for scene_id={scene_id}: {e}", exc_info=True) + return None + # --- Delete --- def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID) success = MemoryConfigRepository.delete(self.db, key.config_id) @@ -196,6 +217,19 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 results = MemoryConfigRepository.get_all(self.db, workspace_id) + # 检查并修正 pruning_scene 与 scene_name 不一致的记录 + needs_commit = False + for config, scene_name in results: + if scene_name and config.pruning_scene != scene_name: + logger.info( + f"修正 pruning_scene: config_id={config.config_id} " + f"'{config.pruning_scene}' -> '{scene_name}'" + ) + config.pruning_scene = scene_name + needs_commit = True + if needs_commit: + self.db.commit() + # 将 ORM 对象转换为字典列表 data_list = [] for config, scene_name in results: @@ -749,8 +783,37 @@ async def analytics_hot_memory_tags( await connector.close() -async def analytics_recent_activity_stats() -> Dict[str, Any]: - stats, _msg = get_recent_activity_stats() +async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) -> Dict[str, Any]: + """获取最近记忆提取活动统计。 + + 优先从 Redis 缓存读取(按 workspace_id),缓存不存在时降级到日志文件解析。 + + Args: + workspace_id: 工作空间ID,用于从 Redis 读取对应缓存 + + Returns: + 包含 total、stats、latest_relative、source 的统计字典 + """ + stats = None + source = "log" + + # 优先从 Redis 读取 + if workspace_id: + try: + from app.cache.memory.activity_stats_cache import ActivityStatsCache + cached = await ActivityStatsCache.get_activity_stats(workspace_id) + if cached: + stats = cached.get("stats", {}) + source = "redis" + logger.info(f"[ANALYTICS] 从 Redis 读取活动统计: workspace_id={workspace_id}") + except Exception as e: + logger.warning(f"[ANALYTICS] 读取 Redis 活动统计失败,降级到日志: {e}") + + # 降级:从日志文件解析 + if stats is None: + stats, _msg = get_recent_activity_stats() + source = "log" + total = ( stats.get("chunk_count", 0) + stats.get("statements_count", 0) @@ -758,26 +821,29 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]: + stats.get("triplet_relations_count", 0) + stats.get("temporal_count", 0) ) - # 精简:仅提供“最新一次活动多久前” - latest_relative = None - try: - info = stats.get("log_path", "") - idx = info.rfind("最新:") - if idx != -1: - latest_path = info[idx + 3 :].strip() - if latest_path and os.path.exists(latest_path): - import time - diff = max(0.0, time.time() - os.path.getmtime(latest_path)) - m = int(diff // 60) - if m < 1: - latest_relative = "刚刚" - elif m < 60: - latest_relative = "一会前" - else: - latest_relative = "较早前" - except Exception: - pass - data = {"total": total, "stats": stats, "latest_relative": latest_relative} + # 计算"最新一次活动多久前"(仅日志来源时有效) + latest_relative = None + if source == "log": + try: + info = stats.get("log_path", "") + idx = info.rfind("最新:") + if idx != -1: + latest_path = info[idx + 3:].strip() + if latest_path and os.path.exists(latest_path): + import time + diff = max(0.0, time.time() - os.path.getmtime(latest_path)) + m = int(diff // 60) + if m < 1: + latest_relative = "刚刚" + elif m < 60: + latest_relative = "一会前" + else: + latest_relative = "较早前" + except Exception: + pass + + data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source} return data + diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index cba25f32..a7398504 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -780,6 +780,7 @@ class ModelBaseService: "description": model_base.description, "capability": model_base.capability, "is_omni": model_base.is_omni, + "is_active": False, "is_composite": False } model_config = ModelConfigRepository.create(db, model_config_data) diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index 4d9cbb5e..5d00d8a5 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -326,6 +326,25 @@ async def run_pilot_extraction( logger.info("Pilot run completed: Skipping Neo4j save") + # 将提取统计写入 Redis,按 workspace_id 存储 + try: + from app.cache.memory.activity_stats_cache import ActivityStatsCache + + stats_to_cache = { + "chunk_count": len(chunk_nodes) if chunk_nodes else 0, + "statements_count": len(statement_nodes) if statement_nodes else 0, + "triplet_entities_count": len(entity_nodes) if entity_nodes else 0, + "triplet_relations_count": len(entity_edges) if entity_edges else 0, + "temporal_count": 0, # temporal 数据在日志中,此处暂置0 + } + await ActivityStatsCache.set_activity_stats( + workspace_id=str(memory_config.workspace_id), + stats=stats_to_cache, + ) + logger.info(f"[PILOT_RUN] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}") + except Exception as cache_err: + logger.warning(f"[PILOT_RUN] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) + except Exception as e: logger.error(f"Pilot run failed: {e}", exc_info=True) raise diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index f6e2ccce..4fe1e9e6 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -85,7 +85,7 @@ class ToolService: """检查工具名称是否重复""" query = self.db.query(ToolConfig).filter( ToolConfig.name == name, - ToolConfig.tool_type == tool_type.value, + ToolConfig.tool_type == tool_type, ToolConfig.tenant_id == tenant_id ) if exclude_id: @@ -910,7 +910,11 @@ class ToolService: config_data.update({ "last_health_check": int(mcp_config.last_health_check.timestamp() * 1000) if mcp_config.last_health_check else None, "health_status": mcp_config.health_status, - "available_tools": available_tools_display + "available_tools": available_tools_display, + "source_channel": mcp_config.source_channel, + "market_id": mcp_config.market_id, + "market_config_id": mcp_config.market_config_id, + "mcp_service_id": mcp_config.mcp_service_id }) return ToolInfo( @@ -965,7 +969,11 @@ class ToolService: id=tool_config.id, server_url=config.get("server_url"), connection_config=config.get("connection_config", {}), - available_tools=config.get("available_tools", []) + available_tools=config.get("available_tools", []), + source_channel=config.get("source_channel", "self_hosted"), + market_id=config.get("market_id"), + market_config_id=config.get("market_config_id"), + mcp_service_id=config.get("mcp_service_id"), ) self.db.add(mcp_config) @@ -1018,6 +1026,14 @@ class ToolService: mcp_config.server_url = config.get("server_url") mcp_config.connection_config = config.get("connection_config", {}) mcp_config.available_tools = config.get("available_tools", []) + if config.get("source_channel") is not None: + mcp_config.source_channel = config.get("source_channel") + if config.get("market_id") is not None: + mcp_config.market_id = config.get("market_id") + if config.get("market_config_id") is not None: + mcp_config.market_config_id = config.get("market_config_id") + if config.get("mcp_service_id") is not None: + mcp_config.mcp_service_id = config.get("mcp_service_id") @staticmethod def _determine_initial_status(tool_info: Dict[str, Any]) -> str: diff --git a/api/app/services/workflow_import_service.py b/api/app/services/workflow_import_service.py index 2e17f404..2b36c5ea 100644 --- a/api/app/services/workflow_import_service.py +++ b/api/app/services/workflow_import_service.py @@ -56,7 +56,7 @@ class WorkflowImportService: success=False, temp_id=None, workflow_id=None, - errors=[InvalidConfiguration()] + errors=[InvalidConfiguration()] + adapter.errors ) workflow_config = adapter.parse_workflow() diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index d13e3454..eaf78b90 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -25,7 +25,7 @@ from app.repositories.workflow_repository import ( WorkflowExecutionRepository, WorkflowNodeExecutionRepository ) -from app.schemas import DraftRunRequest, FileInput +from app.schemas import DraftRunRequest, FileInput, FileType from app.services.conversation_service import ConversationService from app.services.multi_agent_service import convert_uuids_to_str from app.services.multimodal_service import MultimodalService @@ -496,6 +496,7 @@ class WorkflowService: "event": "start", "data": { "conversation_id": payload.get("conversation_id"), + "message_id": payload.get("message_id") } } case "workflow_end": @@ -600,6 +601,7 @@ class WorkflowService: try: files = await self._handle_file_input(payload.files) input_data["files"] = files + message_id = uuid.uuid4() # 更新状态为运行中 self.update_execution_status(execution.execution_id, "running") @@ -624,24 +626,45 @@ class WorkflowService: workspace_id=str(workspace_id), user_id=payload.user_id ) - # 更新执行结果 if result.get("status") == "completed": token_usage = result.get("token_usage", {}) or {} + + final_messages = result.get("messages", [])[init_message_length:] + human_message = "" + assistant_message = "" + for message in final_messages: + if message["role"] == "user": + if isinstance(message["content"], str): + human_message += message["content"] + elif isinstance(message["content"], list): + for file in message["content"]: + if file.get("type") == FileType.IMAGE: + human_message += f"![image]({file.get('url', '')})" + else: + human_message += f"[{file.get('type')}]({file.get('url', '')})" + if message["role"] == "assistant": + assistant_message = message["content"] + self.conversation_service.add_message( + conversation_id=conversation_id_uuid, + role="user", + content=human_message, + meta_data=None + ) + self.conversation_service.add_message( + message_id=message_id, + conversation_id=conversation_id_uuid, + role="assistant", + content=assistant_message, + meta_data={"usage": token_usage} + ) self.update_execution_status( execution.execution_id, "completed", output_data=result, token_usage=token_usage.get("total_tokens", None) ) - final_messages = result.get("messages", [])[init_message_length:] - for message in final_messages: - self.conversation_service.add_message( - conversation_id=conversation_id_uuid, - role=message["role"], - content=message["content"], - meta_data=None if message["role"] == "user" else {"usage": token_usage} - ) + logger.info(f"Workflow Run Success, " f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") else: @@ -650,6 +673,8 @@ class WorkflowService: "failed", error_message=result.get("error") ) + logger.error(f"Workflow Run Failed, execution_id: {execution.execution_id}," + f" error: {result.get('error')}") # 返回增强的响应结构 return { @@ -659,6 +684,7 @@ class WorkflowService: # "messages": result.get("messages"), "output": result.get("output"), # 最终输出(字符串) "message": result.get("output"), # 最终输出(字符串) + "message_id": str(message_id), # "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID "error_message": result.get("error"), @@ -756,7 +782,7 @@ class WorkflowService: input_data["conv_messages"] = last_state.get("messages") or [] break init_message_length = len(input_data.get("conv_messages", [])) - + message_id = uuid.uuid4() async for event in execute_workflow_stream( workflow_config=workflow_config_dict, input_data=input_data, @@ -765,24 +791,43 @@ class WorkflowService: user_id=payload.user_id, ): if event.get("event") == "workflow_end": - status = event.get("data", {}).get("status") token_usage = event.get("data", {}).get("token_usage", {}) or {} if status == "completed": + final_messages = event.get("data", {}).get("messages", [])[init_message_length:] + human_message = "" + assistant_message = "" + for message in final_messages: + if message["role"] == "user": + if isinstance(message["content"], str): + human_message += message["content"] + elif isinstance(message["content"], list): + for file in message["content"]: + if file.get("type") == FileType.IMAGE: + human_message += f"![image]({file.get('url', '')})" + else: + human_message += f"[{file.get('type')}]({file.get('url', '')})" + if message["role"] == "assistant": + assistant_message = message["content"] + self.conversation_service.add_message( + conversation_id=conversation_id_uuid, + role="user", + content=human_message, + meta_data=None + ) + self.conversation_service.add_message( + message_id=message_id, + conversation_id=conversation_id_uuid, + role="assistant", + content=assistant_message, + meta_data={"usage": token_usage} + ) self.update_execution_status( execution.execution_id, "completed", output_data=event.get("data"), token_usage=token_usage.get("total_tokens", None) ) - final_messages = event.get("data", {}).get("messages", [])[init_message_length:] - for message in final_messages: - self.conversation_service.add_message( - conversation_id=conversation_id_uuid, - role=message["role"], - content=message["content"], - meta_data=None if message["role"] == "user" else {"usage": token_usage} - ) logger.info(f"Workflow Run Success, " f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") elif status == "failed": @@ -793,6 +838,8 @@ class WorkflowService: ) else: logger.error(f"unexpect workflow run status, status: {status}") + elif event.get("event") == "workflow_start": + event["data"]["message_id"] = str(message_id) event = self._emit(public, event) if event: yield event diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py index e93c0c5c..7861ef62 100644 --- a/api/app/services/workspace_service.py +++ b/api/app/services/workspace_service.py @@ -130,6 +130,7 @@ def _create_workspace_only( business_logger.error(f"创建工作空间失败: {workspace.name} - {str(e)}") raise + def create_workspace( db: Session, workspace: WorkspaceCreate, user: User, language: str = "zh" ) -> Workspace: @@ -152,6 +153,7 @@ def create_workspace( # Initialize default ontology scenes for the workspace (先创建本体场景) default_scene_id = None + default_scene_name = None try: initializer = DefaultOntologyInitializer(db) success, error_msg = initializer.initialize_default_scenes( @@ -163,7 +165,7 @@ def create_workspace( f"为工作空间 {db_workspace.id} 创建默认本体场景成功 (language={language})" ) - # 获取默认场景ID,优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景 + # 获取默认场景ID,优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景 from app.repositories.ontology_scene_repository import OntologySceneRepository from app.config.default_ontology_config import ( ONLINE_EDUCATION_SCENE, @@ -179,6 +181,7 @@ def create_workspace( if education_scene: default_scene_id = education_scene.scene_id + default_scene_name = education_scene.scene_name business_logger.info( f"获取到教育场景ID用于默认记忆配置: {default_scene_id} (scene_name={education_scene_name})" ) @@ -189,6 +192,7 @@ def create_workspace( if companion_scene: default_scene_id = companion_scene.scene_id + default_scene_name = companion_scene.scene_name business_logger.info( f"教育场景不存在,使用情感陪伴场景ID用于默认记忆配置: {default_scene_id} (scene_name={companion_scene_name})" ) @@ -219,6 +223,7 @@ def create_workspace( embedding_id=embedding, rerank_id=rerank, scene_id=default_scene_id, # 传入默认场景ID(优先教育场景,其次情感陪伴场景) + pruning_scene_name=default_scene_name, # 传入场景名称作为语义剪枝场景值 ) business_logger.info( f"为工作空间 {db_workspace.id} 创建默认记忆配置成功 (scene_id={default_scene_id})" @@ -962,6 +967,125 @@ def update_workspace_models_configs( raise BusinessException(f"更新模型配置失败: {str(e)}", BizCode.INTERNAL_ERROR) +def _fill_workspace_configs_model_defaults( + db: Session, + workspace: Workspace +) -> None: + """Fill empty model fields for all memory configs in a workspace. + + Updates llm_id, embedding_id, rerank_id, reflection_model_id, and emotion_model_id + if they are None, using the corresponding workspace default models. + + Args: + db: Database session + workspace: The workspace containing default model settings + """ + from app.models.memory_config_model import MemoryConfig + + # Get all configs for this workspace + configs = db.query(MemoryConfig).filter( + MemoryConfig.workspace_id == workspace.id + ).all() + + if not configs: + return + + # Map of memory_config field -> workspace field + model_field_mappings = [ + ("llm_id", "llm"), + ("embedding_id", "embedding"), + ("rerank_id", "rerank"), + ("reflection_model_id", "llm"), # reflection uses LLM + ("emotion_model_id", "llm"), # emotion uses LLM + ] + + configs_updated = 0 + + for memory_config in configs: + updated_fields = [] + + for config_field, workspace_field in model_field_mappings: + config_value = getattr(memory_config, config_field, None) + workspace_value = getattr(workspace, workspace_field, None) + + if not config_value and workspace_value: + setattr(memory_config, config_field, workspace_value) + updated_fields.append(config_field) + + if updated_fields: + configs_updated += 1 + business_logger.debug( + f"Updated memory config {memory_config.config_id} fields: {updated_fields}" + ) + + if configs_updated > 0: + try: + db.commit() + business_logger.info( + f"Updated {configs_updated} memory configs in workspace {workspace.id} with default models" + ) + except Exception as e: + db.rollback() + business_logger.error( + f"Failed to update memory configs in workspace {workspace.id}: {str(e)}" + ) + + +def _create_default_memory_config( + db: Session, + workspace_id: uuid.UUID, + workspace_name: str, + llm_id: Optional[uuid.UUID] = None, + embedding_id: Optional[uuid.UUID] = None, + rerank_id: Optional[uuid.UUID] = None, + scene_id: Optional[uuid.UUID] = None, + pruning_scene_name: Optional[str] = None, +) -> None: + """Create a default memory config for a newly created workspace. + + Args: + db: Database session + workspace_id: The workspace ID + workspace_name: The workspace name (used for config naming) + llm_id: Optional LLM model ID + embedding_id: Optional embedding model ID + rerank_id: Optional rerank model ID + scene_id: Optional ontology scene ID (默认关联教育场景) + pruning_scene_name: Optional pruning scene name,取自 ontology_scene.scene_name + """ + from app.models.memory_config_model import MemoryConfig + + config_id = uuid.uuid4() + + default_config = MemoryConfig( + config_id=config_id, + config_name=f"{workspace_name} 默认配置", + config_desc="工作空间创建时自动生成的默认记忆配置", + workspace_id=workspace_id, + llm_id=str(llm_id) if llm_id else None, + embedding_id=str(embedding_id) if embedding_id else None, + rerank_id=str(rerank_id) if rerank_id else None, + scene_id=scene_id, # 关联本体场景ID(默认为"在线教育"场景) + pruning_scene=pruning_scene_name, # 语义剪枝场景直接使用 scene_name + state=True, # Active by default + is_default=True, # Mark as workspace default + ) + + db.add(default_config) + db.flush() # 使用 flush 而不是 commit,让调用者统一提交 + + business_logger.info( + "Created default memory config for workspace", + extra={ + "workspace_id": str(workspace_id), + "config_id": str(config_id), + "config_name": default_config.config_name, + "scene_id": str(scene_id) if scene_id else None, + } + ) + +# ==================== 检查配置相关服务 ==================== + def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None: """Ensure a workspace has a default memory config, creating one if missing. @@ -1041,70 +1165,6 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None: _fill_workspace_configs_model_defaults(db, workspace) -def _fill_workspace_configs_model_defaults( - db: Session, - workspace: Workspace -) -> None: - """Fill empty model fields for all memory configs in a workspace. - - Updates llm_id, embedding_id, rerank_id, reflection_model_id, and emotion_model_id - if they are None, using the corresponding workspace default models. - - Args: - db: Database session - workspace: The workspace containing default model settings - """ - from app.models.memory_config_model import MemoryConfig - - # Get all configs for this workspace - configs = db.query(MemoryConfig).filter( - MemoryConfig.workspace_id == workspace.id - ).all() - - if not configs: - return - - # Map of memory_config field -> workspace field - model_field_mappings = [ - ("llm_id", "llm"), - ("embedding_id", "embedding"), - ("rerank_id", "rerank"), - ("reflection_model_id", "llm"), # reflection uses LLM - ("emotion_model_id", "llm"), # emotion uses LLM - ] - - configs_updated = 0 - - for memory_config in configs: - updated_fields = [] - - for config_field, workspace_field in model_field_mappings: - config_value = getattr(memory_config, config_field, None) - workspace_value = getattr(workspace, workspace_field, None) - - if not config_value and workspace_value: - setattr(memory_config, config_field, workspace_value) - updated_fields.append(config_field) - - if updated_fields: - configs_updated += 1 - business_logger.debug( - f"Updated memory config {memory_config.config_id} fields: {updated_fields}" - ) - - if configs_updated > 0: - try: - db.commit() - business_logger.info( - f"Updated {configs_updated} memory configs in workspace {workspace.id} with default models" - ) - except Exception as e: - db.rollback() - business_logger.error( - f"Failed to update memory configs in workspace {workspace.id}: {str(e)}" - ) - - def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None: """Ensure a workspace has default ontology scenes, creating them if missing. @@ -1150,53 +1210,3 @@ def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None: f"为工作空间 {workspace.id} 补建默认本体场景异常: {str(e)}" ) - -def _create_default_memory_config( - db: Session, - workspace_id: uuid.UUID, - workspace_name: str, - llm_id: Optional[uuid.UUID] = None, - embedding_id: Optional[uuid.UUID] = None, - rerank_id: Optional[uuid.UUID] = None, - scene_id: Optional[uuid.UUID] = None, -) -> None: - """Create a default memory config for a newly created workspace. - - Args: - db: Database session - workspace_id: The workspace ID - workspace_name: The workspace name (used for config naming) - llm_id: Optional LLM model ID - embedding_id: Optional embedding model ID - rerank_id: Optional rerank model ID - scene_id: Optional ontology scene ID (默认关联教育场景) - """ - from app.models.memory_config_model import MemoryConfig - - config_id = uuid.uuid4() - - default_config = MemoryConfig( - config_id=config_id, - config_name=f"{workspace_name} 默认配置", - config_desc="工作空间创建时自动生成的默认记忆配置", - workspace_id=workspace_id, - llm_id=str(llm_id) if llm_id else None, - embedding_id=str(embedding_id) if embedding_id else None, - rerank_id=str(rerank_id) if rerank_id else None, - scene_id=scene_id, # 关联本体场景ID - state=True, # Active by default - is_default=True, # Mark as workspace default - ) - - db.add(default_config) - db.flush() # 使用 flush 而不是 commit,让调用者统一提交 - - business_logger.info( - "Created default memory config for workspace", - extra={ - "workspace_id": str(workspace_id), - "config_id": str(config_id), - "config_name": default_config.config_name, - "scene_id": str(scene_id) if scene_id else None, - } - ) diff --git a/api/app/tasks.py b/api/app/tasks.py index a6ebbb8e..5958d77d 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,5 +1,6 @@ import asyncio import json +import logging import os import re import shutil @@ -14,6 +15,62 @@ from uuid import UUID import redis import requests +from redis.exceptions import RedisError + +logger = logging.getLogger(__name__) + +# 模块级同步 Redis 连接池,供 Celery 任务共享使用 +# 连接 CELERY_BACKEND DB,与 write_message:last_done 时间戳写入保持一致 +# 使用连接池而非单例客户端,提供更好的并发性能和自动重连 +_sync_redis_pool: redis.ConnectionPool = None + +def _get_or_create_redis_pool() -> redis.ConnectionPool: + """获取或创建 Redis 连接池(懒初始化)""" + global _sync_redis_pool + if _sync_redis_pool is None: + try: + _sync_redis_pool = redis.ConnectionPool( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB_CELERY_BACKEND, + password=settings.REDIS_PASSWORD, + decode_responses=True, + max_connections=10, + socket_connect_timeout=5, + socket_timeout=5, + retry_on_timeout=True, + health_check_interval=30, + ) + logger.info("Redis connection pool created for Celery tasks") + except Exception as e: + logger.error(f"Failed to create Redis connection pool: {e}", exc_info=True) + return None + return _sync_redis_pool + +def get_sync_redis_client() -> Optional[redis.StrictRedis]: + """获取同步 Redis 客户端(使用连接池) + + 使用连接池提供的客户端,支持自动重连和健康检查。 + 如果 Redis 不可用,返回 None,调用方应优雅降级。 + + Returns: + redis.StrictRedis: Redis 客户端实例,如果连接失败则返回 None + """ + try: + pool = _get_or_create_redis_pool() + if pool is None: + return None + + client = redis.StrictRedis(connection_pool=pool) + # 验证连接可用性 + client.ping() + return client + except RedisError as e: + logger.error(f"Redis connection failed: {e}", exc_info=True) + return None + except Exception as e: + logger.error(f"Unexpected error getting Redis client: {e}", exc_info=True) + return None # Import a unified Celery instance from app.celery_app import celery_app @@ -1090,6 +1147,22 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s logger.info( f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") + # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 + try: + _r = get_sync_redis_client() + if _r is not None: + from datetime import timedelta as _td + from datetime import timezone as _tz + _CST = _tz(_td(hours=8)) + _now_cst = datetime.now(_CST).replace(tzinfo=None).isoformat() + _r.set( + f"write_message:last_done:{end_user_id}", + _now_cst, + ex=86400 * 30, + ) + except Exception as _e: + logger.warning(f"[CELERY WRITE] 写入 last_done 时间戳失败(不影响主流程): {_e}") + return { "status": "SUCCESS", "result": result, @@ -2149,12 +2222,15 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: start_time = time.time() async def _run() -> Dict[str, Any]: + from sqlalchemy import func, select + from app.core.logging_config import get_logger - from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage - from sqlalchemy import select, func - from app.services.implicit_memory_service import ImplicitMemoryService + from app.repositories.implicit_emotions_storage_repository import ( + ImplicitEmotionsStorageRepository, + ) from app.services.emotion_analytics_service import EmotionAnalyticsService + from app.services.implicit_memory_service import ImplicitMemoryService logger = get_logger(__name__) logger.info("开始执行隐性记忆和情绪数据更新定时任务") @@ -2167,18 +2243,20 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: with get_db_context() as db: try: - # 获取所有已存储数据的用户ID(分批次处理) repo = ImplicitEmotionsStorageRepository(db) - + # 先统计总数用于日志 from sqlalchemy import func total_users = db.execute( select(func.count()).select_from(ImplicitEmotionsStorage) ).scalar() or 0 - logger.info(f"找到 {total_users} 个需要更新的用户") + logger.info(f"表中存量用户总数: {total_users},开始时间轴筛选") - # 遍历每个用户并更新数据(分批次,避免一次性加载所有ID) - for end_user_id in repo.get_all_user_ids(batch_size=100): + # 构建 Redis 同步客户端,用于时间轴筛选 + _redis_client = get_sync_redis_client() + + # 只处理 last_done > updated_at 的用户(有新记忆写入的用户) + for end_user_id in repo.get_users_needing_refresh(_redis_client, batch_size=100): logger.info(f"开始处理用户: {end_user_id}") user_start_time = time.time() @@ -2264,10 +2342,10 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: user_results.append(error_info) logger.error(f"处理用户 {end_user_id} 时出错: {str(e)}") - # ---- 处理增量用户(当天新增、尚未初始化的用户)---- + # ---- 当天新增用户兜底初始化 ---- new_users_initialized = 0 new_users_failed = 0 - logger.info("开始处理当天新增的增量用户初始化") + logger.info("开始处理当天新增用户的兜底初始化") for end_user_id in repo.get_new_user_ids_today(batch_size=100): logger.info(f"开始初始化新用户: {end_user_id}") @@ -2281,35 +2359,27 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id) profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id) await implicit_service.save_profile_cache( - end_user_id=end_user_id, - profile_data=profile_data, - db=db + end_user_id=end_user_id, profile_data=profile_data, db=db ) implicit_success = True logger.info(f"成功初始化新用户 {end_user_id} 的隐性记忆画像") except Exception as e: - error_msg = f"隐性记忆初始化失败: {str(e)}" - errors.append(error_msg) - logger.error(f"新用户 {end_user_id} {error_msg}") + errors.append(f"隐性记忆初始化失败: {str(e)}") + logger.error(f"新用户 {end_user_id} 隐性记忆初始化失败: {e}") try: emotion_service = EmotionAnalyticsService() suggestions_data = await emotion_service.generate_emotion_suggestions( - end_user_id=end_user_id, - db=db, - language="zh" + end_user_id=end_user_id, db=db, language="zh" ) await emotion_service.save_suggestions_cache( - end_user_id=end_user_id, - suggestions_data=suggestions_data, - db=db + end_user_id=end_user_id, suggestions_data=suggestions_data, db=db ) emotion_success = True logger.info(f"成功初始化新用户 {end_user_id} 的情绪建议") except Exception as e: - error_msg = f"情绪建议初始化失败: {str(e)}" - errors.append(error_msg) - logger.error(f"新用户 {end_user_id} {error_msg}") + errors.append(f"情绪建议初始化失败: {str(e)}") + logger.error(f"新用户 {end_user_id} 情绪建议初始化失败: {e}") if implicit_success or emotion_success: new_users_initialized += 1 @@ -2319,7 +2389,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: user_elapsed = time.time() - user_start_time user_results.append({ "end_user_id": end_user_id, - "type": "init", + "type": "new_user_init", "implicit_success": implicit_success, "emotion_success": emotion_success, "errors": errors, @@ -2331,7 +2401,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: user_elapsed = time.time() - user_start_time user_results.append({ "end_user_id": end_user_id, - "type": "init", + "type": "new_user_init", "implicit_success": False, "emotion_success": False, "errors": [str(e)], @@ -2339,27 +2409,24 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: }) logger.error(f"初始化新用户 {end_user_id} 时出错: {str(e)}") - logger.info( - f"增量用户初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}" - ) - # ---- 增量用户处理结束 ---- + logger.info(f"当天新增用户兜底初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}") + # ---- 新增用户兜底初始化结束 ---- - # 记录总体统计信息 logger.info( f"隐性记忆和情绪数据更新定时任务完成: " f"存量用户总数={total_users}, " f"隐性记忆成功={successful_implicit}, " f"情绪建议成功={successful_emotion}, " f"存量失败={failed}, " - f"增量初始化成功={new_users_initialized}, " - f"增量初始化失败={new_users_failed}" + f"新增用户初始化成功={new_users_initialized}, " + f"新增用户初始化失败={new_users_failed}" ) return { "status": "SUCCESS", "message": ( f"存量用户 {total_users} 个,隐性记忆 {successful_implicit} 个成功,情绪建议 {successful_emotion} 个成功;" - f"增量新用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败" + f"当天新增用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败" ), "total_users": total_users, "successful_implicit": successful_implicit, @@ -2367,7 +2434,7 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: "failed": failed, "new_users_initialized": new_users_initialized, "new_users_failed": new_users_failed, - "user_results": user_results[:50] # 只保留前50个用户的详细结果 + "user_results": user_results[:50] } except Exception as e: @@ -2416,3 +2483,125 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: "elapsed_time": elapsed_time, "task_id": self.request.id } + + +# ============================================================================= + +@celery_app.task( + name="app.tasks.init_implicit_emotions_for_users", + bind=True, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=3600, + soft_time_limit=3300, + # 触发型任务标识,区别于 periodic_tasks 队列中的定时任务 + triggered=True, +) +def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]: + """事件触发任务:对指定用户列表做存在性检查,无记录则执行首次初始化。 + + 由 /dashboard/end_users 接口触发,已有数据的用户直接跳过。 + 存量用户的数据刷新由定时任务 update_implicit_emotions_storage 负责。 + + Args: + end_user_ids: 需要检查的用户ID列表 + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.core.logging_config import get_logger + from app.repositories.implicit_emotions_storage_repository import ( + ImplicitEmotionsStorageRepository, + ) + from app.services.emotion_analytics_service import EmotionAnalyticsService + from app.services.implicit_memory_service import ImplicitMemoryService + + logger = get_logger(__name__) + logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}") + + initialized = 0 + failed = 0 + skipped = 0 + + with get_db_context() as db: + repo = ImplicitEmotionsStorageRepository(db) + + for end_user_id in end_user_ids: + existing = repo.get_by_end_user_id(end_user_id) + if existing is not None: + skipped += 1 + continue + + logger.info(f"用户 {end_user_id} 无记录,开始初始化") + implicit_ok = False + emotion_ok = False + try: + try: + implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id) + profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id) + await implicit_service.save_profile_cache( + end_user_id=end_user_id, profile_data=profile_data, db=db + ) + implicit_ok = True + except Exception as e: + logger.error(f"用户 {end_user_id} 隐性记忆初始化失败: {e}") + + try: + emotion_service = EmotionAnalyticsService() + suggestions_data = await emotion_service.generate_emotion_suggestions( + end_user_id=end_user_id, db=db, language="zh" + ) + await emotion_service.save_suggestions_cache( + end_user_id=end_user_id, suggestions_data=suggestions_data, db=db + ) + emotion_ok = True + except Exception as e: + logger.error(f"用户 {end_user_id} 情绪建议初始化失败: {e}") + + if implicit_ok or emotion_ok: + initialized += 1 + else: + failed += 1 + except Exception as e: + failed += 1 + logger.error(f"用户 {end_user_id} 初始化异常: {e}") + + logger.info(f"按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}") + return { + "status": "SUCCESS", + "initialized": initialized, + "skipped": skipped, + "failed": failed, + } + + try: + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) + result["elapsed_time"] = time.time() - start_time + result["task_id"] = self.request.id + return result + except Exception as e: + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": time.time() - start_time, + "task_id": self.request.id, + } diff --git a/api/app/version_info.json b/api/app/version_info.json index 7d82eabc..bbaffc17 100644 --- a/api/app/version_info.json +++ b/api/app/version_info.json @@ -1,4 +1,36 @@ { + "v0.2.6": { + "introduction": { + "codeName": "听剑", + "releaseDate": "2026-3-6", + "upgradePosition": "🐻 多模态交互全面升级,记忆剪枝与工作流迁移双线并进,锋芒初露,兼收并蓄", + "coreUpgrades": [ + "1. 工作流与应用框架
* 工作流导入适配(Dify):支持 Dify 工作流定义无缝迁移
* 字段字数限制与校验规则:可配置字符限制与产品级校验
* 应用复制(Agent、工作流、集群):一键复制完整应用配置
* 对话变量(调试+分享):支持有状态多轮交互
* Chat 接口流式输出 message_id:流式响应包含消息追踪标识", + "2. 多模态与交互 💬
* 音频输入与输出:应用支持音频模态
* 文件类型输入支持:扩展支持语音、文件、视频上传", + "3. 模型与智能 🧠
* 模型视觉与 Omni 区分:精确区分视觉与 Omni 模型能力
* 教育记忆与陪伴玩具场景预设:垂直领域本体配置开箱即用
* 本体配置默认标识:支持基线配置标记
* 记忆配置默认标识:自动应用默认记忆设置", + "4. 记忆智能 🔬
* 记忆剪枝模块:智能裁剪冗余低价值记忆
* RAG 快速检索集成记忆:深度思考与正常回复双模式检索", + "5. 稳健性与缺陷修复 🔧
* 模型管理:修复自定义模型 API Key 批量配置错误
* 知识库管理:修复非源文档下载原始内容接口错误,更新分享停用提示文案
* 用户记忆:优化档案提取准确性(姓名、职业、兴趣分布)
* 长期记忆:修复情景记忆卡片重复和用户归属错误
* 工作空间首页:修复知识库数量、应用数量、总记忆容量、API 调用次数、知识库类型分布等数据不一致问题
* 基础设施:修正 Celery 环境变量配置,修复数据库连接池 idle-in-transaction 泄漏", + "
", + "v0.2.6 标志着 MemoryBear 在多模态交互、跨平台工作流迁移和智能记忆管理方面的重要突破。下一版本将聚焦 A2A 协议支持实现多智能体协作、多模态记忆能力扩展至语音与视觉领域,以及应用导入导出功能支持跨环境便携部署。", + "MemoryBear,让记忆有熊力 🐻✨" + ] + }, + "introduction_en": { + "codeName": "TingJian", + "releaseDate": "2026-3-6", + "upgradePosition": "🐻 Full multimodal interaction upgrade with memory pruning and workflow migration — sharpened edge, broader reach", + "coreUpgrades": [ + "1. Workflow & Application Framework
* Workflow Import Adaptation (Dify): Seamless Dify workflow migration
* Field Character Limits & Validation: Configurable limits with product-defined rules
* Application Cloning (Agent, Workflow, Cluster): One-click full config duplication
* Conversation Variables (Debug + Share): Stateful multi-turn interactions
* Streaming message_id in Chat API: Message tracking in streaming responses", + "2. Multimodal & Interaction 💬
* Audio Input & Output: Audio modality support for applications
* File Type Input Support: Voice, file, and video upload support", + "3. Model & Intelligence 🧠
* Model Vision & Omni Differentiation: Precise capability routing
* Education Memory & Companion Toy Presets: Domain-specific ontology configs
* Ontology Default Identifier: Baseline configuration flagging
* Memory Configuration Default Identifier: Auto-apply default settings", + "4. Memory Intelligence 🔬
* Memory Pruning Module: Intelligent trimming of redundant memories
* RAG Quick Retrieval with Memory: Deep think and normal reply dual-mode retrieval", + "5. Robustness & Bug Fixes 🔧
* Model Management: Fixed custom model API key batch configuration error
* Knowledge Base: Fixed download original content API error for non-source documents, updated share disable prompt text
* User Memory: Improved profile extraction accuracy (name, occupation, interests)
* Long-Term Memory: Fixed duplicate episodic memory cards and wrong user attribution
* Dashboard: Fixed data inconsistencies in knowledge count, app count, memory capacity, API calls, and knowledge type distribution
* Infrastructure: Corrected Celery environment variables, fixed database connection pool idle-in-transaction leak", + "
", + "v0.2.6 marks a significant milestone for MemoryBear in multimodal interaction, cross-platform workflow migration, and intelligent memory management. The next release will focus on A2A protocol support for multi-agent collaboration, multimodal memory extending extraction to voice and visual domains, and application import/export for portable cross-environment deployment.", + "MemoryBear, Memory with Bear Power 🐻✨" + ] + } + }, "v0.2.5": { "introduction": { "codeName": "行云", diff --git a/api/docker-compose.yml b/api/docker-compose.yml index 69763de2..5d358f2c 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -49,7 +49,7 @@ services: networks: - celery - # Periodic worker - Scheduled/beat tasks (prefork, low concurrency) + # Periodic worker - Scheduled/beat tasks + API-triggered tasks (prefork, low concurrency) worker-periodic: image: redbear-mem-open:latest container_name: worker-periodic diff --git a/api/migrations/versions/1ac07dc7366f_202603061644.py b/api/migrations/versions/1ac07dc7366f_202603061644.py new file mode 100644 index 00000000..81266d78 --- /dev/null +++ b/api/migrations/versions/1ac07dc7366f_202603061644.py @@ -0,0 +1,36 @@ +"""202603061644 + +Revision ID: 1ac07dc7366f +Revises: 6a4641cf192b +Create Date: 2026-03-06 16:51:10.152305 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '1ac07dc7366f' +down_revision: Union[str, None] = '6a4641cf192b' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('mcp_tool_configs', sa.Column('source_channel', sa.String(length=50), server_default=sa.text("'self_hosted'"), nullable=False, comment='来源渠道')) + op.add_column('mcp_tool_configs', sa.Column('market_id', sa.UUID(), nullable=True, comment='渠道市场id')) + op.add_column('mcp_tool_configs', sa.Column('market_config_id', sa.UUID(), nullable=True, comment='渠道市场配置id')) + op.add_column('mcp_tool_configs', sa.Column('mcp_service_id', sa.String(length=255), nullable=True, comment='mcp服务id')) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('mcp_tool_configs', 'mcp_service_id') + op.drop_column('mcp_tool_configs', 'market_config_id') + op.drop_column('mcp_tool_configs', 'market_id') + op.drop_column('mcp_tool_configs', 'source_channel') + # ### end Alembic commands ### diff --git a/web/src/components/Chat/ChatInput.tsx b/web/src/components/Chat/ChatInput.tsx index 49fb65d2..508b0d0c 100644 --- a/web/src/components/Chat/ChatInput.tsx +++ b/web/src/components/Chat/ChatInput.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2025-12-10 16:46:14 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-04 18:42:49 + * @Last Modified time: 2026-03-06 13:36:20 */ import { type FC, useEffect, useMemo } from 'react' import { Flex, Input, Form } from 'antd' @@ -50,13 +50,17 @@ const ChatInput: FC = ({ const handleDelete = (file: any) => { - fileChange?.(fileList?.filter(item => item.uid !== file.uid) || []) + fileChange?.(fileList?.filter(item => { + return item.thumbUrl && file.thumbUrl ? item.thumbUrl !== file.thumbUrl + : item.url && file.url ? item.url !== file.url + : item.uid !== file.uid + }) || []) } // Convert file object to preview URL const previewFileList = useMemo(() => { return fileList?.map(file => ({ ...file, - url: file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : file.thumbUrl) + url: file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined) })) || [] }, [fileList]) @@ -72,7 +76,7 @@ const ChatInput: FC = ({ {previewFileList.map((file) => { if (file.type.includes('image')) { return ( -
+
{file.name}
= ({ } if (file.type.includes('video')) { return ( -
+
} - extra={ handleChange(item)} />} + extra={ handleChange(item)} />} bodyClassName="rb:relative rb:pb-[64px]! rb:h-[calc(100%-64px)]!" > @@ -153,7 +153,7 @@ const ModelListDetail = forwardRef(({
- + {!item.model_id && } diff --git a/web/src/views/ModelManagement/types.ts b/web/src/views/ModelManagement/types.ts index 3233353b..d68e5521 100644 --- a/web/src/views/ModelManagement/types.ts +++ b/web/src/views/ModelManagement/types.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:50:18 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-04 11:39:20 + * @Last Modified time: 2026-03-06 12:26:11 */ /** * Type definitions for Model Management @@ -121,6 +121,7 @@ export interface ModelApiKey { * Model list item data structure */ export interface ModelListItem { + model_id?: string; /** Model name */ model_name?: string; /** Associated model config IDs */ diff --git a/web/src/views/Ontology/pages/Detail.tsx b/web/src/views/Ontology/pages/Detail.tsx index 22e08244..25609083 100644 --- a/web/src/views/Ontology/pages/Detail.tsx +++ b/web/src/views/Ontology/pages/Detail.tsx @@ -102,7 +102,7 @@ const Detail: FC = () => { {data.scene_name} - {t('common.default')} + {data.is_system_default ? {t('common.default')} : undefined} } subTitle={
{data.scene_description}
} extra={data.is_system_default ? undefined : ( diff --git a/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx index 4021a9ee..b263120a 100644 --- a/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx @@ -25,6 +25,7 @@ const InitialValuePlugin: React.FC = ({ value, options const textContent = root.getTextContent(); if (textContent !== prevValueRef.current) { isUserInputRef.current = true; + prevValueRef.current = textContent; } }); }); @@ -33,7 +34,13 @@ const InitialValuePlugin: React.FC = ({ value, options }, [editor]); useEffect(() => { - if ((value !== prevValueRef.current || enableLineNumbers !== prevEnableLineNumbersRef.current) && !isUserInputRef.current) { + if (value !== prevValueRef.current || enableLineNumbers !== prevEnableLineNumbersRef.current) { + // Skip reset if the change was triggered by user input (avoid cursor jump) + if (isUserInputRef.current && enableLineNumbers === prevEnableLineNumbersRef.current) { + prevValueRef.current = value; + isUserInputRef.current = false; + return; + } queueMicrotask(() => { editor.update(() => { const root = $getRoot(); diff --git a/web/src/views/Workflow/components/Properties/hooks/useVariableList.ts b/web/src/views/Workflow/components/Properties/hooks/useVariableList.ts index 4dca4854..779174ff 100644 --- a/web/src/views/Workflow/components/Properties/hooks/useVariableList.ts +++ b/web/src/views/Workflow/components/Properties/hooks/useVariableList.ts @@ -35,7 +35,8 @@ const NODE_VARIABLES = { ], 'http-request': [ { label: 'body', dataType: 'string', field: 'body' }, - { label: 'status_code', dataType: 'number', field: 'status_code' } + { label: 'status_code', dataType: 'number', field: 'status_code' }, + { label: 'headers', dataType: 'object', field: 'headers' }, ], 'question-classifier': [{ label: 'class_name', dataType: 'string', field: 'class_name' }], 'memory-read': [ @@ -390,11 +391,6 @@ export const useVariableList = ( addVariable(list, keys, `${pid}_item`, 'item', itemType, `${pid}.item`, pd); addVariable(list, keys, `${pid}_index`, 'index', 'number', `${pid}.index`, pd); } else if (pd.type === 'iteration' && !pd.config.input.defaultValue) { - let itemType = 'object'; - const iv = list.find(v => `{{${v.value}}}` === pd.config.input.defaultValue); - if (iv?.dataType.startsWith('array[')) { - itemType = iv.dataType.replace(/^array\[(.+)\]$/, '$1'); - } addVariable(list, keys, `${pid}_item`, 'item', 'string', `${pid}.item`, pd); addVariable(list, keys, `${pid}_index`, 'index', 'number', `${pid}.index`, pd); } diff --git a/web/src/views/Workflow/components/Properties/index.tsx b/web/src/views/Workflow/components/Properties/index.tsx index 76fc9ad0..bd5392cd 100644 --- a/web/src/views/Workflow/components/Properties/index.tsx +++ b/web/src/views/Workflow/components/Properties/index.tsx @@ -95,7 +95,7 @@ const Properties: FC = ({ initialValue[key] = config[key].defaultValue } }) - + form.setFieldsValue({ type, id: selectedNode.id, @@ -114,16 +114,16 @@ const Properties: FC = ({ */ const updateNodeLabel = (newLabel: string) => { if (selectedNode && form) { - const nodeData = selectedNode.data as NodeProperties; + const nodeData = selectedNode.getData() as NodeProperties; selectedNode.setAttrByPath('text/text', `${nodeData.icon} ${newLabel}`); - selectedNode.setData({ ...selectedNode.data, name: newLabel }); + selectedNode.setData({ ...selectedNode.getData(), name: newLabel }); } }; useEffect(() => { if (values && selectedNode) { const { id, knowledge_retrieval, group, group_variables, ...rest } = values - const { knowledge_bases = [], ...restKnowledgeConfig } = (knowledge_retrieval as any) || {} + const { knowledge_bases = [], name: _name, description: _description, ...restKnowledgeConfig } = (knowledge_retrieval as any) || {} let allRest = { ...rest, @@ -136,21 +136,23 @@ const Properties: FC = ({ })) } + const nodeData = selectedNode.getData() + Object.keys(values).forEach(key => { - if (selectedNode.data?.config?.[key]) { + if (nodeData?.config?.[key]) { // Create a deep copy to avoid reference sharing between nodes - if (!selectedNode.data.config[key]) { - selectedNode.data.config[key] = {}; + if (!nodeData.config[key]) { + nodeData.config[key] = {}; } - selectedNode.data.config[key] = { - ...selectedNode.data.config[key], + nodeData.config[key] = { + ...nodeData.config[key], defaultValue: values[key] }; } }) selectedNode?.setData({ - ...selectedNode.data, + ...nodeData, ...allRest, }) } diff --git a/web/src/views/Workflow/constant.ts b/web/src/views/Workflow/constant.ts index e7d2177a..0b2ec5ce 100644 --- a/web/src/views/Workflow/constant.ts +++ b/web/src/views/Workflow/constant.ts @@ -529,6 +529,10 @@ export const unknownNode = { type: 'unknown', icon: unknownIcon } +export const noteNode = { + type: 'notes', + icon: unknownIcon +} export const nodeWidth = 240; /** diff --git a/web/src/views/Workflow/hooks/useWorkflowGraph.ts b/web/src/views/Workflow/hooks/useWorkflowGraph.ts index 2d8d1939..971d591a 100644 --- a/web/src/views/Workflow/hooks/useWorkflowGraph.ts +++ b/web/src/views/Workflow/hooks/useWorkflowGraph.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:17:48 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-28 17:59:34 + * @Last Modified time: 2026-03-07 15:23:39 */ import { useRef, useEffect, useState } from 'react'; import { useParams } from 'react-router-dom'; @@ -12,7 +12,7 @@ import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from ' import { register } from '@antv/x6-react-shape'; import type { PortMetadata } from '@antv/x6/lib/model/port'; -import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode } from '../constant'; +import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode, noteNode } from '../constant'; import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types'; import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application' @@ -128,7 +128,7 @@ export const useWorkflowGraph = ({ if (nodes.length) { const nodeList = nodes.map(node => { const { id, type, name, position, config = {} } = node - let nodeLibraryConfig = [...nodeLibrary, { nodes: [unknownNode] }] + let nodeLibraryConfig = [...nodeLibrary, { nodes: [unknownNode, noteNode] }] .flatMap(category => category.nodes) .find(n => n.type === type) nodeLibraryConfig = JSON.parse(JSON.stringify({ config: {}, ...nodeLibraryConfig })) as NodeProperties @@ -715,6 +715,8 @@ export const useWorkflowGraph = ({ panning: isHandMode, mousewheel: { enabled: true, + factor: 0.1, + modifiers: null, }, connecting: { connector: {