diff --git a/api/app/controllers/prompt_optimizer_controller.py b/api/app/controllers/prompt_optimizer_controller.py index 2069dd66..c871c511 100644 --- a/api/app/controllers/prompt_optimizer_controller.py +++ b/api/app/controllers/prompt_optimizer_controller.py @@ -117,7 +117,7 @@ async def get_prompt_opt( user_require=data.message ): # chunk 是 prompt 的增量内容 - yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n" + yield f"event:message\ndata: {json.dumps(chunk)}\n\n" return StreamingResponse( event_generator(), diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 8eb31fb4..a1ec2e1d 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -29,7 +29,7 @@ class WorkflowState(TypedDict): # Set of loop node IDs, used for assigning values in loop nodes cycle_nodes: list - looping: bool + looping: Annotated[bool, lambda x, y: x and y] # Input variables (passed from configured variables) # Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx) diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 55919998..4374d847 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -208,17 +208,12 @@ class HttpRequestNode(BaseNode): retries -= 1 if retries > 0: await asyncio.sleep(self.typed_config.retry.retry_interval / 1000) + elif self.typed_config.error_handle.method == HttpErrorHandle.NONE: + raise e + except Exception as e: + raise RuntimeError(f"HTTP request node exception: {e}") else: match self.typed_config.error_handle.method: - case HttpErrorHandle.NONE: - logger.warning( - f"Node {self.node_id}: HTTP request failed, returning error response" - ) - return HttpRequestNodeOutput( - body="", - status_code=resp.status_code, - headers=resp.headers, - ).model_dump() case HttpErrorHandle.DEFAULT: logger.warning( f"Node {self.node_id}: HTTP request failed, returning default result" @@ -229,3 +224,4 @@ class HttpRequestNode(BaseNode): f"Node {self.node_id}: HTTP request failed, switching to error handling branch" ) return "ERROR" + raise RuntimeError("http request failed") diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index e12c6224..d9caae7e 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -203,15 +203,16 @@ class KnowledgeRetrievalNode(BaseNode): rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, indices=indices, score_threshold=kb_config.similarity_threshold) - # Deduplicate hybrid retrieval results + # Deduplicate hy brid retrieval results unique_rs = self._deduplicate_docs(rs1, rs2) vector_service.reranker = self.get_reranker_model() rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) case _: raise RuntimeError("Unknown retrieval type") vector_service.reranker = self.get_reranker_model() + # TODO:其他重排序方式支持 final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) logger.info( f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}" ) - return [chunk.model_dump() for chunk in final_rs] + return [chunk.page_content for chunk in final_rs] diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py index da94482b..8498fc38 100644 --- a/api/app/core/workflow/nodes/llm/config.py +++ b/api/app/core/workflow/nodes/llm/config.py @@ -1,5 +1,7 @@ """LLM 节点配置""" +from typing import Any + from pydantic import BaseModel, Field, field_validator from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType @@ -7,17 +9,17 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefiniti class MessageConfig(BaseModel): """消息配置""" - + role: str = Field( ..., description="消息角色:system, user, assistant" ) - + content: str = Field( ..., description="消息内容,支持模板变量,如:{{ sys.message }}" ) - + @field_validator("role") @classmethod def validate_role(cls, v: str) -> str: @@ -35,24 +37,29 @@ class LLMNodeConfig(BaseNodeConfig): 1. 简单模式:使用 prompt 字段 2. 消息模式:使用 messages 字段(推荐) """ - + model_id: str = Field( ..., description="模型配置 ID" ) - + + context: Any = Field( + default="", + description="上下文" + ) + # 简单模式 prompt: str | None = Field( default=None, description="提示词模板(简单模式),支持变量引用" ) - + # 消息模式(推荐) messages: list[MessageConfig] | None = Field( default=None, description="消息列表(消息模式),支持多轮对话" ) - + # 模型参数 temperature: float | None = Field( default=0.7, @@ -60,35 +67,35 @@ class LLMNodeConfig(BaseNodeConfig): le=2.0, description="温度参数,控制输出的随机性" ) - + max_tokens: int | None = Field( default=1000, ge=1, le=32000, description="最大生成 token 数" ) - + top_p: float | None = Field( default=None, ge=0.0, le=1.0, description="Top-p 采样参数" ) - + frequency_penalty: float | None = Field( default=None, ge=-2.0, le=2.0, description="频率惩罚" ) - + presence_penalty: float | None = Field( default=None, ge=-2.0, le=2.0, description="存在惩罚" ) - + # 输出变量定义 output_variables: list[VariableDefinition] = Field( default_factory=lambda: [ @@ -105,14 +112,14 @@ class LLMNodeConfig(BaseNodeConfig): ], description="输出变量定义(自动生成,通常不需要修改)" ) - + @field_validator("messages", "prompt") @classmethod def validate_input_mode(cls, v, info): """验证输入模式:prompt 和 messages 至少有一个""" # 这个验证在 model_validator 中更合适 return v - + class Config: json_schema_extra = { "examples": [ diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 65826d84..334229f7 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -5,15 +5,17 @@ LLM 节点实现 """ import logging +import re from typing import Any from langchain_core.messages import AIMessage, SystemMessage, HumanMessage from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.models import RedBearLLM, RedBearModelConfig +from app.core.workflow.nodes.llm.config import LLMNodeConfig from app.db import get_db_context from app.models import ModelType from app.services.model_service import ModelConfigService - + from app.core.exceptions import BusinessException from app.core.error_codes import BizCode @@ -63,8 +65,15 @@ class LLMNode(BaseNode): - user/human: 用户消息(HumanMessage) - ai/assistant: AI 消息(AIMessage) """ - - def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]: + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = LLMNodeConfig(**self.config) + + def _render_context(self, message,state): + context = f"{self._render_template(self.typed_config.context, state)}" + return re.sub(r"{{context}}", context, message) + + def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]: """准备 LLM 实例(公共逻辑) Args: @@ -76,15 +85,16 @@ class LLMNode(BaseNode): # 1. 处理消息格式(优先使用 messages) messages_config = self.config.get("messages") - + if messages_config: # 使用 LangChain 消息格式 messages = [] for msg_config in messages_config: role = msg_config.get("role", "user").lower() content_template = msg_config.get("content", "") + content_template = self._render_context(content_template, state) content = self._render_template(content_template, state) - + # 根据角色创建对应的消息对象 if role == "system": messages.append(SystemMessage(content=content)) @@ -95,7 +105,7 @@ class LLMNode(BaseNode): else: logger.warning(f"未知的消息角色: {role},默认使用 user") messages.append(HumanMessage(content=content)) - + prompt_or_messages = messages else: # 使用简单的 prompt 格式(向后兼容) @@ -106,17 +116,17 @@ class LLMNode(BaseNode): model_id = self.config.get("model_id") if not model_id: raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置") - + # 3. 在 with 块内完成所有数据库操作和数据提取 with get_db_context() as db: config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) - - if not config: + + if not config: raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND) - + if not config.api_keys or len(config.api_keys) == 0: raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) - + # 在 Session 关闭前提取所有需要的数据 api_config = config.api_keys[0] model_name = api_config.model_name @@ -124,26 +134,26 @@ class LLMNode(BaseNode): api_key = api_config.api_key api_base = api_config.api_base model_type = config.type - + # 4. 创建 LLM 实例(使用已提取的数据) # 注意:对于流式输出,需要在模型初始化时设置 streaming=True extra_params = {"streaming": stream} if stream else {} - + llm = RedBearLLM( RedBearModelConfig( model_name=model_name, - provider=provider, + provider=provider, api_key=api_key, base_url=api_base, extra_params=extra_params - ), + ), type=ModelType(model_type) ) - + logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") - + return llm, prompt_or_messages - + async def execute(self, state: WorkflowState) -> AIMessage: """非流式执行 LLM 调用 @@ -153,10 +163,10 @@ class LLMNode(BaseNode): Returns: LLM 响应消息 """ - llm, prompt_or_messages = self._prepare_llm(state,True) - + llm, prompt_or_messages = self._prepare_llm(state, True) + logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") - + # 调用 LLM(支持字符串或消息列表) response = await llm.ainvoke(prompt_or_messages) # 提取内容 @@ -164,16 +174,16 @@ class LLMNode(BaseNode): content = response.content else: content = str(response) - + logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}") - + # 返回 AIMessage(包含响应元数据) return response if isinstance(response, AIMessage) else AIMessage(content=content) - + def _extract_input(self, state: WorkflowState) -> dict[str, Any]: """提取输入数据(用于记录)""" _, prompt_or_messages = self._prepare_llm(state) - + return { "prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None, "messages": [ @@ -186,13 +196,13 @@ class LLMNode(BaseNode): "max_tokens": self.config.get("max_tokens") } } - + def _extract_output(self, business_result: Any) -> str: """从 AIMessage 中提取文本内容""" if isinstance(business_result, AIMessage): return business_result.content return str(business_result) - + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: """从 AIMessage 中提取 token 使用情况""" if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'): @@ -204,7 +214,7 @@ class LLMNode(BaseNode): "total_tokens": usage.get('total_tokens', 0) } return None - + async def execute_stream(self, state: WorkflowState): """流式执行 LLM 调用 @@ -215,26 +225,26 @@ class LLMNode(BaseNode): 文本片段(chunk)或完成标记 """ from langgraph.config import get_stream_writer - + llm, prompt_or_messages = self._prepare_llm(state, True) - + logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}") - + # 检查是否有注入的 End 节点前缀配置 writer = get_stream_writer() end_prefix = getattr(self, '_end_node_prefix', None) - + logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}") if end_prefix: logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'") - + if end_prefix: # 渲染前缀(可能包含其他变量) try: rendered_prefix = self._render_template(end_prefix, state) logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'") - + # 提前发送 End 节点的前缀(使用 "message" 类型) writer({ "type": "message", # End 相关的内容都是 message 类型 @@ -246,12 +256,12 @@ class LLMNode(BaseNode): }) except Exception as e: logger.warning(f"渲染/发送 End 节点前缀失败: {e}") - + # 累积完整响应 full_response = "" last_chunk = None chunk_count = 0 - + # 调用 LLM(流式,支持字符串或消息列表) async for chunk in llm.astream(prompt_or_messages): # 提取内容 @@ -259,18 +269,18 @@ class LLMNode(BaseNode): content = chunk.content else: content = str(chunk) - + # 只有当内容不为空时才处理 if content: full_response += content last_chunk = chunk chunk_count += 1 - + # 流式返回每个文本片段 yield content - + logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}") - + # 构建完整的 AIMessage(包含元数据) if isinstance(last_chunk, AIMessage): final_message = AIMessage( @@ -279,6 +289,6 @@ class LLMNode(BaseNode): ) else: final_message = AIMessage(content=full_response) - + # yield 完成标记 yield {"__final__": True, "result": final_message} diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 09c9fc68..bb2366f6 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode): return await MemoryAgentService().read_memory( group_id=end_user_id, - message=self.typed_config.message, + message=self._render_template(self.typed_config.message, state), config_id=self.typed_config.config_id, search_switch=self.typed_config.search_switch, history=[], @@ -51,7 +51,7 @@ class MemoryWriteNode(BaseNode): return await MemoryAgentService().write_memory( group_id=end_user_id, - message=self.typed_config.message, + message=self._render_template(self.typed_config.message, state), config_id=self.typed_config.config_id, db=db, storage_type="neo4j", diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index 67f53801..b0f2c28d 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -65,7 +65,7 @@ class QuestionClassifierNode(BaseNode): category_map[category_name] = case_tag return category_map - async def execute(self, state: WorkflowState) -> str: + async def execute(self, state: WorkflowState) -> dict: """执行问题分类""" question = self.typed_config.input_variable supplement_prompt = self.typed_config.user_supplement_prompt or "" @@ -79,7 +79,15 @@ class QuestionClassifierNode(BaseNode): f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})" ) # 若分类列表为空,返回默认unknown分支,否则返回CASE1 - return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown" + if category_count > 0: + return { + "class_name": category_names[0], + "output": DEFAULT_EMPTY_QUESTION_CASE + } + return { + "class_name": "unknown", + "output": DEFAULT_EMPTY_QUESTION_CASE + } try: llm = self._get_llm_instance() @@ -111,7 +119,10 @@ class QuestionClassifierNode(BaseNode): log_supplement = supplement_prompt if supplement_prompt else "无" logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}") - return f"CASE{category_names.index(category) + 1}" + return { + "class_name": category, + "output": f"CASE{category_names.index(category) + 1}", + } except Exception as e: logger.error( f"节点 {self.node_id} 分类执行异常:{str(e)}", @@ -119,5 +130,11 @@ class QuestionClassifierNode(BaseNode): ) # 异常时返回默认分支,保证工作流容错性 if category_count > 0: - return DEFAULT_EMPTY_QUESTION_CASE - return "unknown" + return { + "class_name": category_names[0], + "output": DEFAULT_EMPTY_QUESTION_CASE + } + return { + "class_name": "unknown", + "output": DEFAULT_EMPTY_QUESTION_CASE + } diff --git a/api/app/core/workflow/nodes/tool/config.py b/api/app/core/workflow/nodes/tool/config.py index 487efae2..d3b1a644 100644 --- a/api/app/core/workflow/nodes/tool/config.py +++ b/api/app/core/workflow/nodes/tool/config.py @@ -1,4 +1,6 @@ from pydantic import Field +from typing import Any + from app.core.workflow.nodes.base_config import BaseNodeConfig @@ -6,4 +8,4 @@ class ToolNodeConfig(BaseNodeConfig): """工具节点配置""" tool_id: str = Field(..., description="工具ID") - tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量") + tool_parameters: dict[str, Any] = Field(default_factory=dict, description="工具参数映射,支持工作流变量") diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index 993a3804..e1b5f380 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -1,5 +1,5 @@ import logging -import uuid +import re from typing import Any from app.core.workflow.nodes.base_node import BaseNode, WorkflowState @@ -9,6 +9,8 @@ from app.db import get_db_read logger = logging.getLogger(__name__) +TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}") + class ToolNode(BaseNode): """工具节点""" @@ -25,25 +27,33 @@ class ToolNode(BaseNode): # 如果没有租户ID,尝试从工作流ID获取 if not tenant_id: - workflow_id = self.get_variable("sys.workflow_id", state) - if workflow_id: + workspace_id = self.get_variable("sys.workspace_id", state) + if workspace_id: from app.repositories.tool_repository import ToolRepository with get_db_read() as db: - tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id) + tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id) if not tenant_id: - tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097") - # logger.error(f"节点 {self.node_id} 缺少租户ID") - # return {"error": "缺少租户ID"} + logger.error(f"节点 {self.node_id} 缺少租户ID") + return { + "success": False, + "data": "缺少租户ID" + } # 渲染工具参数 rendered_parameters = {} for param_name, param_template in self.typed_config.tool_parameters.items(): - rendered_value = self._render_template(param_template, state) + if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template): + try: + rendered_value = self._render_template(param_template, state) + except Exception as e: + raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e + else: + # 非模板参数(数字/布尔/普通字符串)直接保留原值 + rendered_value = param_template rendered_parameters[param_name] = rendered_value logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}") - print(self.typed_config.tool_id) # 执行工具 with get_db_read() as db: @@ -54,7 +64,7 @@ class ToolNode(BaseNode): tenant_id=tenant_id, user_id=user_id ) - print(result) + if result.success: logger.info(f"节点 {self.node_id} 工具执行成功") return { @@ -66,7 +76,7 @@ class ToolNode(BaseNode): logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}") return { "success": False, - "error": result.error, + "data": result.error, "error_code": result.error_code, "execution_time": result.execution_time } \ No newline at end of file diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index 00358d91..6daf415d 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -87,10 +87,11 @@ class WorkflowValidator: return graphs @classmethod - def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]: + def validate(cls, workflow_config: Union[dict[str, Any], Any], publish=False) -> tuple[bool, list[str]]: """验证工作流配置 Args: + publish: 发布验证标识 workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型 Returns: @@ -114,7 +115,7 @@ class WorkflowValidator: graphs = cls.get_subgraph(workflow_config) logger.info(graphs) - for graph in graphs: + for index, graph in enumerate(graphs): nodes = graph.get("nodes", []) edges = graph.get("edges", []) variables = graph.get("variables", []) @@ -125,10 +126,11 @@ class WorkflowValidator: elif len(start_nodes) > 1: errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个") - # 2. 验证 end 节点(至少一个) - end_nodes = [n for n in nodes if n.get("type") == NodeType.END] - if len(end_nodes) == 0: - errors.append("工作流必须至少有一个 end 节点") + if index == len(graphs) - 1: + # 2. 验证 主图end 节点(至少一个) + end_nodes = [n for n in nodes if n.get("type") == NodeType.END] + if len(end_nodes) == 0: + errors.append("工作流必须至少有一个 end 节点") # 3. 验证节点 ID 唯一性 node_ids = [n.get("id") for n in nodes] @@ -159,15 +161,17 @@ class WorkflowValidator: elif target not in node_id_set: errors.append(f"边 #{i} 的 target 节点不存在: {target}") - # 6. 验证所有节点可达(从 start 节点出发) - if start_nodes and not errors: # 只有在前面验证通过时才检查可达性 - reachable = WorkflowValidator._get_reachable_nodes( - start_nodes[0]["id"], - edges - ) - unreachable = node_id_set - reachable - if unreachable: - errors.append(f"以下节点无法从 start 节点到达: {unreachable}") + if publish: + # 仅在发布时验证所有节点可达 + # 6. 验证所有节点可达(从 start 节点出发) + if start_nodes and not errors: # 只有在前面验证通过时才检查可达性 + reachable = WorkflowValidator._get_reachable_nodes( + start_nodes[0]["id"], + edges + ) + unreachable = node_id_set - reachable + if unreachable: + errors.append(f"以下节点无法从 start 节点到达: {unreachable}") # 7. 检测循环依赖(非 loop 节点) if not errors: # 只有在前面验证通过时才检查循环 @@ -288,7 +292,7 @@ class WorkflowValidator: (is_valid, errors): 是否有效和错误列表 """ # 先执行基础验证 - is_valid, errors = WorkflowValidator.validate(workflow_config) + is_valid, errors = WorkflowValidator.validate(workflow_config, publish=True) if not is_valid: return False, errors diff --git a/api/app/repositories/tool_repository.py b/api/app/repositories/tool_repository.py index 3aa7b16e..257910c3 100644 --- a/api/app/repositories/tool_repository.py +++ b/api/app/repositories/tool_repository.py @@ -38,6 +38,33 @@ class ToolRepository: return result[0] if result else None + @staticmethod + def get_tenant_id_by_workspace_id(db: Session, workspace_id: str) -> Optional[uuid.UUID]: + """ + 根据空间ID获取tenant_id + + Args: + db: 数据库会话 + workspace_id: 空间ID + + Returns: + tenant_id或None + """ + from app.models.workspace_model import Workspace + + tenant_id = db.query(Workspace.tenant_id).filter( + Workspace.id == workspace_id + ).scalar() + + if tenant_id is not None and not isinstance(tenant_id, uuid.UUID): + # 兼容数据库中字段类型不匹配的情况(比如存储为字符串) + try: + tenant_id = uuid.UUID(tenant_id) + except (ValueError, TypeError): + return None + + return tenant_id + @staticmethod def find_by_tenant( db: Session, diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 482e8213..b3ac1b79 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -231,9 +231,9 @@ class PromptOptimizerService: if m: prompt_index = m.start() prompt_finished = True - yield {"type": "delta", "content": buffer[idx:prompt_index]} + yield {"content": buffer[idx:prompt_index]} else: - yield {"type": "delta", "content": cache[idx:]} + yield {"content": cache[idx:]} if len(cache) != 0: idx = len(cache) @@ -249,8 +249,8 @@ class PromptOptimizerService: role=RoleType.ASSISTANT, content=desc ) - - yield {"type": "done", "desc": optim_result.get("desc")} + variables = self.parser_prompt_variables(optim_result.get("prompt")) + yield {"desc": optim_result.get("desc"), "variables": variables} @staticmethod def parser_prompt_variables(prompt: str): diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 50cca957..ab5128fd 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -344,14 +344,16 @@ class ToolService: break if operation_param: - # 有多个操作 + # 有多个操作,为每个操作生成具体参数 methods = [] for operation in operation_param.enum: + # 获取该操作的具体参数 + operation_params = self._get_operation_specific_params(tool_instance, operation) methods.append({ "method_id": f"{config.name}_{operation}", "name": operation, "description": f"{config.description} - {operation}", - "parameters": [p for p in tool_instance.parameters if p.name != "operation"] + "parameters": operation_params }) return methods else: @@ -362,6 +364,243 @@ class ToolService: "description": config.description, "parameters": [p for p in tool_instance.parameters if p.name != "operation"] }] + + def _get_operation_specific_params(self, tool_instance: BaseTool, operation: str) -> List[Dict[str, Any]]: + """获取特定操作的参数列表""" + # 对于datetime_tool,根据操作类型返回相关参数 + if hasattr(tool_instance, 'name') and tool_instance.name == 'datetime_tool': + return self._get_datetime_tool_params(operation) + # 对于json_tool,根据操作类型返回相关参数 + elif hasattr(tool_instance, 'name') and tool_instance.name == 'json_tool': + return self._get_json_tool_params(operation) + + # 其他工具的默认处理:返回除operation外的所有参数 + return [{ + "name": param.name, + "type": param.type.value, + "description": param.description, + "required": param.required, + "default": param.default, + "enum": param.enum, + "minimum": param.minimum, + "maximum": param.maximum, + "pattern": param.pattern + } for param in tool_instance.parameters if param.name != "operation"] + + def _get_datetime_tool_params(self, operation: str) -> List[Dict[str, Any]]: + """获取datetime_tool特定操作的参数""" + if operation == "now": + return [ + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + } + ] + elif operation == "format": + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": True + }, + { + "name": "input_format", + "type": "string", + "description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + } + ] + elif operation == "convert_timezone": + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": True + }, + { + "name": "input_format", + "type": "string", + "description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "from_timezone", + "type": "string", + "description": "源时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + } + ] + elif operation == "timestamp_to_datetime": + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": True + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + } + ] + else: + # 默认返回所有参数(除了operation) + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": False + }, + { + "name": "input_format", + "type": "string", + "description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "from_timezone", + "type": "string", + "description": "源时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "calculation", + "type": "string", + "description": "时间计算表达式(如:+1d, -2h, +30m)", + "required": False + } + ] + + def _get_json_tool_params(self, operation: str) -> List[Dict[str, Any]]: + """获取json_tool特定操作的参数""" + base_params = [ + { + "name": "input_data", + "type": "string", + "description": "输入数据(JSON字符串、YAML字符串或XML字符串)", + "required": True + } + ] + + if operation == "insert": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + }, + { + "name": "new_value", + "type": "string", + "description": "新值(用于insert操作)", + "required": True + } + ] + elif operation == "replace": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + }, + { + "name": "old_text", + "type": "string", + "description": "要替换的原文本(用于replace操作)", + "required": True + }, + { + "name": "new_text", + "type": "string", + "description": "替换后的新文本(用于replace操作)", + "required": True + } + ] + elif operation == "delete": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + } + ] + elif operation == "parse": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + } + ] + + return base_params async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: """获取自定义工具的方法""" diff --git a/web/src/api/knowledgeBase.ts b/web/src/api/knowledgeBase.ts index a6979b92..5f171a72 100644 --- a/web/src/api/knowledgeBase.ts +++ b/web/src/api/knowledgeBase.ts @@ -295,4 +295,14 @@ export const getKnowledgeGraph = async (kb_id: string) => { export const getKnowledgeGraphEntityTypes = async (query: any) => { const response = await request.get(`${apiPrefix}/knowledges/knowledge_graph_entity_types`,query); return response ; +}; +// 删除图谱 +export const deleteKnowledgeGraph = async (kb_id: string) => { + const response = await request.delete(`${apiPrefix}/knowledges/${kb_id}/knowledge_graph`); + return response; +}; +// 知识库图谱重建 +export const rebuildKnowledgeGraph = async (kb_id: string) => { + const response = await request.post(`${apiPrefix}/knowledges/${kb_id}/knowledge_graph`); + return response; }; \ No newline at end of file diff --git a/web/src/assets/images/workflow/memory-read.png b/web/src/assets/images/workflow/memory-read.png new file mode 100644 index 00000000..4b0cdc1d Binary files /dev/null and b/web/src/assets/images/workflow/memory-read.png differ diff --git a/web/src/assets/images/workflow/memory-write.png b/web/src/assets/images/workflow/memory-write.png new file mode 100644 index 00000000..83a50fd4 Binary files /dev/null and b/web/src/assets/images/workflow/memory-write.png differ diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 0f3f5898..fc729d98 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -726,6 +726,11 @@ export const en = { graphTips: 'Explore the entity nodes in the knowledge base and their relationship networks', sourceDocuments: 'Source Documents', rebuildGraph: 'Rebuild Graph', + rebuildConfirmTitle: 'Confirm the rebuild graph', + rebuildConfirmContent: 'The rebuild graph will erase the existing map data and rebuild it from scratch. This operation is irreversible. Are you sure you want to proceed?', + deleteGraphSuccess: 'Knowledge graph deletion successful', + deleteGraphFailed:'Knowledge graph deletion failed', + graphEmpty: 'At the foot of the mountain of books, the journey begins.', createForm:{ name: 'Name', embedding_id: 'Embedding', @@ -1793,12 +1798,20 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re "not_contains": 'Does Not Contain', "startwith": 'Starts With', "endwith": 'Ends With', - "eq": '==', - "ne": '!=', - "lt": '<', - "le": '<=', - "gt": '>', - "ge": '>=', + "eq": 'Equals', + "ne": 'Not Equals', + num: { + "eq": '=', + "ne": '≠', + "lt": '<', + "le": '≤', + "gt": '>', + "ge": '≥', + }, + boolean: { + "eq": 'Is', + "ne": 'Is Not', + }, else_desc: 'Used to define the logic that should be executed when the if condition is not met.' }, 'http-request': { @@ -1839,12 +1852,17 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re loop: { cycle_vars: 'Loop Variables', condition: 'Loop Termination Condition', + max_loop: 'Maximum Loop Count', }, assigner: { assignments: 'Variables', - cover: 'Overwrite', + cover: 'Override', assign: 'Set', - clear: 'Clear' + clear: 'Clear', + add: '+=', + subtract: '-=', + multiply: '*=', + divide: '/=', }, iteration: { input: 'Input Variable', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 4aa03990..b6972c1f 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -336,6 +336,11 @@ export const zh = { graphTitle: '知识图谱:实体、关系与属性的关联网络', graphTips: '探索知识库中的实体节点及其关系脉络', rebuildGraph: '重建图谱', + rebuildConfirmTitle: '确认重建图谱', + rebuildConfirmContent: '重建图谱将清除现有的图谱数据并重新构建,此操作不可逆。确定要继续吗?', + deleteGraphSuccess: '删除知识图谱成功', + deleteGraphFailed:'删除知识图谱失败', + graphEmpty: '书山有路,此处为始', createForm: { name: '名称', embedding_id: '嵌入模型', @@ -1893,12 +1898,20 @@ export const zh = { "not_contains": '不包含', "startwith": '开始是', "endwith": '结束是', - "eq": '==', - "ne": '!=', - "lt": '<', - "le": '<=', - "gt": '>', - "ge": '>=', + "eq": '是', + "ne": '不是', + num: { + "eq": '=', + "ne": '≠', + "lt": '<', + "le": '≤', + "gt": '>', + "ge": '≥', + }, + boolean: { + "eq": '是', + "ne": '不是', + }, else_desc: '用于定义当 if 条件不满足时应执行的逻辑。' }, 'http-request': { @@ -1939,12 +1952,17 @@ export const zh = { loop: { cycle_vars: '循环变量', condition: '循环终止条件', + max_loop: '最大循环次数', }, assigner: { assignments: '变量', cover: '覆盖', assign: '设置', - clear: '清空' + clear: '清空', + add: '+=', + subtract: '-=', + multiply: '*=', + divide: '/=', }, iteration: { input: '输入变量', diff --git a/web/src/views/KnowledgeBase/[knowledgeBaseId]/Private.tsx b/web/src/views/KnowledgeBase/[knowledgeBaseId]/Private.tsx index d6d5ee4f..8087e596 100644 --- a/web/src/views/KnowledgeBase/[knowledgeBaseId]/Private.tsx +++ b/web/src/views/KnowledgeBase/[knowledgeBaseId]/Private.tsx @@ -657,6 +657,7 @@ const Private: FC = () => { const handleRefreshTable = () => { // 刷新表格数据 + fetchKnowledgeBaseDetail(knowledgeBase.id) tableRef.current?.loadData(); } return ( diff --git a/web/src/views/KnowledgeBase/components/CreateModal.tsx b/web/src/views/KnowledgeBase/components/CreateModal.tsx index 2349233f..ce228fa4 100644 --- a/web/src/views/KnowledgeBase/components/CreateModal.tsx +++ b/web/src/views/KnowledgeBase/components/CreateModal.tsx @@ -7,7 +7,9 @@ import { getModelList, createKnowledgeBase, updateKnowledgeBase, - getKnowledgeGraphEntityTypes + getKnowledgeGraphEntityTypes, + deleteKnowledgeGraph, + rebuildKnowledgeGraph } from '@/api/knowledgeBase' import RbModal from '@/components/RbModal' const { TextArea } = Input; @@ -31,6 +33,7 @@ const CreateModal = forwardRef(({ const [activeTab, setActiveTab] = useState('basic'); const [generatingEntityTypes, setGeneratingEntityTypes] = useState(false); const [isRebuildMode, setIsRebuildMode] = useState(false); + const [originalType, setOriginalType] = useState(''); // 保存原始的 type 参数 // 监听 parser_config.graphrag 相关字段的变化 const parserConfig = Form.useWatch('parser_config', form); @@ -47,6 +50,7 @@ const CreateModal = forwardRef(({ setLoading(false); setActiveTab('basic'); setIsRebuildMode(false); // 重置重建模式标识 + setOriginalType(''); // 重置原始 type setVisible(false); }; @@ -224,9 +228,12 @@ const CreateModal = forwardRef(({ const handleOpen = (record?: KnowledgeBaseListItem | null, type?: string) => { setDatasets(record || null); - const nextType = type || currentType; - setCurrentType(nextType as any); + + // 如果是重建模式,使用记录的实际类型,否则使用传入的类型 + const actualType = type === 'rebuild' ? (record?.type || 'General') : (type || currentType); + setCurrentType(actualType as any); setIsRebuildMode(type === 'rebuild'); // 设置重建模式标识 + setOriginalType(type || ''); // 保存原始的 type 参数 // 如果是重建模式,默认切换到知识图谱标签页 if (type === 'rebuild') { @@ -235,7 +242,7 @@ const CreateModal = forwardRef(({ setActiveTab('basic'); } - setBaseFields(record || null, nextType); + setBaseFields(record || null, actualType); getTypeList(record || null); setVisible(true); }; @@ -260,6 +267,39 @@ const CreateModal = forwardRef(({ // 封装保存方法,添加提交逻辑 const handleSave = () => { + // 获取当前表单中的知识图谱开启状态 + const currentFormValues = form.getFieldsValue(); + const isGraphragEnabled = currentFormValues?.parser_config?.graphrag?.use_graphrag || false; + + // 如果原始 type 是 'rebuild' 并且知识图谱开启为true,显示确认弹框 + if (originalType === 'rebuild' && isGraphragEnabled) { + confirm({ + title: t('knowledgeBase.rebuildConfirmTitle'), + content: t('knowledgeBase.rebuildConfirmContent'), + onOk: async() => { + handleDeleteGraph() + performSave(); + await rebuildKnowledgeGraph(datasets?.id || '') + }, + onCancel: () => { + // 用户取消,不执行任何操作 + }, + }); + } else { + // 非重建模式或知识图谱未开启,直接保存 + performSave(); + } + }; + const handleDeleteGraph = () => { + try{ + deleteKnowledgeGraph(datasets?.id || '') + console.log(t('knowledgeBase.deleteGraphSuccess')) + }catch(e){ + messageApi.error(t('knowledgeBase.deleteGraphFailed')) + } + }; + // 实际的保存逻辑 + const performSave = () => { form .validateFields() .then(() => { @@ -276,9 +316,12 @@ const CreateModal = forwardRef(({ formValues.parser_config.graphrag.entity_types = entityTypesArray; } + // 确保保存时使用正确的类型(不是 'rebuild') + const saveType = originalType === 'rebuild' ? currentType : (formValues.type || currentType); + const payload: KnowledgeBaseFormData = { ...formValues, - type: formValues.type || currentType, + type: saveType, permission_id: formValues.permission_id || 'Private', parent_id: datasets?.parent_id || undefined, }; diff --git a/web/src/views/KnowledgeBase/components/KnowledgeGraphCard.tsx b/web/src/views/KnowledgeBase/components/KnowledgeGraphCard.tsx index a485bacc..3dd7ab22 100644 --- a/web/src/views/KnowledgeBase/components/KnowledgeGraphCard.tsx +++ b/web/src/views/KnowledgeBase/components/KnowledgeGraphCard.tsx @@ -4,7 +4,7 @@ * @Author: yujiangping * @Date: 2025-12-30 15:07:37 * @LastEditors: yujiangping - * @LastEditTime: 2026-01-05 16:18:53 + * @LastEditTime: 2026-01-05 20:28:51 */ import React, { useState, useEffect } from 'react' import { useTranslation } from 'react-i18next'; @@ -38,7 +38,13 @@ const KnowledgeGraphCard: React.FC = ({ knowledgeBase, setLoading(true) try { const res = await getKnowledgeGraph(knowledgeBase?.id) - setData(res as KnowledgeGraphResponse) + // 判断 res.graph 是否为空对象或不存在 + const graphResponse = res as KnowledgeGraphResponse; + if (!graphResponse || !graphResponse.graph || Object.keys(graphResponse.graph).length === 0) { + setData(undefined) // 设置为 undefined 以显示 empty 状态 + } else { + setData(graphResponse) + } } catch (error) { console.error('获取知识图谱数据失败:', error) } finally { @@ -68,7 +74,10 @@ const KnowledgeGraphCard: React.FC = ({ knowledgeBase,
- {knowledgeBase?.parser_config?.graphrag?.use_graphrag ? () : } + {knowledgeBase?.parser_config?.graphrag?.use_graphrag ? + () + : + }
diff --git a/web/src/views/ToolManagement/Inner.tsx b/web/src/views/ToolManagement/Inner.tsx index d256d6c7..6f85e1f7 100644 --- a/web/src/views/ToolManagement/Inner.tsx +++ b/web/src/views/ToolManagement/Inner.tsx @@ -4,10 +4,9 @@ import { Col, Tag, List, - Space + Flex } from 'antd'; import { EyeOutlined } from '@ant-design/icons'; -import clsx from 'clsx' import { useTranslation } from 'react-i18next'; import dayjs, { type Dayjs } from 'dayjs' @@ -103,9 +102,9 @@ const Inner: React.FC<{ getStatusTag: (status: string) => ReactNode }> = ({ getS
{t(`tool.${item.config_data.tool_class}_features`)}
- + {InnerConfigData[item.config_data.tool_class].features.map(vo => { t(`tool.${vo}`) }) } - + {item.config_data.tool_class === 'DateTimeTool' ?
diff --git a/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx b/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx index 571f1e4e..fabe45ba 100644 --- a/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx +++ b/web/src/views/Workflow/components/AddChatVariable/ChatVariableModal.tsx @@ -26,7 +26,6 @@ const ChatVariableModal = forwardRef(); const [loading, setLoading] = useState(false) const [editIndex, setEditIndex] = useState(undefined) - const typeValue = Form.useWatch('type', form); // 封装取消方法,添加关闭弹窗逻辑 const handleClose = () => { diff --git a/web/src/views/Workflow/components/Editor/plugin/CharacterCountPlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/CharacterCountPlugin.tsx index 963f824b..ed07392d 100644 --- a/web/src/views/Workflow/components/Editor/plugin/CharacterCountPlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/CharacterCountPlugin.tsx @@ -14,18 +14,23 @@ const CharacterCountPlugin = ({ setCount, onChange }: { setCount: (count: number let serializedContent = ''; // Traverse all nodes and serialize properly + const paragraphs: string[] = []; root.getChildren().forEach(child => { if ($isParagraphNode(child)) { + let paragraphContent = ''; child.getChildren().forEach(node => { if ($isVariableNode(node)) { - serializedContent += node.getTextContent(); + paragraphContent += node.getTextContent(); } else { - serializedContent += node.getTextContent(); + paragraphContent += node.getTextContent(); } }); + paragraphs.push(paragraphContent); } }); + serializedContent = paragraphs.join('\n'); + setCount(serializedContent.length); onChange?.(serializedContent); }); diff --git a/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx b/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx index 4059b300..93197150 100644 --- a/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx +++ b/web/src/views/Workflow/components/Editor/plugin/InitialValuePlugin.tsx @@ -26,6 +26,7 @@ const InitialValuePlugin: React.FC = ({ value, options parts.forEach(part => { const match = part.match(/^\{\{([^.]+)\.([^}]+)\}\}$/); const contextMatch = part.match(/^\{\{context\}\}$/); + const conversationMatch = part.match(/^\{\{conv\.([^}]+)\}\}$/); // 匹配{{context}}格式 if (contextMatch) { @@ -38,6 +39,20 @@ const InitialValuePlugin: React.FC = ({ value, options return } + // 匹配{{conv.xx}}格式 + if (conversationMatch) { + const [_, variableName] = conversationMatch; + const conversationSuggestion = options.find(s => + s.group === 'CONVERSATION' && s.label === variableName + ); + if (conversationSuggestion) { + paragraph.append($createVariableNode(conversationSuggestion)); + } else { + paragraph.append($createTextNode(part)); + } + return + } + // 匹配普通变量{{nodeId.label}}格式 if (match) { const [_, nodeId, label] = match; diff --git a/web/src/views/Workflow/components/Nodes/AddNode.tsx b/web/src/views/Workflow/components/Nodes/AddNode.tsx index a2f6d930..973a503c 100644 --- a/web/src/views/Workflow/components/Nodes/AddNode.tsx +++ b/web/src/views/Workflow/components/Nodes/AddNode.tsx @@ -13,13 +13,15 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => { const handleNodeSelect = (selectedNodeType: any) => { const parentBBox = node.getBBox(); const cycleId = data.cycle; - + + const id = `${selectedNodeType.type.replace(/-/g, '_') }_${Date.now()}_${Math.random().toString(36).substr(2, 9)}` const newNode = graph.addNode({ ...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default), x: parentBBox.x, y: parentBBox.y, + id, data: { - id: `${selectedNodeType.type}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`, + id, type: selectedNodeType.type, icon: selectedNodeType.icon, name: t(`workflow.${selectedNodeType.type}`), diff --git a/web/src/views/Workflow/components/Nodes/LoopNode.tsx b/web/src/views/Workflow/components/Nodes/LoopNode.tsx index b0b8d4ce..37feb2dc 100644 --- a/web/src/views/Workflow/components/Nodes/LoopNode.tsx +++ b/web/src/views/Workflow/components/Nodes/LoopNode.tsx @@ -75,12 +75,15 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => { const parentBBox = node.getBBox(); const centerX = parentBBox.x + 24; // 默认节点宽度的一半 const centerY = parentBBox.y + 50; // 默认节点高度的一半 - + + const cycleStartNodeId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}` const cycleStartNode = graph.addNode({ ...graphNodeLibrary.cycleStart, x: centerX, y: centerY, + id: cycleStartNodeId, data: { + id: cycleStartNodeId, type: 'cycle-start', parentId: node.id, isDefault: true, // 标记为默认节点,不可删除 diff --git a/web/src/views/Workflow/components/PortClickHandler.tsx b/web/src/views/Workflow/components/PortClickHandler.tsx index 0be6fba1..9a644438 100644 --- a/web/src/views/Workflow/components/PortClickHandler.tsx +++ b/web/src/views/Workflow/components/PortClickHandler.tsx @@ -43,12 +43,14 @@ const PortClickHandler: React.FC = ({ graph }) => { const newY = sourceBBox.y; // 创建新节点 + const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}` const newNode = graph.addNode({ ...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default), x: newX, y: newY, + id, data: { - id: `${selectedNodeType.type}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`, + id, type: selectedNodeType.type, icon: selectedNodeType.icon, name: t(`workflow.${selectedNodeType.type}`), diff --git a/web/src/views/Workflow/components/Properties/AssignmentList/index.tsx b/web/src/views/Workflow/components/Properties/AssignmentList/index.tsx index 34c133c7..eac3775f 100644 --- a/web/src/views/Workflow/components/Properties/AssignmentList/index.tsx +++ b/web/src/views/Workflow/components/Properties/AssignmentList/index.tsx @@ -1,6 +1,6 @@ import { type FC } from 'react' import { useTranslation } from 'react-i18next'; -import { Form, Input, Button, Row, Col, Select } from 'antd' +import { Form, Input, Row, Col, Select, InputNumber, Radio } from 'antd' import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons'; import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' import VariableSelect from '../VariableSelect' @@ -11,6 +11,23 @@ interface AssignmentListProps { options: Suggestion[]; } +const operationsObj = { + number: [ + { value: 'cover', label: 'workflow.config.assigner.cover' }, + { value: 'clear', label: 'workflow.config.assigner.clear' }, + { value: 'assign', label: 'workflow.config.assigner.assign' }, + { value: 'add', label: 'workflow.config.assigner.add' }, + { value: 'subtract', label: 'workflow.config.assigner.subtract' }, + { value: 'multiply', label: 'workflow.config.assigner.multiply' }, + { value: 'divide', label: 'workflow.config.assigner.divide' }, + ], + default: [ + { value: 'cover', label: 'workflow.config.assigner.cover' }, + { value: 'clear', label: 'workflow.config.assigner.clear' }, + { value: 'assign', label: 'workflow.config.assigner.assign' }, + ], +} + const AssignmentList: FC = ({ parentName, options = [], @@ -27,6 +44,11 @@ const AssignmentList: FC = ({ add({ operation: 'cover'})} />
{fields.map(({ key, name, ...restField }) => { + const variableSelector = form.getFieldValue([parentName, name, 'variable_selector']); + const selectedOption = options.find(option => `{{${option.value}}}` === variableSelector); + const dataType = selectedOption?.dataType; + const operationOptions = dataType === 'number' ? operationsObj.number : operationsObj.default; + return (
@@ -50,11 +72,10 @@ const AssignmentList: FC = ({ noStyle > ({ - value: key, - label: t(`workflow.config.if-else.${key}`) + options={operatorList.map(vo => ({ + ...vo, + label: t(String(vo?.label || '')) }))} size="small" popupMatchSelectWidth={false} + placeholder={t('common.pleaseSelect')} /> @@ -280,11 +321,48 @@ const CaseList: FC = ({ - {!hideRightField && ( - - - - )} + {!hideRightField && <> + {leftFieldType === 'number' + ? + + + ({ - value: key, - label: t(`workflow.config.if-else.${key}`) + options={operatorList.map(vo => ({ + ...vo, + label: t(String(vo?.label || '')) }))} size="small" popupMatchSelectWidth={false} @@ -104,14 +139,53 @@ const ConditionList: FC = ({ onClick={() => remove(field.name)} /> - - {!hideRightField && ( - - - - - - )} + + {!hideRightField && <> + {leftFieldType === 'number' + ? + + + { - console.log('value record', value) - handleChange(record.key, 'type', value) - }} - /> - ), - }, - { - title: t('workflow.config.value'), + const columns = useMemo(() => { + const baseColumns = [ + { + title: typeOptions.length > 0 ? t('workflow.config.name') : '键', + dataIndex: 'name', + width: typeOptions.length > 0 ? '35%' : '45%', + render: (text: string, record: TableRow) => ( + handleChange(record.key, 'name', value || '')} + /> + ), + } + ]; + + if (typeOptions.length > 0) { + baseColumns.push({ + title: t('workflow.config.type'), + dataIndex: 'type', + width: '20%', + render: (text: string, record: TableRow) => ( + - - - - - remove(name)} /> - + + + + + + + + + + + + + remove(name)} /> + + ))}