diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index 162c8e57..509ad442 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -18,6 +18,9 @@ from app.schemas.response_schema import ApiResponse from app.core.response_utils import success from app.services import knowledge_service, document_service, file_service, knowledgeshare_service from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory +from app.core.rag.common.settings import kg_retriever +from app.core.rag.llm.chat_model import Base +from app.core.rag.llm.embedding_model import OpenAIEmbed from app.core.logging_config import get_api_logger # Obtain a dedicated API logger @@ -389,36 +392,41 @@ async def retrieve_chunks( knowledge_model.Knowledge.chunk_num > 0, knowledge_model.Knowledge.status == 1 ] - existing_ids = knowledge_service.get_chunded_knowledgeids( + private_items = knowledge_service.get_chunded_knowledgeids( db=db, filters=filters, current_user=current_user ) + private_kb_ids = [item[0] for item in private_items] + private_workspace_ids = [item[1] for item in private_items] filters = [ knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids), knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share, knowledge_model.Knowledge.chunk_num > 0, knowledge_model.Knowledge.status == 1 ] - share_ids = knowledge_service.get_chunded_knowledgeids( + items = knowledge_service.get_chunded_knowledgeids( db=db, filters=filters, current_user=current_user ) - if share_ids: + if items: filters = [ knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids) ] - items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id( + share_items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id( db=db, filters=filters, current_user=current_user ) - existing_ids.extend(items) - if not existing_ids: + share_kb_ids = [item[0] for item in share_items] + share_workspace_ids = [item[1] for item in share_items] + private_kb_ids.extend(share_kb_ids) + private_workspace_ids.extend(share_workspace_ids) + if not private_kb_ids: return success(data=[], msg="retrieval successful") - kb_id = existing_ids[0] - uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] + kb_id = private_kb_ids[0] + uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in private_kb_ids] indices = ",".join(uuid_strs) db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) if not db_knowledge: @@ -448,4 +456,21 @@ async def retrieve_chunks( seen_ids.add(doc.metadata["doc_id"]) unique_rs.append(doc) rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k) + if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph: + kb_ids = [str(kb_id) for kb_id in private_kb_ids] + workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids] + # Prepare to configure chat_mdl、embedding_model、vision_model information + chat_model = Base( + key=db_knowledge.llm.api_keys[0].api_key, + model_name=db_knowledge.llm.api_keys[0].model_name, + base_url=db_knowledge.llm.api_keys[0].api_base + ) + embedding_model = OpenAIEmbed( + key=db_knowledge.embedding.api_keys[0].api_key, + model_name=db_knowledge.embedding.api_keys[0].model_name, + base_url=db_knowledge.embedding.api_keys[0].api_base + ) + doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model) + if doc: + rs.insert(0, doc) return success(data=rs, msg="retrieval successful") \ No newline at end of file diff --git a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 index 200f2667..e649897a 100644 --- a/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/evaluate.jinja2 @@ -86,5 +86,5 @@ - **quality_assessment**: quality_assessment=true时输出评估对象,否则为null(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容) - **memory_verify**: memory_verify=true时输出隐私检测对象,否则为null - (注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容) + (注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true\memory_verify=false)等原数据字段以及涉及需要修改的字段以及内容) 模式参考:{{ json_schema }} \ No newline at end of file diff --git a/api/app/core/models/rerank.py b/api/app/core/models/rerank.py index 64b3b566..c4b91e25 100644 --- a/api/app/core/models/rerank.py +++ b/api/app/core/models/rerank.py @@ -1,4 +1,3 @@ - from typing import Any, Dict, List, Optional, Sequence, Type, Union from copy import deepcopy from urllib.parse import urlparse @@ -8,8 +7,10 @@ from langchain_core.callbacks import Callbacks from app.core.models.base import RedBearModelConfig, get_provider_rerank_class, RedBearModelFactory from app.models import ModelProvider + class RedBearRerank(BaseDocumentCompressor): """ Rerank → 作为 Runnable 插入任意 LCEL 链""" + def __init__(self, config: RedBearModelConfig): self._model = self._create_model(config) self._config = config @@ -22,10 +23,10 @@ class RedBearRerank(BaseDocumentCompressor): return model_class(**model_params) def compress_documents( - self, - documents: Sequence[Document], - query: str, - callbacks: Optional[Callbacks] = None, + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """ Compress documents using Jina's Rerank API. @@ -46,17 +47,17 @@ class RedBearRerank(BaseDocumentCompressor): compressed.append(doc_copy) return compressed - def rerank( - self, - documents: Sequence[Union[str, Document, dict]], - query: str, - *, - top_n: Optional[int] = -1, - ) -> List[Dict[str, Any]]: - provider = self._config.provider.lower() - if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + self, + documents: Sequence[Union[str, Document, dict]], + query: str, + *, + top_n: Optional[int] = -1, + ) -> List[Dict[str, Any]]: + provider = self._config.provider.lower() + if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: import langchain_community.document_compressors.jina_rerank as jina_mod + # 规范化:如果不以 /v1/rerank 结尾,则补齐;若已以 /v1 结尾,则补 /rerank def _normalize_jina_base(base_url: Optional[str]) -> Optional[str]: if not base_url: @@ -73,8 +74,7 @@ class RedBearRerank(BaseDocumentCompressor): # 设置完整的 rerank 端点,例如 http://host:port/v1/rerank jina_mod.JINA_API_URL = jina_base from langchain_community.document_compressors import JinaRerank - model_instance : JinaRerank = self._model - return model_instance.rerank(documents = documents, query = query, top_n=top_n) + model_instance: JinaRerank = self._model + return model_instance.rerank(documents=documents, query=query, top_n=top_n) else: raise ValueError(f"不支持的模型提供商: {provider}") - \ No newline at end of file diff --git a/api/app/core/rag/app/picture.py b/api/app/core/rag/app/picture.py index da133c27..99c3603d 100644 --- a/api/app/core/rag/app/picture.py +++ b/api/app/core/rag/app/picture.py @@ -51,7 +51,7 @@ def chunk(filename, binary, lang, callback=None, vision_model=None, **kwargs): img_binary = io.BytesIO() img.save(img_binary, format="JPEG") img_binary.seek(0) - ans = vision_model.describe(img_binary.read()) + ans, ans_num_tokens = vision_model.describe(img_binary.read()) callback(0.8, "CV LLM respond: %s ..." % ans[:32]) txt += "\n" + ans tokenize(doc, txt, eng) diff --git a/api/app/core/rag/deepdoc/parser/excel_parser.py b/api/app/core/rag/deepdoc/parser/excel_parser.py index b6e1e4a1..f7601ee3 100644 --- a/api/app/core/rag/deepdoc/parser/excel_parser.py +++ b/api/app/core/rag/deepdoc/parser/excel_parser.py @@ -42,11 +42,14 @@ class RAGExcelParser: file_like_object.seek(0) try: dfs = pd.read_excel(file_like_object, sheet_name=None) + if isinstance(dfs, dict): + dfs = next(iter(dfs.values())) return RAGExcelParser._dataframe_to_workbook(dfs) except Exception as ex: logging.info(f"pandas with default engine load error: {ex}, try calamine instead") file_like_object.seek(0) df = pd.read_excel(file_like_object, engine="calamine") + print(df) return RAGExcelParser._dataframe_to_workbook(df) except Exception as e_pandas: raise Exception(f"pandas.read_excel error: {e_pandas}, original openpyxl error: {e}") diff --git a/api/app/core/rag/graphrag/search.py b/api/app/core/rag/graphrag/search.py index 27eb674e..8823145f 100644 --- a/api/app/core/rag/graphrag/search.py +++ b/api/app/core/rag/graphrag/search.py @@ -4,6 +4,7 @@ from collections import defaultdict from copy import deepcopy import json_repair import pandas as pd +import time import trio from app.core.rag.common.misc_utils import get_uuid @@ -262,21 +263,21 @@ class KGSearch(Dealer): relas = "" return { - "chunk_id": get_uuid(), - "content_ltks": "", - "page_content": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms, - comm_topn, max_token), + "page_content": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms, comm_topn, max_token), + "vector": None, + "metadata": { + "doc_id": get_uuid(), + "file_id": "", + "file_name": "Related content in Knowledge Graph", + "file_created_at": int(time.time() * 1000), "document_id": "", - "docnm_kwd": "Related content in Knowledge Graph", - "kb_id": kb_ids, - "important_kwd": [], - "image_id": "", - "similarity": 1., - "vector_similarity": 1., - "term_similarity": 0, - "vector": [], - "positions": [], - } + "knowledge_id": kb_ids, + "sort_id": 0, + "status": 1, + "score": 1 + }, + "children": None + } def _community_retrieval_(self, entities, condition, kb_ids, idxnms, topn, max_token): ## Community retrieval diff --git a/api/app/core/rag/nlp/search.py b/api/app/core/rag/nlp/search.py index 57ff4dc8..1f696c98 100644 --- a/api/app/core/rag/nlp/search.py +++ b/api/app/core/rag/nlp/search.py @@ -26,6 +26,8 @@ from app.core.rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr from app.core.rag.common.string_utils import remove_redundant_spaces from app.core.rag.common.float_utils import get_float from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD +from app.core.rag.llm.chat_model import Base +from app.core.rag.llm.embedding_model import OpenAIEmbed def knowledge_retrieval( @@ -48,6 +50,7 @@ def knowledge_retrieval( - merge_strategy: "weight" or other strategies - reranker_id: UUID of the reranker to use - reranker_top_k: int + - use_graph: bool, whether to use a graph Returns: Rearranged document block list (in descending order of relevance) @@ -59,6 +62,7 @@ def knowledge_retrieval( merge_strategy = config.get("merge_strategy", "weight") reranker_id = config.get("reranker_id") reranker_top_k = config.get("reranker_top_k", 1024) + use_graph = config.get("use_graph", "false").lower() == "true" file_names_filter = [] if user_ids: @@ -67,6 +71,10 @@ def knowledge_retrieval( if not knowledge_bases: return [] + kb_ids = [] + workspace_ids = [] + chat_model = None + embedding_model = None all_results = [] # Search each knowledge base for kb_config in knowledge_bases: @@ -87,6 +95,22 @@ def knowledge_retrieval( else: continue + if str(db_knowledge.id) not in kb_ids: + kb_ids.append(str(db_knowledge.id)) + if str(db_knowledge.workspace_id) not in workspace_ids: + workspace_ids.append(str(db_knowledge.workspace_id)) + if not chat_model: + chat_model = Base( + key=db_knowledge.llm.api_keys[0].api_key, + model_name=db_knowledge.llm.api_keys[0].model_name, + base_url=db_knowledge.llm.api_keys[0].api_base + ) + if not embedding_model: + embedding_model = OpenAIEmbed( + key=db_knowledge.embedding.api_keys[0].api_key, + model_name=db_knowledge.embedding.api_keys[0].model_name, + base_url=db_knowledge.embedding.api_keys[0].api_base + ) vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) # Retrieve according to the configured retrieval type match kb_config["retrieve_type"]: @@ -136,6 +160,12 @@ def knowledge_retrieval( # Use the specified reranker for re-ranking if reranker_id: return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) + # use graph + if use_graph: + from app.core.rag.common.settings import kg_retriever + doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model) + if doc: + all_results.insert(0, doc) return all_results except Exception as e: diff --git a/api/app/core/rag/utils/es_conn.py b/api/app/core/rag/utils/es_conn.py index 43b1dfbe..7fbf0e38 100644 --- a/api/app/core/rag/utils/es_conn.py +++ b/api/app/core/rag/utils/es_conn.py @@ -213,7 +213,7 @@ class ESConnection(DocStoreConnection): m.topn * 2, query_vector=list(m.embedding_data), filter=bqry.to_dict(), - similarity=similarity, + # similarity=similarity ) if bqry and rank_feature: diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index de10f6f6..0d0879d7 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -4,9 +4,9 @@ 基于 LangGraph 的工作流执行引擎。 """ -import logging # import uuid import datetime +import logging from typing import Any from langchain_core.messages import HumanMessage @@ -107,7 +107,13 @@ class WorkflowExecutor: "user_id": self.user_id, "error": None, "error_node": None, - "streaming_buffer": {} # 流式缓冲区 + "streaming_buffer": {}, # 流式缓冲区 + "cycle_nodes": [ + node.get("id") + for node in self.workflow_config.get("nodes") + if node.get("type") in [NodeType.LOOP, NodeType.ITERATION] + ], # loop, iteration node id + "looping": False # loop runing flag, only use in loop node,not use in main loop } def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]: @@ -199,6 +205,10 @@ class WorkflowExecutor: for node in self.nodes: node_type = node.get("type") node_id = node.get("id") + cycle_node = node.get("cycle") + if cycle_node: + # 处于循环子图中的节点由 CycleGraphNode 进行构建处理 + continue # 记录 start 和 end 节点 ID if node_type == NodeType.START: @@ -271,7 +281,7 @@ class WorkflowExecutor: workflow.add_edge(START, start_node_id) logger.debug(f"添加边: START -> {start_node_id}") - for edge in self.edges: + for edge in self.workflow_config.get("edges", []): source = edge.get("source") target = edge.get("target") edge_type = edge.get("type") @@ -284,12 +294,12 @@ class WorkflowExecutor: logger.debug(f"添加边: {source} -> {target}") continue - # 处理到 end 节点的边 - if target in end_node_ids: - # 连接到 end 节点 - workflow.add_edge(source, target) - logger.debug(f"添加边: {source} -> {target}") - continue + # # 处理到 end 节点的边 + # if target in end_node_ids: + # # 连接到 end 节点 + # workflow.add_edge(source, target) + # logger.debug(f"添加边: {source} -> {target}") + # continue # 跳过错误边(在节点内部处理) if edge_type == "error": @@ -297,22 +307,30 @@ class WorkflowExecutor: if condition: # 条件边 - def router(state: WorkflowState, cond=condition, tgt=target): - """条件路由函数""" - if evaluate_condition( - cond, - state.get("variables", {}), - state.get("node_outputs", {}), - { - "execution_id": state.get("execution_id"), - "workspace_id": state.get("workspace_id"), - "user_id": state.get("user_id") - } - ): - return tgt - return END # 条件不满足,结束 + def make_router(cond, tgt): + """Dynamically generate a conditional router function to ensure each branch has a unique name.""" - workflow.add_conditional_edges(source, router) + + def router_fn(state: WorkflowState): + if evaluate_condition( + cond, + state.get("variables", {}), + state.get("node_outputs", {}), + { + "execution_id": state.get("execution_id"), + "workspace_id": state.get("workspace_id"), + "user_id": state.get("user_id") + } + ): + return tgt + return END + + # 动态修改函数名,避免重复 + router_fn.__name__ = f"router_{tgt}" + return router_fn + + router_fn = make_router(condition, target) + workflow.add_conditional_edges(source, router_fn) logger.debug(f"添加条件边: {source} -> {target} (condition={condition})") else: # 普通边 diff --git a/api/app/core/workflow/expression_evaluator.py b/api/app/core/workflow/expression_evaluator.py index 81ab25dc..1a8b101e 100644 --- a/api/app/core/workflow/expression_evaluator.py +++ b/api/app/core/workflow/expression_evaluator.py @@ -74,6 +74,7 @@ class ExpressionEvaluator: # 为了向后兼容,也支持直接访问(但会在日志中警告) context.update(variables) context["nodes"] = node_outputs + context.update(node_outputs) try: # simpleeval 只支持安全的操作: diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index c174f52a..a637a8c1 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -6,7 +6,7 @@ from app.core.workflow.expression_evaluator import ExpressionEvaluator from app.core.workflow.nodes.assigner.config import AssignerNodeConfig from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.enums import AssignmentOperator -from app.core.workflow.nodes.operators import AssignmentOperatorInstance +from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -40,8 +40,8 @@ class AssignerNode(BaseNode): variable_selector = expression.split('.') # Only conversation variables ('conv') are allowed - if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature) - raise ValueError("Only conversation variables can be assigned.") + if variable_selector[0] != 'conv' and variable_selector[0] not in state["cycle_nodes"]: + raise ValueError("Only conversation or cycle variables can be assigned.") # Get the value or expression to assign value = assignment.value @@ -55,7 +55,9 @@ class AssignerNode(BaseNode): ) # Select the appropriate assignment operator instance based on the target variable type - operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))( + operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value( + pool.get(variable_selector) + )( pool, variable_selector, value ) @@ -81,3 +83,5 @@ class AssignerNode(BaseNode): operator.remove_last() case _: raise ValueError(f"Invalid Operator: {assignment.operation}") + logger.info(f"Node {self.node_id}: execution completed") + diff --git a/api/app/core/workflow/nodes/base_config.py b/api/app/core/workflow/nodes/base_config.py index 90d02732..1550584a 100644 --- a/api/app/core/workflow/nodes/base_config.py +++ b/api/app/core/workflow/nodes/base_config.py @@ -14,9 +14,13 @@ class VariableType(StrEnum): STRING = "string" NUMBER = "number" BOOLEAN = "boolean" - ARRAY = "array" OBJECT = "object" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_BOOLEAN = "array[boolean]" + ARRAY_OBJECT = "array[object]" + class VariableDefinition(BaseModel): """变量定义 diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 82f3d9b8..e9fa4f25 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -20,40 +20,44 @@ logger = logging.getLogger(__name__) class WorkflowState(TypedDict): - """工作流状态 - - 在节点间传递的状态对象,包含消息、变量、节点输出等信息。 + """Workflow state + + The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc. """ - # 消息列表(追加模式) + # List of messages (append mode) messages: Annotated[list[AnyMessage], add] - - # 输入变量(从配置的 variables 传入) - # 使用深度合并函数,支持嵌套字典的更新(如 conv.xxx) + + # Set of loop node IDs, used for assigning values in loop nodes + cycle_nodes: list + looping: bool + + # Input variables (passed from configured variables) + # Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx) variables: Annotated[dict[str, Any], lambda x, y: { **x, **{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v for k, v in y.items()} }] - - # 节点输出(存储每个节点的执行结果,用于变量引用) - # 使用自定义合并函数,将新的节点输出合并到现有字典中 + + # Node outputs (stores execution results of each node for variable references) + # Uses a custom merge function to combine new node outputs into the existing dictionary node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}] - - # 运行时节点变量(简化版,只存储业务数据,供节点间快速访问) - # 格式:{node_id: business_result} + + # Runtime node variables (simplified version, stores business data for fast access between nodes) + # Format: {node_id: business_result} runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}] - # 执行上下文 + # Execution context execution_id: str workspace_id: str user_id: str - # 错误信息(用于错误边) + # Error information (for error edges) error: str | None error_node: str | None - - # 流式缓冲区(存储节点的实时流式输出) - # 格式:{node_id: {"chunks": [...], "full_content": "..."}} + + # Streaming buffer (stores real-time streaming output of nodes) + # Format: {node_id: {"chunks": [...], "full_content": "..."}} streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}] @@ -74,6 +78,7 @@ class BaseNode(ABC): self.workflow_config = workflow_config self.node_id = node_config["id"] self.node_type = node_config["type"] + self.cycle = node_config.get("cycle") self.node_name = node_config.get("name", self.node_id) # 使用 or 运算符处理 None 值 self.config = node_config.get("config") or {} @@ -170,10 +175,10 @@ class BaseNode(ABC): import time start_time = time.time() + + timeout = self.get_timeout() try: - timeout = self.get_timeout() - # 调用节点的业务逻辑 business_result = await asyncio.wait_for( self.execute(state), @@ -200,7 +205,8 @@ class BaseNode(ABC): **wrapped_output, "runtime_vars": { self.node_id: runtime_var - } + }, + "looping": state["looping"] } except TimeoutError: @@ -236,10 +242,10 @@ class BaseNode(ABC): import time start_time = time.time() + + timeout = self.get_timeout() try: - timeout = self.get_timeout() - # Get LangGraph's stream writer for sending custom data writer = get_stream_writer() diff --git a/api/app/core/workflow/nodes/breaker/__init__.py b/api/app/core/workflow/nodes/breaker/__init__.py new file mode 100644 index 00000000..d028cc25 --- /dev/null +++ b/api/app/core/workflow/nodes/breaker/__init__.py @@ -0,0 +1,3 @@ +from app.core.workflow.nodes.breaker.node import BreakNode + +__all__ = ["BreakNode"] diff --git a/api/app/core/workflow/nodes/breaker/node.py b/api/app/core/workflow/nodes/breaker/node.py new file mode 100644 index 00000000..882ffda0 --- /dev/null +++ b/api/app/core/workflow/nodes/breaker/node.py @@ -0,0 +1,33 @@ +import logging +from typing import Any + +from app.core.workflow.nodes import BaseNode, WorkflowState + +logger = logging.getLogger(__name__) + + +class BreakNode(BaseNode): + """ + Workflow node that immediately stops loop execution. + + When executed, this node sets the 'looping' flag in the workflow state + to False, signaling the outer loop runtime to terminate further iterations. + """ + + async def execute(self, state: WorkflowState) -> Any: + """ + Execute the break node. + + Args: + state: Current workflow state, including loop control flags. + + Effects: + - Sets 'looping' in the state to False to stop the loop. + - Logs the action for debugging purposes. + + Returns: + Optional dictionary indicating the loop has been stopped. + """ + state["looping"] = False + logger.info(f"Setting cycle node exit flag, cycle={self.cycle}, looping={state['looping']}") + diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index b1c64227..2ba23d4c 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -22,6 +22,7 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregato from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig +from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig __all__ = [ # 基础类 "BaseNodeConfig", @@ -41,5 +42,7 @@ __all__ = [ "JinjaRenderNodeConfig", "VariableAggregatorNodeConfig", "ParameterExtractorNodeConfig", + "LoopNodeConfig", + "IterationNodeConfig", "QuestionClassifierNodeConfig" ] diff --git a/api/app/core/workflow/nodes/cycle_graph/__init__.py b/api/app/core/workflow/nodes/cycle_graph/__init__.py new file mode 100644 index 00000000..dc2d72e0 --- /dev/null +++ b/api/app/core/workflow/nodes/cycle_graph/__init__.py @@ -0,0 +1,4 @@ +from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig +from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode + +__all__ = ['CycleGraphNode', 'LoopNodeConfig', 'IterationNodeConfig'] diff --git a/api/app/core/workflow/nodes/cycle_graph/config.py b/api/app/core/workflow/nodes/cycle_graph/config.py new file mode 100644 index 00000000..b1b613a4 --- /dev/null +++ b/api/app/core/workflow/nodes/cycle_graph/config.py @@ -0,0 +1,96 @@ +from pydantic import Field, BaseModel + +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType +from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator + + +class CycleVariable(BaseNodeConfig): + name: str = Field( + ..., + description="Name of the loop variable" + ) + type: VariableType = Field( + ..., + description="Data type of the loop variable" + ) + value: str = Field( + ..., + description="Initial or current value of the loop variable" + ) + + +class ConditionDetail(BaseModel): + comparison_operator: ComparisonOperator = Field( + ..., + description="Operator used to compare the left and right operands" + ) + + left: str = Field( + ..., + description="Left-hand operand of the comparison expression" + ) + + right: str = Field( + ..., + description="Right-hand operand of the comparison expression" + ) + + +class ConditionsConfig(BaseModel): + """Configuration for loop condition evaluation""" + + logical_operator: LogicOperator = Field( + default=LogicOperator.AND.value, + description="Logical operator used to combine multiple condition expressions" + ) + + expressions: list[ConditionDetail] = Field( + ..., + description="Collection of condition expressions to be evaluated" + ) + + +class LoopNodeConfig(BaseNodeConfig): + condition: ConditionsConfig = Field( + default_factory=list, + description="Conditional configuration that controls loop execution" + ) + + cycle_vars: list[CycleVariable] = Field( + default_factory=list, + description="List of variables used and updated during the loop" + ) + + max_loop: int = Field( + default=10, + description="Maximum number of loop iterations" + ) + + +class IterationNodeConfig(BaseNodeConfig): + input: str = Field( + ..., + description="Input of the loop iteration" + ) + + parallel: bool = Field( + default=False, + description="Whether to execute loop iterations in parallel" + ) + + parallel_count: int = Field( + default=4, + description="Number of iterations to run in parallel" + ) + + flatten: bool = Field( + default=False, + description="Whether to flatten the output list from iterations" + ) + + output: str = Field( + ..., + description="Output of the loop iteration" + ) + + diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py new file mode 100644 index 00000000..4ae8e118 --- /dev/null +++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py @@ -0,0 +1,154 @@ +import asyncio +import copy +import logging +import re +from typing import Any + +from langgraph.graph.state import CompiledStateGraph + +from app.core.workflow.nodes import WorkflowState +from app.core.workflow.nodes.cycle_graph import IterationNodeConfig +from app.core.workflow.variable_pool import VariablePool + +logger = logging.getLogger(__name__) + + +class IterationRuntime: + """ + Runtime executor for loop/iteration nodes in a workflow. + + This class handles executing iterations over a list variable, supporting + optional parallel execution, flattening of output, and loop control via + the workflow state. + """ + def __init__( + self, + graph: CompiledStateGraph, + node_id: str, + config: dict[str, Any], + state: WorkflowState, + ): + """ + Initialize the iteration runtime. + + Args: + graph: Compiled workflow graph capable of async invocation. + node_id: Unique identifier of the loop node. + config: Dictionary containing iteration node configuration. + state: Current workflow state at the point of iteration. + """ + self.graph = graph + self.state = state + self.node_id = node_id + self.typed_config = IterationNodeConfig(**config) + self.looping = True + + self.output_value = None + self.result: list = [] + + def _init_iteration_state(self, item, idx): + """ + Initialize a per-iteration copy of the workflow state. + + Args: + item: Current element from the input array for this iteration. + idx: Index of the element in the input array. + + Returns: + A deep copy of the workflow state with iteration-specific variables set. + """ + loopstate = WorkflowState( + **copy.deepcopy(self.state) + ) + loopstate["runtime_vars"][self.node_id] = { + "item": item, + "index": idx, + } + loopstate["node_outputs"][self.node_id] = { + "item": item, + "index": idx, + } + loopstate["looping"] = True + return loopstate + + async def run_task(self, item, idx): + """ + Execute a single iteration asynchronously. + + Args: + item: The input element for this iteration. + idx: The index of this iteration. + """ + result = await self.graph.ainvoke(self._init_iteration_state(item, idx)) + output = VariablePool(result).get(self.output_value) + if isinstance(output, list) and self.typed_config.flatten: + self.result.extend(output) + else: + self.result.append(output) + if not result["looping"]: + self.looping = False + + def _create_iteration_tasks(self, array_obj, idx): + """ + Create async tasks for a batch of iterations based on parallel count. + + Args: + array_obj: The input array to iterate over. + idx: Starting index for this batch of iterations. + + Returns: + List of coroutine tasks ready to be executed in parallel. + """ + tasks = [] + for i in range(self.typed_config.parallel_count): + if idx + i >= len(array_obj): + break + item = array_obj[idx + i] + tasks.append(self.run_task(item, idx + i)) + return tasks + + async def run(self): + """ + Execute the loop over the input array according to configuration. + + Returns: + A list of outputs from all iterations, optionally flattened. + + Raises: + RuntimeError: If the input variable is not a list. + """ + pattern = r"\{\{\s*(.*?)\s*\}\}" + input_expression = re.sub(pattern, r"\1", self.typed_config.input).strip() + self.output_value = re.sub(pattern, r"\1", self.typed_config.output).strip() + + array_obj = VariablePool(self.state).get(input_expression) + if not isinstance(array_obj, list): + raise RuntimeError("Cannot iterate over a non-list variable") + + idx = 0 + if self.typed_config.parallel: + # Execute iterations in parallel batches + while idx < len(array_obj) and self.looping: + tasks = self._create_iteration_tasks(array_obj, idx) + logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}") + idx += self.typed_config.parallel_count + await asyncio.gather(*tasks) + logger.info(f"Iteration node {self.node_id}: execution completed") + return self.result + else: + # Execute iterations sequentially + while idx < len(array_obj) and self.looping: + logger.info(f"Iteration node {self.node_id}: running") + item = array_obj[idx] + result = await self.graph.ainvoke(self._init_iteration_state(item, idx)) + output = VariablePool(result).get(self.output_value) + if isinstance(output, list) and self.typed_config.flatten: + self.result.extend(output) + else: + self.result.append(output) + if not result["looping"]: + self.looping = False + idx += 1 + + logger.info(f"Iteration node {self.node_id}: execution completed") + return self.result diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py new file mode 100644 index 00000000..af75d372 --- /dev/null +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -0,0 +1,130 @@ +import logging +from typing import Any + +from langgraph.graph.state import CompiledStateGraph + +from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression +from app.core.workflow.nodes import WorkflowState +from app.core.workflow.nodes.cycle_graph import LoopNodeConfig +from app.core.workflow.nodes.operators import ConditionExpressionBuilder +from app.core.workflow.variable_pool import VariablePool + +logger = logging.getLogger(__name__) + + +class LoopRuntime: + """ + Runtime executor for loop nodes in a workflow. + + Handles iterative execution of a loop node according to defined loop variables + and conditional expressions. Supports maximum loop count and loop control + through the workflow state. + """ + + def __init__( + self, + graph: CompiledStateGraph, + node_id: str, + config: dict[str, Any], + state: WorkflowState, + ): + """ + Initialize the loop runtime. + + Args: + graph: Compiled workflow graph capable of async invocation. + node_id: Unique identifier of the loop node. + config: Dictionary containing loop node configuration. + state: Current workflow state at the point of loop execution. + """ + self.graph = graph + self.state = state + self.node_id = node_id + self.typed_config = LoopNodeConfig(**config) + + def _init_loop_state(self): + """ + Initialize workflow state for loop execution. + + - Evaluates initial values of loop variables. + - Stores loop variables in runtime_vars and node_outputs. + - Marks the loop as active by setting 'looping' to True. + + Returns: + A copy of the workflow state prepared for the loop execution. + """ + pool = VariablePool(self.state) + # 循环变量 + self.state["runtime_vars"][self.node_id] = { + variable.name: evaluate_expression( + expression=variable.value, + variables=pool.get_all_conversation_vars(), + node_outputs=pool.get_all_node_outputs(), + system_vars=pool.get_all_system_vars(), + ) + for variable in self.typed_config.cycle_vars + } + self.state["node_outputs"][self.node_id] = { + variable.name: evaluate_expression( + expression=variable.value, + variables=pool.get_all_conversation_vars(), + node_outputs=pool.get_all_node_outputs(), + system_vars=pool.get_all_system_vars(), + ) + for variable in self.typed_config.cycle_vars + } + loopstate = WorkflowState( + **self.state + ) + loopstate["looping"] = True + return loopstate + + def _get_loop_expression(self): + """ + Build the Python boolean expression for evaluating the loop condition. + + - Converts each condition in the loop configuration into a Python expression string. + - Combines multiple conditions with the configured logical operator (AND/OR). + + Returns: + A string representing the combined loop condition expression. + """ + branch_conditions = [ + ConditionExpressionBuilder( + left=condition.left, + operator=condition.comparison_operator, + right=condition.right + ).build() + for condition in self.typed_config.condition.expressions + ] + if len(branch_conditions) > 1: + combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions) + else: + combined_condition = branch_conditions[0] + + return combined_condition + + async def run(self): + """ + Execute the loop node until the condition is no longer met, the loop is + manually stopped, or the maximum loop count is reached. + + Returns: + The final runtime variables of this loop node after completion. + """ + loopstate = self._init_loop_state() + expression = self._get_loop_expression() + loop_variable_pool = VariablePool(loopstate) + loop_time = self.typed_config.max_loop + while evaluate_condition( + expression=expression, + variables=loop_variable_pool.get_all_conversation_vars(), + node_outputs=loop_variable_pool.get_all_node_outputs(), + system_vars=loop_variable_pool.get_all_system_vars(), + ) and loopstate["looping"] and loop_time > 0: + logger.info(f"loop node {self.node_id}: running") + await self.graph.ainvoke(loopstate) + loop_time -= 1 + + logger.info(f"loop node {self.node_id}: execution completed") + return loopstate["runtime_vars"][self.node_id] diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py new file mode 100644 index 00000000..2428ef46 --- /dev/null +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -0,0 +1,226 @@ +import logging +from typing import Any + +from langgraph.graph import StateGraph, START, END +from langgraph.graph.state import CompiledStateGraph + +from app.core.workflow.expression_evaluator import evaluate_condition +from app.core.workflow.nodes import WorkflowState +from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig +from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime +from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime +from app.core.workflow.nodes.enums import NodeType + +logger = logging.getLogger(__name__) + + +class CycleGraphNode(BaseNode): + """ + Node representing a cycle (loop) subgraph within the workflow. + + This node manages internal loop/iteration nodes, builds a subgraph + for execution, handles conditional routing, and executes loop + or iteration logic based on node type. + """ + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None + + self.cycle_nodes = list() # Nodes belonging to this cycle + self.cycle_edges = list() # Edges connecting nodes within the cycle + self.start_node_id = None # ID of the start node within the cycle + self.end_node_ids = [] # IDs of end nodes within the cycle + + self.graph: StateGraph | CompiledStateGraph | None = None + self.build_graph() + self.iteration_flag = True + + def pure_cycle_graph(self) -> tuple[list, list]: + """ + Extract cycle nodes and internal edges from the workflow configuration, + removing them from the global workflow. + + Raises: + ValueError: If cycle nodes are connected to external nodes improperly. + + Returns: + Tuple containing: + - cycle_nodes: List of removed nodes + - cycle_edges: List of removed edges + """ + nodes = self.workflow_config.get("nodes", []) + edges = self.workflow_config.get("edges", []) + + # Select all nodes that belong to the current cycle + cycle_nodes = [node for node in nodes if node.get("cycle") == self.node_id] + cycle_node_ids = {node.get("id") for node in cycle_nodes} + + cycle_edges = [] + remain_edges = [] + + for edge in edges: + source_in = edge.get("source") in cycle_node_ids + target_in = edge.get("target") in cycle_node_ids + + # Raise error if cycle nodes are connected with external nodes + if source_in ^ target_in: + raise ValueError( + f"Cycle node is connected to external node, " + f"source: {edge.get('source')}, target: {edge.get('target')}" + ) + + if source_in and target_in: + cycle_edges.append(edge) + else: + remain_edges.append(edge) + + # Update workflow_config by removing cycle nodes and internal edges + self.workflow_config["nodes"] = [ + node for node in nodes if node.get("cycle") != self.node_id + ] + self.workflow_config["edges"] = remain_edges + + return cycle_nodes, cycle_edges + + def create_node(self): + """ + Instantiate node objects for each node in the cycle subgraph and add them to the graph. + + Special handling is applied for conditional nodes to generate + edge conditions based on node outputs. + """ + from app.core.workflow.nodes import NodeFactory + for node in self.cycle_nodes: + node_type = node.get("type") + node_id = node.get("id") + + if node_type == NodeType.CYCLE_START: + self.start_node_id = node_id + continue + elif node_type == NodeType.END: + self.end_node_ids.append(node_id) + + node_instance = NodeFactory.create_node(node, self.workflow_config) + + if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]: + expressions = node_instance.build_conditional_edge_expressions() + + # Number of branches, usually matches the number of conditional expressions + branch_number = len(expressions) + + # Find all edges whose source is the current node + related_edge = [edge for edge in self.cycle_edges if edge.get("source") == node_id] + + # Iterate over each branch + for idx in range(branch_number): + # Generate a condition expression for each edge + # Used later to determine which branch to take based on the node's output + # Assumes node output `node..output` matches the edge's label + # For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' + related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" + + def make_func(inst): + async def node_func(state: WorkflowState): + return await inst.run(state) + + return node_func + + self.graph.add_node(node_id, make_func(node_instance)) + + def create_edge(self): + """ + Connect nodes within the cycle subgraph by adding edges to the internal graph. + + Conditional edges are routed based on evaluated expressions. + Start and end nodes are connected to global START and END nodes. + """ + for edge in self.cycle_edges: + source = edge.get("source") + target = edge.get("target") + edge_type = edge.get("type") + condition = edge.get("condition") + + # 跳过从 start 节点出发的边(因为已经从 START 连接到 start) + if source == self.start_node_id: + # 但要连接 start 到下一个节点 + self.graph.add_edge(START, target) + logger.debug(f"添加边: {source} -> {target}") + continue + + if condition: + # 条件边 + def router(state: WorkflowState, cond=condition, tgt=target): + """条件路由函数""" + if evaluate_condition( + cond, + state.get("variables", {}), + state.get("node_outputs", {}), + { + "execution_id": state.get("execution_id"), + "workspace_id": state.get("workspace_id"), + "user_id": state.get("user_id") + } + ): + return tgt + return END # 条件不满足,结束 + + self.graph.add_conditional_edges(source, router) + logger.debug(f"添加条件边: {source} -> {target} (condition={condition})") + else: + # 普通边 + self.graph.add_edge(source, target) + logger.debug(f"添加边: {source} -> {target}") + + # 从 end 节点连接到 END + for end_node_id in self.end_node_ids: + self.graph.add_edge(end_node_id, END) + logger.debug(f"添加边: {end_node_id} -> END") + + def build_graph(self): + """ + Build the internal subgraph for the cycle node. + + Steps: + 1. Extract cycle nodes and edges. + 2. Create node instances and add them to the graph. + 3. Connect edges and conditional routes. + 4. Compile the graph for execution. + """ + self.graph = StateGraph(WorkflowState) + self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() + self.create_node() + self.create_edge() + self.graph = self.graph.compile() + + async def execute(self, state: WorkflowState) -> Any: + """ + Execute the cycle node at runtime. + + Depending on the node type, runs either a loop (LoopRuntime) + or an iteration (IterationRuntime) over the internal subgraph. + + Args: + state: Current workflow state. + + Returns: + Runtime result of the cycle, typically the final loop/iteration variables. + + Raises: + RuntimeError: If node type is unrecognized. + """ + if self.node_type == NodeType.LOOP: + return await LoopRuntime( + graph=self.graph, + node_id=self.node_id, + config=self.config, + state=state, + ).run() + if self.node_type == NodeType.ITERATION: + return await IterationRuntime( + graph=self.graph, + node_id=self.node_id, + config=self.config, + state=state, + ).run() + raise RuntimeError("Unknown cycle node type") diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index b4cc0634..0492a7bf 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -1,14 +1,5 @@ from enum import StrEnum -from app.core.workflow.nodes.operators import ( - StringOperator, - NumberOperator, - AssignmentOperatorType, - BooleanOperator, - ArrayOperator, - ObjectOperator -) - class NodeType(StrEnum): START = "start" @@ -27,6 +18,10 @@ class NodeType(StrEnum): JINJARENDER = "jinja-render" VAR_AGGREGATOR = "var-aggregator" PARAMETER_EXTRACTOR = "parameter-extractor" + LOOP = "loop" + ITERATION = "iteration" + CYCLE_START = "cycle-start" + BREAK = "break" class ComparisonOperator(StrEnum): @@ -62,21 +57,6 @@ class AssignmentOperator(StrEnum): REMOVE_LAST = "remove_last" REMOVE_FIRST = "remove_first" - @classmethod - def get_operator(cls, obj) -> AssignmentOperatorType: - if isinstance(obj, str): - return StringOperator - elif isinstance(obj, bool): - return BooleanOperator - elif isinstance(obj, (int, float)): - return NumberOperator - elif isinstance(obj, list): - return ArrayOperator - elif isinstance(obj, dict): - return ObjectOperator - - raise TypeError(f"Unsupported variable type ({type(obj)})") - class HttpRequestMethod(StrEnum): GET = "GET" diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 3b3a8b1a..d23ccd03 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -215,6 +215,7 @@ class HttpRequestNode(BaseNode): **self._build_content(state) ) resp.raise_for_status() + logger.info(f"Node {self.node_id}: HTTP request succeeded") return HttpRequestNodeOutput( body=resp.text, status_code=resp.status_code, @@ -228,12 +229,21 @@ class HttpRequestNode(BaseNode): else: match self.typed_config.error_handle.method: case HttpErrorHandle.NONE: + logger.warning( + f"Node {self.node_id}: HTTP request failed, returning error response" + ) return HttpRequestNodeOutput( body="", status_code=resp.status_code, headers=resp.headers, ).model_dump() case HttpErrorHandle.DEFAULT: + logger.warning( + f"Node {self.node_id}: HTTP request failed, returning default result" + ) return self.typed_config.error_handle.default.model_dump() case HttpErrorHandle.BRANCH: + logger.warning( + f"Node {self.node_id}: HTTP request failed, switching to error handling branch" + ) return "ERROR" diff --git a/api/app/core/workflow/nodes/if_else/config.py b/api/app/core/workflow/nodes/if_else/config.py index 4e424b54..9eddb473 100644 --- a/api/app/core/workflow/nodes/if_else/config.py +++ b/api/app/core/workflow/nodes/if_else/config.py @@ -30,7 +30,7 @@ class ConditionBranchConfig(BaseModel): description="Logical operator used to combine multiple condition expressions" ) - conditions: list[ConditionDetail] = Field( + expressions: list[ConditionDetail] = Field( ..., description="List of condition expressions within this branch" ) @@ -57,7 +57,7 @@ class IfElseNodeConfig(BaseNodeConfig): # CASE1 / IF Branch { "logical_operator": "and", - "conditions": [ + "expressions": [ [ { "left": "node.userinput.message", @@ -75,7 +75,7 @@ class IfElseNodeConfig(BaseNodeConfig): # CASE1 / ELIF Branch { "logical_operator": "or", - "conditions": [ + "expressions": [ [ { "left": "node.userinput.test", diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 579c2840..1450a28f 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -2,93 +2,13 @@ import logging from typing import Any from app.core.workflow.nodes.base_node import BaseNode, WorkflowState -from app.core.workflow.nodes.enums import ComparisonOperator from app.core.workflow.nodes.if_else import IfElseNodeConfig from app.core.workflow.nodes.if_else.config import ConditionDetail +from app.core.workflow.nodes.operators import ConditionExpressionBuilder logger = logging.getLogger(__name__) -class ConditionExpressionBuilder: - """ - Build a Python boolean expression string based on a comparison operator. - - This class does not evaluate the expression. - It only generates a valid Python expression string - that can be evaluated later in a workflow context. - """ - - def __init__(self, left: str, operator: ComparisonOperator, right: str): - self.left = left - self.operator = operator - self.right = right - - def _empty(self): - return f"{self.left} == ''" - - def _not_empty(self): - return f"{self.left} != ''" - - def _contains(self): - return f"{self.right} in {self.left}" - - def _not_contains(self): - return f"{self.right} not in {self.left}" - - def _startwith(self): - return f'{self.left}.startswith({self.right})' - - def _endwith(self): - return f'{self.left}.endswith({self.right})' - - def _eq(self): - return f"{self.left} == {self.right}" - - def _ne(self): - return f"{self.left} != {self.right}" - - def _lt(self): - return f"{self.left} < {self.right}" - - def _le(self): - return f"{self.left} <= {self.right}" - - def _gt(self): - return f"{self.left} > {self.right}" - - def _ge(self): - return f"{self.left} >= {self.right}" - - def build(self): - match self.operator: - case ComparisonOperator.EMPTY: - return self._empty() - case ComparisonOperator.NOT_EMPTY: - return self._not_empty() - case ComparisonOperator.CONTAINS: - return self._contains() - case ComparisonOperator.NOT_CONTAINS: - return self._not_contains() - case ComparisonOperator.START_WITH: - return self._startwith() - case ComparisonOperator.END_WITH: - return self._endwith() - case ComparisonOperator.EQ: - return self._eq() - case ComparisonOperator.NE: - return self._ne() - case ComparisonOperator.LT: - return self._lt() - case ComparisonOperator.LE: - return self._le() - case ComparisonOperator.GT: - return self._gt() - case ComparisonOperator.GE: - return self._ge() - case _: - raise ValueError(f"Invalid condition: {self.operator}") - - class IfElseNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) @@ -143,7 +63,7 @@ class IfElseNode(BaseNode): branch_conditions = [ self._build_condition_expression(condition) - for condition in case_branch.conditions + for condition in case_branch.expressions ] if len(branch_conditions) > 1: combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions) @@ -174,5 +94,6 @@ class IfElseNode(BaseNode): for i in range(len(expressions)): logger.info(expressions[i]) if self._evaluate_condition(expressions[i], state): + logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}") return f'CASE{i + 1}' return f'CASE{len(expressions)}' diff --git a/api/app/core/workflow/nodes/jinja_render/node.py b/api/app/core/workflow/nodes/jinja_render/node.py index 60beefb6..6130c30a 100644 --- a/api/app/core/workflow/nodes/jinja_render/node.py +++ b/api/app/core/workflow/nodes/jinja_render/node.py @@ -1,3 +1,4 @@ +import logging from typing import Any from app.core.workflow.nodes import WorkflowState @@ -5,6 +6,7 @@ from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig from app.core.workflow.template_renderer import TemplateRenderer +logger = logging.getLogger(__name__) class JinjaRenderNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): @@ -41,5 +43,5 @@ class JinjaRenderNode(BaseNode): res = render.env.from_string(self.typed_config.template).render(**context) except Exception as e: raise RuntimeError(f"JinjaRender Node {self.node_name} render failed: {e}") from e - + logger.info(f"Node {self.node_id}: Jinja template rendering completed") return res diff --git a/api/app/core/workflow/nodes/knowledge/config.py b/api/app/core/workflow/nodes/knowledge/config.py index 09c23855..cdb83131 100644 --- a/api/app/core/workflow/nodes/knowledge/config.py +++ b/api/app/core/workflow/nodes/knowledge/config.py @@ -1,18 +1,13 @@ from uuid import UUID -from pydantic import Field +from pydantic import Field, BaseModel from app.core.workflow.nodes.base_config import BaseNodeConfig from app.schemas.chunk_schema import RetrieveType -class KnowledgeRetrievalNodeConfig(BaseNodeConfig): - query: str = Field( - ..., - description="Search query string" - ) - - kb_ids: list[UUID] = Field( +class KnowledgeBaseConfig(BaseModel): + kb_id: UUID = Field( ..., description="Knowledge base IDs" ) @@ -37,18 +32,42 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig): description="Retrieve type" ) + +class KnowledgeRetrievalNodeConfig(BaseNodeConfig): + query: str = Field( + ..., + description="Search query string" + ) + + knowledge_bases: list[KnowledgeBaseConfig] = Field( + ..., + description="Knowledge base config" + ) + + reranker_id: UUID = Field( + ..., + description="Reranker top k" + ) + + reranker_top_k: int = Field( + default=4, + description="Knowledge base top k" + ) + class Config: json_schema_extra = { "examples": [ { "query": "{{sys.message}}", - "kb_ids": [ - "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" - ], - "similarity_threshold": 0.2, - "vector_similarity_weight": 0.3, - "top_k": 1, - "retrieve_type": "hybrid" + "knowledge_bases": [{ + "kb_id": "xxxxxxxx-xxxx-xxxx-xxxxxxxxxxxxxxxxx", + "similarity_threshold": 0.2, + "vector_similarity_weight": 0.3, + "top_k": 4, + "retrieve_type": "hybrid" + }], + "reranker_top_k": 1, + "reranker_id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" } ] } diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 319a0b88..e12c6224 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -2,14 +2,18 @@ import logging import uuid from typing import Any +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException +from app.core.models import RedBearRerank, RedBearModelConfig from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig from app.db import get_db_read -from app.models import knowledge_model, knowledgeshare_model +from app.models import knowledge_model, knowledgeshare_model, ModelType from app.repositories import knowledge_repository from app.schemas.chunk_schema import RetrieveType from app.services import knowledge_service, knowledgeshare_service +from app.services.model_service import ModelConfigService logger = logging.getLogger(__name__) @@ -108,6 +112,44 @@ class KnowledgeRetrievalNode(BaseNode): existing_ids.extend(items) return existing_ids + def get_reranker_model(self) -> RedBearRerank: + """ + Retrieve and initialize a RedBear reranker model based on configuration. + + Raises: + BusinessException: If configuration is missing or API keys are not set. + RuntimeError: If the configured model is not of type RERANK. + """ + with get_db_read() as db: + config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.reranker_id) + + if not config: + raise BusinessException("Configured model does not exist", BizCode.NOT_FOUND) + + if not config.api_keys or len(config.api_keys) == 0: + raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER) + + # 在 Session 关闭前提取所有需要的数据 + api_config = config.api_keys[0] + model_name = api_config.model_name + provider = api_config.provider + api_key = api_config.api_key + api_base = api_config.api_base + model_type = config.type + + if model_type != ModelType.RERANK: + raise RuntimeError("Model is not a reranker") + + reranker = RedBearRerank( + RedBearModelConfig( + model_name=model_name, + provider=provider, + api_key=api_key, + base_url=api_base, + ) + ) + return reranker + async def execute(self, state: WorkflowState) -> Any: """ Execute the knowledge retrieval workflow node. @@ -131,38 +173,45 @@ class KnowledgeRetrievalNode(BaseNode): """ query = self._render_template(self.typed_config.query, state) with get_db_read() as db: - existing_ids = self._get_existing_kb_ids(db, self.typed_config.kb_ids) + knowledge_bases = self.typed_config.knowledge_bases + existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases]) if not existing_ids: raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") - kb_id = existing_ids[0] - uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] - indices = ",".join(uuid_strs) + rs = [] + for kb_config in knowledge_bases: + db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id) + if not db_knowledge: + raise RuntimeError("The knowledge base does not exist or access is denied.") - db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id) - if not db_knowledge: - raise RuntimeError("The knowledge base does not exist or access is denied.") - - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - - match self.typed_config.retrieve_type: - case RetrieveType.PARTICIPLE: - rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.similarity_threshold) - case RetrieveType.SEMANTIC: - rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.vector_similarity_weight) - case _: - rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.vector_similarity_weight) - rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.similarity_threshold) - # Deduplicate hybrid retrieval results - unique_rs = self._deduplicate_docs(rs1, rs2) - rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k) - return [chunk.model_dump() for chunk in rs] + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + indices = f"Vector_index_{kb_config.kb_id}_Node".lower() + match kb_config.retrieve_type: + case RetrieveType.PARTICIPLE: + rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.similarity_threshold)) + case RetrieveType.SEMANTIC: + rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.vector_similarity_weight)) + case RetrieveType.HYBRID: + rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.vector_similarity_weight) + rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.similarity_threshold) + # Deduplicate hybrid retrieval results + unique_rs = self._deduplicate_docs(rs1, rs2) + vector_service.reranker = self.get_reranker_model() + rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) + case _: + raise RuntimeError("Unknown retrieval type") + vector_service.reranker = self.get_reranker_model() + final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) + logger.info( + f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}" + ) + return [chunk.model_dump() for chunk in final_rs] diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 90c48ac0..ed26533d 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -10,6 +10,7 @@ from typing import Any, Union from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.assigner import AssignerNode from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.http_request import HttpRequestNode @@ -22,6 +23,7 @@ from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode from app.core.workflow.nodes.question_classifier import QuestionClassifierNode +from app.core.workflow.nodes.breaker import BreakNode logger = logging.getLogger(__name__) @@ -39,6 +41,9 @@ WorkflowNode = Union[ JinjaRenderNode, VariableAggregatorNode, ParameterExtractorNode, + CycleGraphNode, + BreakNode, + ParameterExtractorNode, QuestionClassifierNode ] @@ -64,6 +69,9 @@ class NodeFactory: NodeType.VAR_AGGREGATOR: VariableAggregatorNode, NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, + NodeType.LOOP: CycleGraphNode, + NodeType.ITERATION: CycleGraphNode, + NodeType.BREAK: BreakNode, } @classmethod diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py index a80cf326..70668b6a 100644 --- a/api/app/core/workflow/nodes/operators.py +++ b/api/app/core/workflow/nodes/operators.py @@ -1,6 +1,7 @@ from abc import ABC from typing import Union, Type +from app.core.workflow.nodes.enums import ComparisonOperator from app.core.workflow.variable_pool import VariablePool @@ -136,6 +137,23 @@ class ObjectOperator(OperatorBase): self.pool.set(self.left_selector, dict()) +class AssignmentOperatorResolver: + @classmethod + def resolve_by_value(cls, value): + if isinstance(value, str): + return StringOperator + elif isinstance(value, bool): + return BooleanOperator + elif isinstance(value, (int, float)): + return NumberOperator + elif isinstance(value, list): + return ArrayOperator + elif isinstance(value, dict): + return ObjectOperator + else: + raise TypeError(f"Unsupported variable type: {type(value)}") + + AssignmentOperatorInstance = Union[ StringOperator, NumberOperator, @@ -144,3 +162,83 @@ AssignmentOperatorInstance = Union[ ObjectOperator ] AssignmentOperatorType = Type[AssignmentOperatorInstance] + + +class ConditionExpressionBuilder: + """ + Build a Python boolean expression string based on a comparison operator. + + This class does not evaluate the expression. + It only generates a valid Python expression string + that can be evaluated later in a workflow context. + """ + + def __init__(self, left: str, operator: ComparisonOperator, right: str): + self.left = left + self.operator = operator + self.right = right + + def _empty(self): + return f"{self.left} == ''" + + def _not_empty(self): + return f"{self.left} != ''" + + def _contains(self): + return f"{self.right} in {self.left}" + + def _not_contains(self): + return f"{self.right} not in {self.left}" + + def _startswith(self): + return f'{self.left}.startswith({self.right})' + + def _endswith(self): + return f'{self.left}.endswith({self.right})' + + def _eq(self): + return f"{self.left} == {self.right}" + + def _ne(self): + return f"{self.left} != {self.right}" + + def _lt(self): + return f"{self.left} < {self.right}" + + def _le(self): + return f"{self.left} <= {self.right}" + + def _gt(self): + return f"{self.left} > {self.right}" + + def _ge(self): + return f"{self.left} >= {self.right}" + + def build(self): + match self.operator: + case ComparisonOperator.EMPTY: + return self._empty() + case ComparisonOperator.NOT_EMPTY: + return self._not_empty() + case ComparisonOperator.CONTAINS: + return self._contains() + case ComparisonOperator.NOT_CONTAINS: + return self._not_contains() + case ComparisonOperator.START_WITH: + return self._startswith() + case ComparisonOperator.END_WITH: + return self._endswith() + case ComparisonOperator.EQ: + return self._eq() + case ComparisonOperator.NE: + return self._ne() + case ComparisonOperator.LT: + return self._lt() + case ComparisonOperator.LE: + return self._le() + case ComparisonOperator.GT: + return self._gt() + case ComparisonOperator.GE: + return self._ge() + case _: + raise ValueError(f"Invalid condition: {self.operator}") diff --git a/api/app/core/workflow/nodes/parameter_extractor/config.py b/api/app/core/workflow/nodes/parameter_extractor/config.py index 30c0e1ef..3b5607c5 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/config.py +++ b/api/app/core/workflow/nodes/parameter_extractor/config.py @@ -36,6 +36,11 @@ class ParamsConfig(BaseModel): description="Description of the parameter" ) + required: bool = Field( + ..., + description="Whether the parameter is required" + ) + class ParameterExtractorNodeConfig(BaseNodeConfig): model_id: uuid.UUID = Field( @@ -52,3 +57,8 @@ class ParameterExtractorNodeConfig(BaseNodeConfig): ..., description="List of parameters" ) + + prompt: str = Field( + ..., + description="User-provided supplemental prompt" + ) diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index 0eb3bfd4..af2a4478 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -1,4 +1,5 @@ import os +import logging import json_repair from typing import Any @@ -15,6 +16,8 @@ from app.db import get_db_read from app.models import ModelType from app.services.model_service import ModelConfigService +logger = logging.getLogger(__name__) + class ParameterExtractorNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): @@ -114,7 +117,7 @@ class ParameterExtractorNode(BaseNode): """ field_type = {} for param in self.typed_config.params: - field_type[param.name] = param.type + field_type[param.name] = f'{param.type}, required:{str(param.required)}' return field_type async def execute(self, state: WorkflowState) -> Any: @@ -154,12 +157,12 @@ class ParameterExtractorNode(BaseNode): messages = [ ("system", system_prompt), + ("user", self._render_template(self.typed_config.prompt, state)), ("user", rendered_user_prompt), ] model_resp = await llm.ainvoke(messages) - result = json_repair.repair_json(model_resp.content) + result = json_repair.repair_json(model_resp.content, return_objects=True) + logger.info(f"node: {self.node_id} get params:{result}") - return { - "output": result, - } + return result diff --git a/api/app/core/workflow/nodes/variable_aggregator/config.py b/api/app/core/workflow/nodes/variable_aggregator/config.py index 84f82487..ac1419a4 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/config.py +++ b/api/app/core/workflow/nodes/variable_aggregator/config.py @@ -9,43 +9,27 @@ class VariableAggregatorNodeConfig(BaseNodeConfig): description="输出变量是否需要分组", ) - group_names: list[str] = Field( - default_factory=lambda: ["output"], - description="各个分组的名称" - ) - - group_variables: list[str] | list[list[str]] = Field( + group_variables: list[str] | dict[str, list[str]] = Field( ..., description="需要被聚合的变量" ) - @field_validator("group_names", mode="before") - @classmethod - def group_names_validator(cls, v, info): - group_status = info.data.get("group") - if not group_status or not v: - return ["output"] - return v - @field_validator("group_variables") @classmethod def group_variables_validator(cls, v, info): group_status = info.data.get("group") - group_names = info.data.get("group_names") - if not isinstance(v, list): - raise ValueError("group_variables must be a list") if not group_status: for variable in v: if not isinstance(variable, str): raise ValueError("When group=False, group_variables must be a list of strings") else: - if len(group_names) != len(v): - raise ValueError("group_names and group_variables length mismatch") - for group in v: - if not isinstance(group, list): + if not isinstance(v, dict): + raise ValueError("When group=True, group_variables must be a dict") + for group_name, group_values in v.items(): + if not isinstance(group_name, str): raise ValueError("When group=True, each element of group_variables must be a list") - for variable in group: + for variable in group_values: if not isinstance(variable, str): raise ValueError("Each element inside group_variables lists must be a string") return v diff --git a/api/app/core/workflow/nodes/variable_aggregator/node.py b/api/app/core/workflow/nodes/variable_aggregator/node.py index f53f9269..e6cbf75b 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/node.py +++ b/api/app/core/workflow/nodes/variable_aggregator/node.py @@ -50,6 +50,7 @@ class VariableAggregatorNode(BaseNode): continue if value is not None: + logger.info(f"Node: {self.node_id} variable aggregation result: {value}") return value logger.info("No variable found in non-group mode; returning empty string.") @@ -59,7 +60,7 @@ class VariableAggregatorNode(BaseNode): # Group mode # -------------------------- result = {} - for group_name, variables in zip(self.typed_config.group_names, self.typed_config.group_variables): + for group_name, variables in self.typed_config.group_variables.items(): for variable in variables: var_express = self._get_express(variable) try: @@ -74,5 +75,5 @@ class VariableAggregatorNode(BaseNode): else: result[group_name] = "" logger.info(f"No variable found for group '{group_name}'; set empty string.") - + logger.info(f"Node: {self.node_id} variable aggregation result: {result}") return result diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index 58bc20b9..00358d91 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -7,14 +7,87 @@ import logging from typing import Any, Union +from app.core.workflow.nodes.enums import NodeType + logger = logging.getLogger(__name__) class WorkflowValidator: """工作流配置验证器""" - - @staticmethod - def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]: + + @classmethod + def pure_cycle_graph(cls, workflow_config: Union[dict[str, Any], Any], node_id) -> tuple[list, list]: + """ + Extract cycle nodes and internal edges from the workflow configuration, + removing them from the global workflow. + + Raises: + ValueError: If cycle nodes are connected to external nodes improperly. + + Returns: + Tuple containing: + - cycle_nodes: List of removed nodes + - cycle_edges: List of removed edges + """ + nodes = workflow_config.get("nodes", []) + edges = workflow_config.get("edges", []) + + # Select all nodes that belong to the current cycle + cycle_nodes = [node for node in nodes if node.get("cycle") == node_id] + cycle_node_ids = {node.get("id") for node in cycle_nodes} + + cycle_edges = [] + remain_edges = [] + + for edge in edges: + source_in = edge.get("source") in cycle_node_ids + target_in = edge.get("target") in cycle_node_ids + + # Raise error if cycle nodes are connected with external nodes + if source_in ^ target_in: + raise ValueError( + f"Cycle node is connected to external node, " + f"source: {edge.get('source')}, target: {edge.get('target')}" + ) + + if source_in and target_in: + cycle_edges.append(edge) + else: + remain_edges.append(edge) + + # Update workflow_config by removing cycle nodes and internal edges + workflow_config["nodes"] = [ + node for node in nodes if node.get("cycle") != node_id + ] + workflow_config["edges"] = remain_edges + + return cycle_nodes, cycle_edges + + @classmethod + def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list: + if not isinstance(workflow_config, dict): + workflow_config = { + "nodes": workflow_config.nodes, + "edges": workflow_config.edges, + "variables": workflow_config.variables, + } + cycle_nodes = [ + node.get("id") + for node in workflow_config.get("nodes", []) + if node.get("type") in [NodeType.LOOP, NodeType.ITERATION] + ] + graphs = [] + for cycle_node in cycle_nodes: + nodes, edges = cls.pure_cycle_graph(workflow_config, cycle_node) + graphs.append({ + "nodes": nodes, + "edges": edges, + }) + graphs.append(workflow_config) + return graphs + + @classmethod + def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]: """验证工作流配置 Args: @@ -38,84 +111,79 @@ class WorkflowValidator: True """ errors = [] - - # 支持字典和 Pydantic 模型 - if isinstance(workflow_config, dict): - nodes = workflow_config.get("nodes", []) - edges = workflow_config.get("edges", []) - variables = workflow_config.get("variables", []) - else: - # Pydantic 模型 - nodes = getattr(workflow_config, "nodes", []) - edges = getattr(workflow_config, "edges", []) - variables = getattr(workflow_config, "variables", []) - - # 1. 验证 start 节点(有且只有一个) - start_nodes = [n for n in nodes if n.get("type") == "start"] - if len(start_nodes) == 0: - errors.append("工作流必须有一个 start 节点") - elif len(start_nodes) > 1: - errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个") - - # 2. 验证 end 节点(至少一个) - end_nodes = [n for n in nodes if n.get("type") == "end"] - if len(end_nodes) == 0: - errors.append("工作流必须至少有一个 end 节点") - - # 3. 验证节点 ID 唯一性 - node_ids = [n.get("id") for n in nodes] - if len(node_ids) != len(set(node_ids)): - duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1] - errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}") - - # 4. 验证节点必须有 id 和 type - for i, node in enumerate(nodes): - if not node.get("id"): - errors.append(f"节点 #{i} 缺少 id 字段") - if not node.get("type"): - errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段") - - # 5. 验证边的有效性 - node_id_set = set(node_ids) - for i, edge in enumerate(edges): - source = edge.get("source") - target = edge.get("target") - - if not source: - errors.append(f"边 #{i} 缺少 source 字段") - elif source not in node_id_set: - errors.append(f"边 #{i} 的 source 节点不存在: {source}") - - if not target: - errors.append(f"边 #{i} 缺少 target 字段") - elif target not in node_id_set: - errors.append(f"边 #{i} 的 target 节点不存在: {target}") - - # 6. 验证所有节点可达(从 start 节点出发) - if start_nodes and not errors: # 只有在前面验证通过时才检查可达性 - reachable = WorkflowValidator._get_reachable_nodes( - start_nodes[0]["id"], - edges - ) - unreachable = node_id_set - reachable - if unreachable: - errors.append(f"以下节点无法从 start 节点到达: {unreachable}") - - # 7. 检测循环依赖(非 loop 节点) - if not errors: # 只有在前面验证通过时才检查循环 - has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges) - if has_cycle: - errors.append( - f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}" + + graphs = cls.get_subgraph(workflow_config) + logger.info(graphs) + for graph in graphs: + nodes = graph.get("nodes", []) + edges = graph.get("edges", []) + variables = graph.get("variables", []) + # 1. 验证 start 节点(有且只有一个) + start_nodes = [n for n in nodes if n.get("type") in [NodeType.START, NodeType.CYCLE_START]] + if len(start_nodes) == 0: + errors.append("工作流必须有一个 start 节点") + elif len(start_nodes) > 1: + errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个") + + # 2. 验证 end 节点(至少一个) + end_nodes = [n for n in nodes if n.get("type") == NodeType.END] + if len(end_nodes) == 0: + errors.append("工作流必须至少有一个 end 节点") + + # 3. 验证节点 ID 唯一性 + node_ids = [n.get("id") for n in nodes] + if len(node_ids) != len(set(node_ids)): + duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1] + errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}") + + # 4. 验证节点必须有 id 和 type + for i, node in enumerate(nodes): + if not node.get("id"): + errors.append(f"节点 #{i} 缺少 id 字段") + if not node.get("type"): + errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段") + + # 5. 验证边的有效性 + node_id_set = set(node_ids) + for i, edge in enumerate(edges): + source = edge.get("source") + target = edge.get("target") + + if not source: + errors.append(f"边 #{i} 缺少 source 字段") + elif source not in node_id_set: + errors.append(f"边 #{i} 的 source 节点不存在: {source}") + + if not target: + errors.append(f"边 #{i} 缺少 target 字段") + elif target not in node_id_set: + errors.append(f"边 #{i} 的 target 节点不存在: {target}") + + # 6. 验证所有节点可达(从 start 节点出发) + if start_nodes and not errors: # 只有在前面验证通过时才检查可达性 + reachable = WorkflowValidator._get_reachable_nodes( + start_nodes[0]["id"], + edges ) - - # 8. 验证变量名 - from app.core.workflow.expression_evaluator import ExpressionEvaluator - var_errors = ExpressionEvaluator.validate_variable_names(variables) - errors.extend(var_errors) - + unreachable = node_id_set - reachable + if unreachable: + errors.append(f"以下节点无法从 start 节点到达: {unreachable}") + + # 7. 检测循环依赖(非 loop 节点) + if not errors: # 只有在前面验证通过时才检查循环 + has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges) + if has_cycle: + errors.append( + f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}" + ) + + # 8. 验证变量名 + from app.core.workflow.expression_evaluator import ExpressionEvaluator + var_errors = ExpressionEvaluator.validate_variable_names(variables) + errors.extend(var_errors) + return len(errors) == 0, errors - + @staticmethod def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]: """获取从 start 节点可达的所有节点 @@ -129,7 +197,7 @@ class WorkflowValidator: """ reachable = {start_id} queue = [start_id] - + while queue: current = queue.pop(0) for edge in edges: @@ -138,9 +206,9 @@ class WorkflowValidator: if target and target not in reachable: reachable.add(target) queue.append(target) - + return reachable - + @staticmethod def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]: """检测是否存在循环依赖(DFS) @@ -154,39 +222,39 @@ class WorkflowValidator: """ # 排除 loop 类型的节点 loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"} - + # 构建邻接表(排除 loop 节点的边和错误边) graph: dict[str, list[str]] = {} for edge in edges: source = edge.get("source") target = edge.get("target") edge_type = edge.get("type") - + # 跳过错误边 if edge_type == "error": continue - + # 如果涉及 loop 节点,跳过 if source in loop_nodes or target in loop_nodes: continue - + if source and target: if source not in graph: graph[source] = [] graph[source].append(target) - + # DFS 检测环 visited = set() rec_stack = set() path = [] cycle_path = [] - + def dfs(node: str) -> bool: """DFS 检测环,返回是否找到环""" visited.add(node) rec_stack.add(node) path.append(node) - + for neighbor in graph.get(node, []): if neighbor not in visited: if dfs(neighbor): @@ -196,19 +264,19 @@ class WorkflowValidator: cycle_start = path.index(neighbor) cycle_path.extend([*path[cycle_start:], neighbor]) return True - + rec_stack.remove(node) path.pop() return False - + # 检查所有节点 for node_id in graph: if node_id not in visited: if dfs(node_id): return True, cycle_path - + return False, [] - + @staticmethod def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]: """验证工作流配置是否可以发布(更严格的验证) @@ -221,30 +289,30 @@ class WorkflowValidator: """ # 先执行基础验证 is_valid, errors = WorkflowValidator.validate(workflow_config) - + if not is_valid: return False, errors - + # 额外的发布验证 nodes = workflow_config.get("nodes", []) - + # 1. 验证所有节点都有名称 for node in nodes: - if node.get("type") not in ["start", "end"] and not node.get("name"): + if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"): errors.append( f"节点 {node.get('id')} 缺少名称(发布时必须提供)" ) - + # 2. 验证所有非 start/end 节点都有配置 for node in nodes: node_type = node.get("type") - if node_type not in ["start", "end"]: + if node_type not in [NodeType.START, NodeType.CYCLE_START, NodeType.END, NodeType.BREAK]: config = node.get("config") if not config or not isinstance(config, dict): errors.append( f"节点 {node.get('id')} 缺少配置(发布时必须提供)" ) - + # 3. 验证必填变量 variables = workflow_config.get("variables", []) required_vars = [v for v in variables if v.get("required")] @@ -254,13 +322,13 @@ class WorkflowValidator: f"工作流包含 {len(required_vars)} 个必填变量: " f"{[v.get('name') for v in required_vars]}" ) - + return len(errors) == 0, errors def validate_workflow_config( - workflow_config: dict[str, Any], - for_publish: bool = False + workflow_config: dict[str, Any], + for_publish: bool = False ) -> tuple[bool, list[str]]: """验证工作流配置(便捷函数) diff --git a/api/app/core/workflow/variable_pool.py b/api/app/core/workflow/variable_pool.py index 0f97c349..b7814f28 100644 --- a/api/app/core/workflow/variable_pool.py +++ b/api/app/core/workflow/variable_pool.py @@ -198,19 +198,22 @@ class VariablePool: namespace = selector[0] - if namespace != "conv": - raise ValueError("只能设置会话变量 (conv.*)") + if namespace != "conv" and namespace not in self.state["cycle_nodes"]: + raise ValueError("Only conversation or cycle variables can be assigned.") key = selector[1] # 确保 variables 结构存在 if "variables" not in self.state: self.state["variables"] = {"sys": {}, "conv": {}} - if "conv" not in self.state["variables"]: - self.state["variables"]["conv"] = {} - - # 设置值 - self.state["variables"]["conv"][key] = value + if namespace == "conv": + if "conv" not in self.state["variables"]: + self.state["variables"]["conv"] = {} + + # 设置值 + self.state["variables"]["conv"][key] = value + elif namespace in self.state["cycle_nodes"]: + self.state["runtime_vars"][namespace][key] = value logger.debug(f"设置变量: {'.'.join(selector)} = {value}") diff --git a/api/app/models/document_model.py b/api/app/models/document_model.py index db9280c6..fb43d44d 100644 --- a/api/app/models/document_model.py +++ b/api/app/models/document_model.py @@ -26,6 +26,7 @@ class Document(Base): "html4excel": False, "graphrag": { "use_graphrag": False, + "scene_name": "", "entity_types": [ "organization", "person", @@ -33,7 +34,9 @@ class Document(Base): "event", "category" ], - "method": "general" + "method": "general", + "resolution": True, + "community": True } }, comment="default parser config") chunk_num = Column(Integer, default=0, comment="chunk num") diff --git a/api/app/models/knowledge_model.py b/api/app/models/knowledge_model.py index 6d3465f9..8f0909d3 100644 --- a/api/app/models/knowledge_model.py +++ b/api/app/models/knowledge_model.py @@ -65,6 +65,7 @@ class Knowledge(Base): "html4excel": False, "graphrag": { "use_graphrag": False, + "scene_name": "", "entity_types": [ "organization", "person", @@ -72,7 +73,9 @@ class Knowledge(Base): "event", "category" ], - "method": "general" + "method": "general", + "resolution": True, + "community": True } }, comment="default parser config") diff --git a/api/app/repositories/knowledge_repository.py b/api/app/repositories/knowledge_repository.py index b7908cb0..681d1c10 100644 --- a/api/app/repositories/knowledge_repository.py +++ b/api/app/repositories/knowledge_repository.py @@ -58,13 +58,13 @@ def get_chunked_knowledgeids( ) -> list: """ Query the list of vectorized knowledge base IDs - Return: list[UUID] - List of knowledge base IDs + Return: list[(id,workspace_id)] - List of knowledge base id and workspace_id """ db_logger.debug(f"Query the list of vectorized knowledge base IDs: filters_count={len(filters)}") try: # Only query the id field - query = db.query(Knowledge.id) + query = db.query(Knowledge.id, Knowledge.workspace_id) # Apply filter conditions for filter_cond in filters: @@ -74,8 +74,8 @@ def get_chunked_knowledgeids( items = query.all() db_logger.info(f"Querying the vectorized knowledge base id list succeeded: count={len(items)}") - # Return the list of IDs directly. Since only the ID field is queried, the returned data is a single column - return [item[0] for item in items] + # Return the list of ID and workspace_id directly. Since only the ID and workspace_id field is queried + return items except Exception as e: db_logger.error(f"Querying the vectorized knowledge base id list failed: {str(e)}") raise diff --git a/api/app/repositories/knowledgeshare_repository.py b/api/app/repositories/knowledgeshare_repository.py index e4976b8d..28bb07e9 100644 --- a/api/app/repositories/knowledgeshare_repository.py +++ b/api/app/repositories/knowledgeshare_repository.py @@ -61,14 +61,14 @@ def get_source_kb_ids_by_target_kb_id( ) -> list: """ Query the original knowledge base ID list by sharing the knowledge base - Return: list[UUID] - List of knowledge base IDs + Return: list[(source_kb_id,source_workspace_id)] - List of knowledge base source_kb_id and source_workspace_id """ db_logger.debug( f"Query the original knowledge base id list by sharing the knowledge base: filters_count={len(filters)}") try: # Only query the id field - query = db.query(KnowledgeShare.source_kb_id) + query = db.query(KnowledgeShare.source_kb_id, KnowledgeShare.source_workspace_id) # Apply filter conditions for filter_cond in filters: @@ -78,8 +78,8 @@ def get_source_kb_ids_by_target_kb_id( items = query.all() db_logger.info(f"Successfully queried the original knowledge base ID list by sharing the knowledge base: count={len(items)}") - # Return the list of IDs directly. Since only the ID field is queried, the returned data is a single column - return [item[0] for item in items] + # Return the list of source_kb_id and source_workspace_id directly. Since only the source_kb_id and source_workspace_id field is queried + return items except Exception as e: db_logger.error(f"Failed to query the original knowledge base ID list through knowledge base sharing: {str(e)}") raise diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index de0a4c53..81cd704d 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -32,6 +32,7 @@ class KnowledgeRetrievalConfig(BaseModel): ) reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID") reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数") + use_graph: bool = Field(default=False, description="是否使用图搜索") class ToolConfig(BaseModel): diff --git a/api/app/schemas/chunk_schema.py b/api/app/schemas/chunk_schema.py index c38fe765..cef9b9cb 100644 --- a/api/app/schemas/chunk_schema.py +++ b/api/app/schemas/chunk_schema.py @@ -10,6 +10,7 @@ class RetrieveType(StrEnum): PARTICIPLE = "participle" SEMANTIC = "semantic" HYBRID = "hybrid" + Graph = "graph" class ChunkCreate(BaseModel): diff --git a/api/app/schemas/memory_reflection_schemas.py b/api/app/schemas/memory_reflection_schemas.py index ada92cf2..860f1ef1 100644 --- a/api/app/schemas/memory_reflection_schemas.py +++ b/api/app/schemas/memory_reflection_schemas.py @@ -12,8 +12,8 @@ class Memory_Reflection(BaseModel): config_id: Optional[int] = None reflection_enabled: bool reflection_period_in_hours: str - reflexion_range: str - baseline: str + reflexion_range: Optional[str] = "partial" + baseline: Optional[str] = "TIME" reflection_model_id: str memory_verify: bool quality_assessment: bool diff --git a/api/app/schemas/workflow_schema.py b/api/app/schemas/workflow_schema.py index eb337298..bdef825e 100644 --- a/api/app/schemas/workflow_schema.py +++ b/api/app/schemas/workflow_schema.py @@ -20,6 +20,7 @@ class NodeDefinition(BaseModel): id: str = Field(..., description="节点唯一标识") type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code") name: str | None = Field(None, description="节点名称") + cycle: str | None = Field(None, description="父循环节点id") description: str | None = Field(None, description="节点描述") config: dict[str, Any] = Field(default_factory=dict, description="节点配置") position: dict[str, float] | None = Field(None, description="节点位置 {x, y}") diff --git a/api/app/templates/workflows/simple_qa/template.yml b/api/app/templates/workflows/simple_qa/template.yml index ab1af3c2..dc3a7c70 100644 --- a/api/app/templates/workflows/simple_qa/template.yml +++ b/api/app/templates/workflows/simple_qa/template.yml @@ -42,7 +42,7 @@ nodes: - 适当使用格式化(如列表、段落)提高可读性 - role: user - content: "{{ sys.message }}" + content: "{{sys.message}}" model_id: null temperature: 0.7 @@ -55,7 +55,7 @@ nodes: type: end name: 结束 config: - output: "{{ llm_qa.output }}" + output: "{{llm_qa.output}}" position: x: 900 y: 100 diff --git a/api/pyproject.toml b/api/pyproject.toml index 901858e6..2dcc706d 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -135,6 +135,8 @@ dependencies = [ "graspologic==3.4.5.dev2", "markdown-to-json==2.1.1", "valkey==6.0.2", + "python-calamine>=0.4.0", + "xlrd==2.0.2" ] [tool.pytest.ini_options] diff --git a/api/requirements.txt b/api/requirements.txt index 5530a9e3..99252e09 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -129,3 +129,5 @@ editdistance==0.8.1 graspologic==3.4.5.dev2 markdown-to-json==2.1.1 valkey==6.0.2 +python-calamine>=0.4.0 +xlrd==2.0.2