Merge branch 'refs/heads/develop' into feature/20251219_xjn

This commit is contained in:
谢俊男
2025-12-30 21:05:31 +08:00
47 changed files with 1313 additions and 403 deletions

View File

@@ -18,6 +18,9 @@ from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success from app.core.response_utils import success
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.rag.common.settings import kg_retriever
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger # Obtain a dedicated API logger
@@ -389,36 +392,41 @@ async def retrieve_chunks(
knowledge_model.Knowledge.chunk_num > 0, knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1 knowledge_model.Knowledge.status == 1
] ]
existing_ids = knowledge_service.get_chunded_knowledgeids( private_items = knowledge_service.get_chunded_knowledgeids(
db=db, db=db,
filters=filters, filters=filters,
current_user=current_user current_user=current_user
) )
private_kb_ids = [item[0] for item in private_items]
private_workspace_ids = [item[1] for item in private_items]
filters = [ filters = [
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids), knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share, knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
knowledge_model.Knowledge.chunk_num > 0, knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1 knowledge_model.Knowledge.status == 1
] ]
share_ids = knowledge_service.get_chunded_knowledgeids( items = knowledge_service.get_chunded_knowledgeids(
db=db, db=db,
filters=filters, filters=filters,
current_user=current_user current_user=current_user
) )
if share_ids: if items:
filters = [ filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids) knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids)
] ]
items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id( share_items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
db=db, db=db,
filters=filters, filters=filters,
current_user=current_user current_user=current_user
) )
existing_ids.extend(items) share_kb_ids = [item[0] for item in share_items]
if not existing_ids: share_workspace_ids = [item[1] for item in share_items]
private_kb_ids.extend(share_kb_ids)
private_workspace_ids.extend(share_workspace_ids)
if not private_kb_ids:
return success(data=[], msg="retrieval successful") return success(data=[], msg="retrieval successful")
kb_id = existing_ids[0] kb_id = private_kb_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in private_kb_ids]
indices = ",".join(uuid_strs) indices = ",".join(uuid_strs)
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user) db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge: if not db_knowledge:
@@ -448,4 +456,21 @@ async def retrieve_chunks(
seen_ids.add(doc.metadata["doc_id"]) seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc) unique_rs.append(doc)
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k) rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
# Prepare to configure chat_mdl、embedding_model、vision_model information
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base
)
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
rs.insert(0, doc)
return success(data=rs, msg="retrieval successful") return success(data=rs, msg="retrieval successful")

View File

@@ -86,5 +86,5 @@
- **quality_assessment**: - **quality_assessment**:
quality_assessment=true时输出评估对象否则为null注意- summary输出的结果不允许含有expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容) quality_assessment=true时输出评估对象否则为null注意- summary输出的结果不允许含有expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
- **memory_verify**: memory_verify=true时输出隐私检测对象否则为null - **memory_verify**: memory_verify=true时输出隐私检测对象否则为null
(注意:- summary输出的结果不允许含有expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容) (注意:- summary输出的结果不允许含有expired_at设为2024-01-01T00:00:00Z、memory_verify=true\memory_verify=false)等原数据字段以及涉及需要修改的字段以及内容)
模式参考:{{ json_schema }} 模式参考:{{ json_schema }}

View File

@@ -1,4 +1,3 @@
from typing import Any, Dict, List, Optional, Sequence, Type, Union from typing import Any, Dict, List, Optional, Sequence, Type, Union
from copy import deepcopy from copy import deepcopy
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -8,8 +7,10 @@ from langchain_core.callbacks import Callbacks
from app.core.models.base import RedBearModelConfig, get_provider_rerank_class, RedBearModelFactory from app.core.models.base import RedBearModelConfig, get_provider_rerank_class, RedBearModelFactory
from app.models import ModelProvider from app.models import ModelProvider
class RedBearRerank(BaseDocumentCompressor): class RedBearRerank(BaseDocumentCompressor):
""" Rerank → 作为 Runnable 插入任意 LCEL 链""" """ Rerank → 作为 Runnable 插入任意 LCEL 链"""
def __init__(self, config: RedBearModelConfig): def __init__(self, config: RedBearModelConfig):
self._model = self._create_model(config) self._model = self._create_model(config)
self._config = config self._config = config
@@ -22,10 +23,10 @@ class RedBearRerank(BaseDocumentCompressor):
return model_class(**model_params) return model_class(**model_params)
def compress_documents( def compress_documents(
self, self,
documents: Sequence[Document], documents: Sequence[Document],
query: str, query: str,
callbacks: Optional[Callbacks] = None, callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]: ) -> Sequence[Document]:
""" """
Compress documents using Jina's Rerank API. Compress documents using Jina's Rerank API.
@@ -46,17 +47,17 @@ class RedBearRerank(BaseDocumentCompressor):
compressed.append(doc_copy) compressed.append(doc_copy)
return compressed return compressed
def rerank( def rerank(
self, self,
documents: Sequence[Union[str, Document, dict]], documents: Sequence[Union[str, Document, dict]],
query: str, query: str,
*, *,
top_n: Optional[int] = -1, top_n: Optional[int] = -1,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
provider = self._config.provider.lower() provider = self._config.provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
import langchain_community.document_compressors.jina_rerank as jina_mod import langchain_community.document_compressors.jina_rerank as jina_mod
# 规范化:如果不以 /v1/rerank 结尾,则补齐;若已以 /v1 结尾,则补 /rerank # 规范化:如果不以 /v1/rerank 结尾,则补齐;若已以 /v1 结尾,则补 /rerank
def _normalize_jina_base(base_url: Optional[str]) -> Optional[str]: def _normalize_jina_base(base_url: Optional[str]) -> Optional[str]:
if not base_url: if not base_url:
@@ -73,8 +74,7 @@ class RedBearRerank(BaseDocumentCompressor):
# 设置完整的 rerank 端点,例如 http://host:port/v1/rerank # 设置完整的 rerank 端点,例如 http://host:port/v1/rerank
jina_mod.JINA_API_URL = jina_base jina_mod.JINA_API_URL = jina_base
from langchain_community.document_compressors import JinaRerank from langchain_community.document_compressors import JinaRerank
model_instance : JinaRerank = self._model model_instance: JinaRerank = self._model
return model_instance.rerank(documents = documents, query = query, top_n=top_n) return model_instance.rerank(documents=documents, query=query, top_n=top_n)
else: else:
raise ValueError(f"不支持的模型提供商: {provider}") raise ValueError(f"不支持的模型提供商: {provider}")

View File

@@ -51,7 +51,7 @@ def chunk(filename, binary, lang, callback=None, vision_model=None, **kwargs):
img_binary = io.BytesIO() img_binary = io.BytesIO()
img.save(img_binary, format="JPEG") img.save(img_binary, format="JPEG")
img_binary.seek(0) img_binary.seek(0)
ans = vision_model.describe(img_binary.read()) ans, ans_num_tokens = vision_model.describe(img_binary.read())
callback(0.8, "CV LLM respond: %s ..." % ans[:32]) callback(0.8, "CV LLM respond: %s ..." % ans[:32])
txt += "\n" + ans txt += "\n" + ans
tokenize(doc, txt, eng) tokenize(doc, txt, eng)

View File

@@ -42,11 +42,14 @@ class RAGExcelParser:
file_like_object.seek(0) file_like_object.seek(0)
try: try:
dfs = pd.read_excel(file_like_object, sheet_name=None) dfs = pd.read_excel(file_like_object, sheet_name=None)
if isinstance(dfs, dict):
dfs = next(iter(dfs.values()))
return RAGExcelParser._dataframe_to_workbook(dfs) return RAGExcelParser._dataframe_to_workbook(dfs)
except Exception as ex: except Exception as ex:
logging.info(f"pandas with default engine load error: {ex}, try calamine instead") logging.info(f"pandas with default engine load error: {ex}, try calamine instead")
file_like_object.seek(0) file_like_object.seek(0)
df = pd.read_excel(file_like_object, engine="calamine") df = pd.read_excel(file_like_object, engine="calamine")
print(df)
return RAGExcelParser._dataframe_to_workbook(df) return RAGExcelParser._dataframe_to_workbook(df)
except Exception as e_pandas: except Exception as e_pandas:
raise Exception(f"pandas.read_excel error: {e_pandas}, original openpyxl error: {e}") raise Exception(f"pandas.read_excel error: {e_pandas}, original openpyxl error: {e}")

View File

@@ -4,6 +4,7 @@ from collections import defaultdict
from copy import deepcopy from copy import deepcopy
import json_repair import json_repair
import pandas as pd import pandas as pd
import time
import trio import trio
from app.core.rag.common.misc_utils import get_uuid from app.core.rag.common.misc_utils import get_uuid
@@ -262,21 +263,21 @@ class KGSearch(Dealer):
relas = "" relas = ""
return { return {
"chunk_id": get_uuid(), "page_content": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms, comm_topn, max_token),
"content_ltks": "", "vector": None,
"page_content": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms, "metadata": {
comm_topn, max_token), "doc_id": get_uuid(),
"file_id": "",
"file_name": "Related content in Knowledge Graph",
"file_created_at": int(time.time() * 1000),
"document_id": "", "document_id": "",
"docnm_kwd": "Related content in Knowledge Graph", "knowledge_id": kb_ids,
"kb_id": kb_ids, "sort_id": 0,
"important_kwd": [], "status": 1,
"image_id": "", "score": 1
"similarity": 1., },
"vector_similarity": 1., "children": None
"term_similarity": 0, }
"vector": [],
"positions": [],
}
def _community_retrieval_(self, entities, condition, kb_ids, idxnms, topn, max_token): def _community_retrieval_(self, entities, condition, kb_ids, idxnms, topn, max_token):
## Community retrieval ## Community retrieval

View File

@@ -26,6 +26,8 @@ from app.core.rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr
from app.core.rag.common.string_utils import remove_redundant_spaces from app.core.rag.common.string_utils import remove_redundant_spaces
from app.core.rag.common.float_utils import get_float from app.core.rag.common.float_utils import get_float
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.embedding_model import OpenAIEmbed
def knowledge_retrieval( def knowledge_retrieval(
@@ -48,6 +50,7 @@ def knowledge_retrieval(
- merge_strategy: "weight" or other strategies - merge_strategy: "weight" or other strategies
- reranker_id: UUID of the reranker to use - reranker_id: UUID of the reranker to use
- reranker_top_k: int - reranker_top_k: int
- use_graph: bool, whether to use a graph
Returns: Returns:
Rearranged document block list (in descending order of relevance) Rearranged document block list (in descending order of relevance)
@@ -59,6 +62,7 @@ def knowledge_retrieval(
merge_strategy = config.get("merge_strategy", "weight") merge_strategy = config.get("merge_strategy", "weight")
reranker_id = config.get("reranker_id") reranker_id = config.get("reranker_id")
reranker_top_k = config.get("reranker_top_k", 1024) reranker_top_k = config.get("reranker_top_k", 1024)
use_graph = config.get("use_graph", "false").lower() == "true"
file_names_filter = [] file_names_filter = []
if user_ids: if user_ids:
@@ -67,6 +71,10 @@ def knowledge_retrieval(
if not knowledge_bases: if not knowledge_bases:
return [] return []
kb_ids = []
workspace_ids = []
chat_model = None
embedding_model = None
all_results = [] all_results = []
# Search each knowledge base # Search each knowledge base
for kb_config in knowledge_bases: for kb_config in knowledge_bases:
@@ -87,6 +95,22 @@ def knowledge_retrieval(
else: else:
continue continue
if str(db_knowledge.id) not in kb_ids:
kb_ids.append(str(db_knowledge.id))
if str(db_knowledge.workspace_id) not in workspace_ids:
workspace_ids.append(str(db_knowledge.workspace_id))
if not chat_model:
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
if not embedding_model:
embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# Retrieve according to the configured retrieval type # Retrieve according to the configured retrieval type
match kb_config["retrieve_type"]: match kb_config["retrieve_type"]:
@@ -136,6 +160,12 @@ def knowledge_retrieval(
# Use the specified reranker for re-ranking # Use the specified reranker for re-ranking
if reranker_id: if reranker_id:
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
# use graph
if use_graph:
from app.core.rag.common.settings import kg_retriever
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
all_results.insert(0, doc)
return all_results return all_results
except Exception as e: except Exception as e:

View File

@@ -213,7 +213,7 @@ class ESConnection(DocStoreConnection):
m.topn * 2, m.topn * 2,
query_vector=list(m.embedding_data), query_vector=list(m.embedding_data),
filter=bqry.to_dict(), filter=bqry.to_dict(),
similarity=similarity, # similarity=similarity
) )
if bqry and rank_feature: if bqry and rank_feature:

View File

@@ -4,9 +4,9 @@
基于 LangGraph 的工作流执行引擎。 基于 LangGraph 的工作流执行引擎。
""" """
import logging
# import uuid # import uuid
import datetime import datetime
import logging
from typing import Any from typing import Any
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
@@ -107,7 +107,13 @@ class WorkflowExecutor:
"user_id": self.user_id, "user_id": self.user_id,
"error": None, "error": None,
"error_node": None, "error_node": None,
"streaming_buffer": {} # 流式缓冲区 "streaming_buffer": {}, # 流式缓冲区
"cycle_nodes": [
node.get("id")
for node in self.workflow_config.get("nodes")
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
], # loop, iteration node id
"looping": False # loop runing flag, only use in loop node,not use in main loop
} }
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]: def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
@@ -199,6 +205,10 @@ class WorkflowExecutor:
for node in self.nodes: for node in self.nodes:
node_type = node.get("type") node_type = node.get("type")
node_id = node.get("id") node_id = node.get("id")
cycle_node = node.get("cycle")
if cycle_node:
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
continue
# 记录 start 和 end 节点 ID # 记录 start 和 end 节点 ID
if node_type == NodeType.START: if node_type == NodeType.START:
@@ -271,7 +281,7 @@ class WorkflowExecutor:
workflow.add_edge(START, start_node_id) workflow.add_edge(START, start_node_id)
logger.debug(f"添加边: START -> {start_node_id}") logger.debug(f"添加边: START -> {start_node_id}")
for edge in self.edges: for edge in self.workflow_config.get("edges", []):
source = edge.get("source") source = edge.get("source")
target = edge.get("target") target = edge.get("target")
edge_type = edge.get("type") edge_type = edge.get("type")
@@ -284,12 +294,12 @@ class WorkflowExecutor:
logger.debug(f"添加边: {source} -> {target}") logger.debug(f"添加边: {source} -> {target}")
continue continue
# 处理到 end 节点的边 # # 处理到 end 节点的边
if target in end_node_ids: # if target in end_node_ids:
# 连接到 end 节点 # # 连接到 end 节点
workflow.add_edge(source, target) # workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}") # logger.debug(f"添加边: {source} -> {target}")
continue # continue
# 跳过错误边(在节点内部处理) # 跳过错误边(在节点内部处理)
if edge_type == "error": if edge_type == "error":
@@ -297,22 +307,30 @@ class WorkflowExecutor:
if condition: if condition:
# 条件边 # 条件边
def router(state: WorkflowState, cond=condition, tgt=target): def make_router(cond, tgt):
"""条件路由函数""" """Dynamically generate a conditional router function to ensure each branch has a unique name."""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
workflow.add_conditional_edges(source, router)
def router_fn(state: WorkflowState):
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END
# 动态修改函数名,避免重复
router_fn.__name__ = f"router_{tgt}"
return router_fn
router_fn = make_router(condition, target)
workflow.add_conditional_edges(source, router_fn)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})") logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else: else:
# 普通边 # 普通边

View File

@@ -74,6 +74,7 @@ class ExpressionEvaluator:
# 为了向后兼容,也支持直接访问(但会在日志中警告) # 为了向后兼容,也支持直接访问(但会在日志中警告)
context.update(variables) context.update(variables)
context["nodes"] = node_outputs context["nodes"] = node_outputs
context.update(node_outputs)
try: try:
# simpleeval 只支持安全的操作: # simpleeval 只支持安全的操作:

View File

@@ -6,7 +6,7 @@ from app.core.workflow.expression_evaluator import ExpressionEvaluator
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import AssignmentOperator from app.core.workflow.nodes.enums import AssignmentOperator
from app.core.workflow.nodes.operators import AssignmentOperatorInstance from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver
from app.core.workflow.variable_pool import VariablePool from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -40,8 +40,8 @@ class AssignerNode(BaseNode):
variable_selector = expression.split('.') variable_selector = expression.split('.')
# Only conversation variables ('conv') are allowed # Only conversation variables ('conv') are allowed
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature) if variable_selector[0] != 'conv' and variable_selector[0] not in state["cycle_nodes"]:
raise ValueError("Only conversation variables can be assigned.") raise ValueError("Only conversation or cycle variables can be assigned.")
# Get the value or expression to assign # Get the value or expression to assign
value = assignment.value value = assignment.value
@@ -55,7 +55,9 @@ class AssignerNode(BaseNode):
) )
# Select the appropriate assignment operator instance based on the target variable type # Select the appropriate assignment operator instance based on the target variable type
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))( operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value(
pool.get(variable_selector)
)(
pool, variable_selector, value pool, variable_selector, value
) )
@@ -81,3 +83,5 @@ class AssignerNode(BaseNode):
operator.remove_last() operator.remove_last()
case _: case _:
raise ValueError(f"Invalid Operator: {assignment.operation}") raise ValueError(f"Invalid Operator: {assignment.operation}")
logger.info(f"Node {self.node_id}: execution completed")

View File

@@ -14,9 +14,13 @@ class VariableType(StrEnum):
STRING = "string" STRING = "string"
NUMBER = "number" NUMBER = "number"
BOOLEAN = "boolean" BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object" OBJECT = "object"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_BOOLEAN = "array[boolean]"
ARRAY_OBJECT = "array[object]"
class VariableDefinition(BaseModel): class VariableDefinition(BaseModel):
"""变量定义 """变量定义

View File

@@ -20,40 +20,44 @@ logger = logging.getLogger(__name__)
class WorkflowState(TypedDict): class WorkflowState(TypedDict):
"""工作流状态 """Workflow state
在节点间传递的状态对象,包含消息、变量、节点输出等信息。 The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
""" """
# 消息列表(追加模式) # List of messages (append mode)
messages: Annotated[list[AnyMessage], add] messages: Annotated[list[AnyMessage], add]
# 输入变量(从配置的 variables 传入) # Set of loop node IDs, used for assigning values in loop nodes
# 使用深度合并函数,支持嵌套字典的更新(如 conv.xxx cycle_nodes: list
looping: bool
# Input variables (passed from configured variables)
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
variables: Annotated[dict[str, Any], lambda x, y: { variables: Annotated[dict[str, Any], lambda x, y: {
**x, **x,
**{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v **{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v
for k, v in y.items()} for k, v in y.items()}
}] }]
# 节点输出(存储每个节点的执行结果,用于变量引用) # Node outputs (stores execution results of each node for variable references)
# 使用自定义合并函数,将新的节点输出合并到现有字典中 # Uses a custom merge function to combine new node outputs into the existing dictionary
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}] node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# 运行时节点变量(简化版,只存储业务数据,供节点间快速访问) # Runtime node variables (simplified version, stores business data for fast access between nodes)
# 格式:{node_id: business_result} # Format: {node_id: business_result}
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}] runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# 执行上下文 # Execution context
execution_id: str execution_id: str
workspace_id: str workspace_id: str
user_id: str user_id: str
# 错误信息(用于错误边) # Error information (for error edges)
error: str | None error: str | None
error_node: str | None error_node: str | None
# 流式缓冲区(存储节点的实时流式输出) # Streaming buffer (stores real-time streaming output of nodes)
# 格式:{node_id: {"chunks": [...], "full_content": "..."}} # Format: {node_id: {"chunks": [...], "full_content": "..."}}
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}] streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
@@ -74,6 +78,7 @@ class BaseNode(ABC):
self.workflow_config = workflow_config self.workflow_config = workflow_config
self.node_id = node_config["id"] self.node_id = node_config["id"]
self.node_type = node_config["type"] self.node_type = node_config["type"]
self.cycle = node_config.get("cycle")
self.node_name = node_config.get("name", self.node_id) self.node_name = node_config.get("name", self.node_id)
# 使用 or 运算符处理 None 值 # 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {} self.config = node_config.get("config") or {}
@@ -170,10 +175,10 @@ class BaseNode(ABC):
import time import time
start_time = time.time() start_time = time.time()
timeout = self.get_timeout()
try: try:
timeout = self.get_timeout()
# 调用节点的业务逻辑 # 调用节点的业务逻辑
business_result = await asyncio.wait_for( business_result = await asyncio.wait_for(
self.execute(state), self.execute(state),
@@ -200,7 +205,8 @@ class BaseNode(ABC):
**wrapped_output, **wrapped_output,
"runtime_vars": { "runtime_vars": {
self.node_id: runtime_var self.node_id: runtime_var
} },
"looping": state["looping"]
} }
except TimeoutError: except TimeoutError:
@@ -236,10 +242,10 @@ class BaseNode(ABC):
import time import time
start_time = time.time() start_time = time.time()
timeout = self.get_timeout()
try: try:
timeout = self.get_timeout()
# Get LangGraph's stream writer for sending custom data # Get LangGraph's stream writer for sending custom data
writer = get_stream_writer() writer = get_stream_writer()

View File

@@ -0,0 +1,3 @@
from app.core.workflow.nodes.breaker.node import BreakNode
__all__ = ["BreakNode"]

View File

@@ -0,0 +1,33 @@
import logging
from typing import Any
from app.core.workflow.nodes import BaseNode, WorkflowState
logger = logging.getLogger(__name__)
class BreakNode(BaseNode):
"""
Workflow node that immediately stops loop execution.
When executed, this node sets the 'looping' flag in the workflow state
to False, signaling the outer loop runtime to terminate further iterations.
"""
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the break node.
Args:
state: Current workflow state, including loop control flags.
Effects:
- Sets 'looping' in the state to False to stop the loop.
- Logs the action for debugging purposes.
Returns:
Optional dictionary indicating the loop has been stopped.
"""
state["looping"] = False
logger.info(f"Setting cycle node exit flag, cycle={self.cycle}, looping={state['looping']}")

View File

@@ -22,6 +22,7 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregato
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
__all__ = [ __all__ = [
# 基础类 # 基础类
"BaseNodeConfig", "BaseNodeConfig",
@@ -41,5 +42,7 @@ __all__ = [
"JinjaRenderNodeConfig", "JinjaRenderNodeConfig",
"VariableAggregatorNodeConfig", "VariableAggregatorNodeConfig",
"ParameterExtractorNodeConfig", "ParameterExtractorNodeConfig",
"LoopNodeConfig",
"IterationNodeConfig",
"QuestionClassifierNodeConfig" "QuestionClassifierNodeConfig"
] ]

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode
__all__ = ['CycleGraphNode', 'LoopNodeConfig', 'IterationNodeConfig']

View File

@@ -0,0 +1,96 @@
from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
class CycleVariable(BaseNodeConfig):
name: str = Field(
...,
description="Name of the loop variable"
)
type: VariableType = Field(
...,
description="Data type of the loop variable"
)
value: str = Field(
...,
description="Initial or current value of the loop variable"
)
class ConditionDetail(BaseModel):
comparison_operator: ComparisonOperator = Field(
...,
description="Operator used to compare the left and right operands"
)
left: str = Field(
...,
description="Left-hand operand of the comparison expression"
)
right: str = Field(
...,
description="Right-hand operand of the comparison expression"
)
class ConditionsConfig(BaseModel):
"""Configuration for loop condition evaluation"""
logical_operator: LogicOperator = Field(
default=LogicOperator.AND.value,
description="Logical operator used to combine multiple condition expressions"
)
expressions: list[ConditionDetail] = Field(
...,
description="Collection of condition expressions to be evaluated"
)
class LoopNodeConfig(BaseNodeConfig):
condition: ConditionsConfig = Field(
default_factory=list,
description="Conditional configuration that controls loop execution"
)
cycle_vars: list[CycleVariable] = Field(
default_factory=list,
description="List of variables used and updated during the loop"
)
max_loop: int = Field(
default=10,
description="Maximum number of loop iterations"
)
class IterationNodeConfig(BaseNodeConfig):
input: str = Field(
...,
description="Input of the loop iteration"
)
parallel: bool = Field(
default=False,
description="Whether to execute loop iterations in parallel"
)
parallel_count: int = Field(
default=4,
description="Number of iterations to run in parallel"
)
flatten: bool = Field(
default=False,
description="Whether to flatten the output list from iterations"
)
output: str = Field(
...,
description="Output of the loop iteration"
)

View File

@@ -0,0 +1,154 @@
import asyncio
import copy
import logging
import re
from typing import Any
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class IterationRuntime:
"""
Runtime executor for loop/iteration nodes in a workflow.
This class handles executing iterations over a list variable, supporting
optional parallel execution, flattening of output, and loop control via
the workflow state.
"""
def __init__(
self,
graph: CompiledStateGraph,
node_id: str,
config: dict[str, Any],
state: WorkflowState,
):
"""
Initialize the iteration runtime.
Args:
graph: Compiled workflow graph capable of async invocation.
node_id: Unique identifier of the loop node.
config: Dictionary containing iteration node configuration.
state: Current workflow state at the point of iteration.
"""
self.graph = graph
self.state = state
self.node_id = node_id
self.typed_config = IterationNodeConfig(**config)
self.looping = True
self.output_value = None
self.result: list = []
def _init_iteration_state(self, item, idx):
"""
Initialize a per-iteration copy of the workflow state.
Args:
item: Current element from the input array for this iteration.
idx: Index of the element in the input array.
Returns:
A deep copy of the workflow state with iteration-specific variables set.
"""
loopstate = WorkflowState(
**copy.deepcopy(self.state)
)
loopstate["runtime_vars"][self.node_id] = {
"item": item,
"index": idx,
}
loopstate["node_outputs"][self.node_id] = {
"item": item,
"index": idx,
}
loopstate["looping"] = True
return loopstate
async def run_task(self, item, idx):
"""
Execute a single iteration asynchronously.
Args:
item: The input element for this iteration.
idx: The index of this iteration.
"""
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
output = VariablePool(result).get(self.output_value)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
if not result["looping"]:
self.looping = False
def _create_iteration_tasks(self, array_obj, idx):
"""
Create async tasks for a batch of iterations based on parallel count.
Args:
array_obj: The input array to iterate over.
idx: Starting index for this batch of iterations.
Returns:
List of coroutine tasks ready to be executed in parallel.
"""
tasks = []
for i in range(self.typed_config.parallel_count):
if idx + i >= len(array_obj):
break
item = array_obj[idx + i]
tasks.append(self.run_task(item, idx + i))
return tasks
async def run(self):
"""
Execute the loop over the input array according to configuration.
Returns:
A list of outputs from all iterations, optionally flattened.
Raises:
RuntimeError: If the input variable is not a list.
"""
pattern = r"\{\{\s*(.*?)\s*\}\}"
input_expression = re.sub(pattern, r"\1", self.typed_config.input).strip()
self.output_value = re.sub(pattern, r"\1", self.typed_config.output).strip()
array_obj = VariablePool(self.state).get(input_expression)
if not isinstance(array_obj, list):
raise RuntimeError("Cannot iterate over a non-list variable")
idx = 0
if self.typed_config.parallel:
# Execute iterations in parallel batches
while idx < len(array_obj) and self.looping:
tasks = self._create_iteration_tasks(array_obj, idx)
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
idx += self.typed_config.parallel_count
await asyncio.gather(*tasks)
logger.info(f"Iteration node {self.node_id}: execution completed")
return self.result
else:
# Execute iterations sequentially
while idx < len(array_obj) and self.looping:
logger.info(f"Iteration node {self.node_id}: running")
item = array_obj[idx]
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
output = VariablePool(result).get(self.output_value)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
if not result["looping"]:
self.looping = False
idx += 1
logger.info(f"Iteration node {self.node_id}: execution completed")
return self.result

View File

@@ -0,0 +1,130 @@
import logging
from typing import Any
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class LoopRuntime:
"""
Runtime executor for loop nodes in a workflow.
Handles iterative execution of a loop node according to defined loop variables
and conditional expressions. Supports maximum loop count and loop control
through the workflow state.
"""
def __init__(
self,
graph: CompiledStateGraph,
node_id: str,
config: dict[str, Any],
state: WorkflowState,
):
"""
Initialize the loop runtime.
Args:
graph: Compiled workflow graph capable of async invocation.
node_id: Unique identifier of the loop node.
config: Dictionary containing loop node configuration.
state: Current workflow state at the point of loop execution.
"""
self.graph = graph
self.state = state
self.node_id = node_id
self.typed_config = LoopNodeConfig(**config)
def _init_loop_state(self):
"""
Initialize workflow state for loop execution.
- Evaluates initial values of loop variables.
- Stores loop variables in runtime_vars and node_outputs.
- Marks the loop as active by setting 'looping' to True.
Returns:
A copy of the workflow state prepared for the loop execution.
"""
pool = VariablePool(self.state)
# 循环变量
self.state["runtime_vars"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
for variable in self.typed_config.cycle_vars
}
self.state["node_outputs"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
for variable in self.typed_config.cycle_vars
}
loopstate = WorkflowState(
**self.state
)
loopstate["looping"] = True
return loopstate
def _get_loop_expression(self):
"""
Build the Python boolean expression for evaluating the loop condition.
- Converts each condition in the loop configuration into a Python expression string.
- Combines multiple conditions with the configured logical operator (AND/OR).
Returns:
A string representing the combined loop condition expression.
"""
branch_conditions = [
ConditionExpressionBuilder(
left=condition.left,
operator=condition.comparison_operator,
right=condition.right
).build()
for condition in self.typed_config.condition.expressions
]
if len(branch_conditions) > 1:
combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions)
else:
combined_condition = branch_conditions[0]
return combined_condition
async def run(self):
"""
Execute the loop node until the condition is no longer met, the loop is
manually stopped, or the maximum loop count is reached.
Returns:
The final runtime variables of this loop node after completion.
"""
loopstate = self._init_loop_state()
expression = self._get_loop_expression()
loop_variable_pool = VariablePool(loopstate)
loop_time = self.typed_config.max_loop
while evaluate_condition(
expression=expression,
variables=loop_variable_pool.get_all_conversation_vars(),
node_outputs=loop_variable_pool.get_all_node_outputs(),
system_vars=loop_variable_pool.get_all_system_vars(),
) and loopstate["looping"] and loop_time > 0:
logger.info(f"loop node {self.node_id}: running")
await self.graph.ainvoke(loopstate)
loop_time -= 1
logger.info(f"loop node {self.node_id}: execution completed")
return loopstate["runtime_vars"][self.node_id]

View File

@@ -0,0 +1,226 @@
import logging
from typing import Any
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__)
class CycleGraphNode(BaseNode):
"""
Node representing a cycle (loop) subgraph within the workflow.
This node manages internal loop/iteration nodes, builds a subgraph
for execution, handles conditional routing, and executes loop
or iteration logic based on node type.
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
self.cycle_nodes = list() # Nodes belonging to this cycle
self.cycle_edges = list() # Edges connecting nodes within the cycle
self.start_node_id = None # ID of the start node within the cycle
self.end_node_ids = [] # IDs of end nodes within the cycle
self.graph: StateGraph | CompiledStateGraph | None = None
self.build_graph()
self.iteration_flag = True
def pure_cycle_graph(self) -> tuple[list, list]:
"""
Extract cycle nodes and internal edges from the workflow configuration,
removing them from the global workflow.
Raises:
ValueError: If cycle nodes are connected to external nodes improperly.
Returns:
Tuple containing:
- cycle_nodes: List of removed nodes
- cycle_edges: List of removed edges
"""
nodes = self.workflow_config.get("nodes", [])
edges = self.workflow_config.get("edges", [])
# Select all nodes that belong to the current cycle
cycle_nodes = [node for node in nodes if node.get("cycle") == self.node_id]
cycle_node_ids = {node.get("id") for node in cycle_nodes}
cycle_edges = []
remain_edges = []
for edge in edges:
source_in = edge.get("source") in cycle_node_ids
target_in = edge.get("target") in cycle_node_ids
# Raise error if cycle nodes are connected with external nodes
if source_in ^ target_in:
raise ValueError(
f"Cycle node is connected to external node, "
f"source: {edge.get('source')}, target: {edge.get('target')}"
)
if source_in and target_in:
cycle_edges.append(edge)
else:
remain_edges.append(edge)
# Update workflow_config by removing cycle nodes and internal edges
self.workflow_config["nodes"] = [
node for node in nodes if node.get("cycle") != self.node_id
]
self.workflow_config["edges"] = remain_edges
return cycle_nodes, cycle_edges
def create_node(self):
"""
Instantiate node objects for each node in the cycle subgraph and add them to the graph.
Special handling is applied for conditional nodes to generate
edge conditions based on node outputs.
"""
from app.core.workflow.nodes import NodeFactory
for node in self.cycle_nodes:
node_type = node.get("type")
node_id = node.get("id")
if node_type == NodeType.CYCLE_START:
self.start_node_id = node_id
continue
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
expressions = node_instance.build_conditional_edge_expressions()
# Number of branches, usually matches the number of conditional expressions
branch_number = len(expressions)
# Find all edges whose source is the current node
related_edge = [edge for edge in self.cycle_edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(branch_number):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
self.graph.add_node(node_id, make_func(node_instance))
def create_edge(self):
"""
Connect nodes within the cycle subgraph by adding edges to the internal graph.
Conditional edges are routed based on evaluated expressions.
Start and end nodes are connected to global START and END nodes.
"""
for edge in self.cycle_edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == self.start_node_id:
# 但要连接 start 到下一个节点
self.graph.add_edge(START, target)
logger.debug(f"添加边: {source} -> {target}")
continue
if condition:
# 条件边
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
self.graph.add_conditional_edges(source, router)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in self.end_node_ids:
self.graph.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
def build_graph(self):
"""
Build the internal subgraph for the cycle node.
Steps:
1. Extract cycle nodes and edges.
2. Create node instances and add them to the graph.
3. Connect edges and conditional routes.
4. Compile the graph for execution.
"""
self.graph = StateGraph(WorkflowState)
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.create_node()
self.create_edge()
self.graph = self.graph.compile()
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the cycle node at runtime.
Depending on the node type, runs either a loop (LoopRuntime)
or an iteration (IterationRuntime) over the internal subgraph.
Args:
state: Current workflow state.
Returns:
Runtime result of the cycle, typically the final loop/iteration variables.
Raises:
RuntimeError: If node type is unrecognized.
"""
if self.node_type == NodeType.LOOP:
return await LoopRuntime(
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
).run()
if self.node_type == NodeType.ITERATION:
return await IterationRuntime(
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
).run()
raise RuntimeError("Unknown cycle node type")

View File

@@ -1,14 +1,5 @@
from enum import StrEnum from enum import StrEnum
from app.core.workflow.nodes.operators import (
StringOperator,
NumberOperator,
AssignmentOperatorType,
BooleanOperator,
ArrayOperator,
ObjectOperator
)
class NodeType(StrEnum): class NodeType(StrEnum):
START = "start" START = "start"
@@ -27,6 +18,10 @@ class NodeType(StrEnum):
JINJARENDER = "jinja-render" JINJARENDER = "jinja-render"
VAR_AGGREGATOR = "var-aggregator" VAR_AGGREGATOR = "var-aggregator"
PARAMETER_EXTRACTOR = "parameter-extractor" PARAMETER_EXTRACTOR = "parameter-extractor"
LOOP = "loop"
ITERATION = "iteration"
CYCLE_START = "cycle-start"
BREAK = "break"
class ComparisonOperator(StrEnum): class ComparisonOperator(StrEnum):
@@ -62,21 +57,6 @@ class AssignmentOperator(StrEnum):
REMOVE_LAST = "remove_last" REMOVE_LAST = "remove_last"
REMOVE_FIRST = "remove_first" REMOVE_FIRST = "remove_first"
@classmethod
def get_operator(cls, obj) -> AssignmentOperatorType:
if isinstance(obj, str):
return StringOperator
elif isinstance(obj, bool):
return BooleanOperator
elif isinstance(obj, (int, float)):
return NumberOperator
elif isinstance(obj, list):
return ArrayOperator
elif isinstance(obj, dict):
return ObjectOperator
raise TypeError(f"Unsupported variable type ({type(obj)})")
class HttpRequestMethod(StrEnum): class HttpRequestMethod(StrEnum):
GET = "GET" GET = "GET"

View File

@@ -215,6 +215,7 @@ class HttpRequestNode(BaseNode):
**self._build_content(state) **self._build_content(state)
) )
resp.raise_for_status() resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded")
return HttpRequestNodeOutput( return HttpRequestNodeOutput(
body=resp.text, body=resp.text,
status_code=resp.status_code, status_code=resp.status_code,
@@ -228,12 +229,21 @@ class HttpRequestNode(BaseNode):
else: else:
match self.typed_config.error_handle.method: match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE: case HttpErrorHandle.NONE:
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning error response"
)
return HttpRequestNodeOutput( return HttpRequestNodeOutput(
body="", body="",
status_code=resp.status_code, status_code=resp.status_code,
headers=resp.headers, headers=resp.headers,
).model_dump() ).model_dump()
case HttpErrorHandle.DEFAULT: case HttpErrorHandle.DEFAULT:
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning default result"
)
return self.typed_config.error_handle.default.model_dump() return self.typed_config.error_handle.default.model_dump()
case HttpErrorHandle.BRANCH: case HttpErrorHandle.BRANCH:
logger.warning(
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
)
return "ERROR" return "ERROR"

View File

@@ -30,7 +30,7 @@ class ConditionBranchConfig(BaseModel):
description="Logical operator used to combine multiple condition expressions" description="Logical operator used to combine multiple condition expressions"
) )
conditions: list[ConditionDetail] = Field( expressions: list[ConditionDetail] = Field(
..., ...,
description="List of condition expressions within this branch" description="List of condition expressions within this branch"
) )
@@ -57,7 +57,7 @@ class IfElseNodeConfig(BaseNodeConfig):
# CASE1 / IF Branch # CASE1 / IF Branch
{ {
"logical_operator": "and", "logical_operator": "and",
"conditions": [ "expressions": [
[ [
{ {
"left": "node.userinput.message", "left": "node.userinput.message",
@@ -75,7 +75,7 @@ class IfElseNodeConfig(BaseNodeConfig):
# CASE1 / ELIF Branch # CASE1 / ELIF Branch
{ {
"logical_operator": "or", "logical_operator": "or",
"conditions": [ "expressions": [
[ [
{ {
"left": "node.userinput.test", "left": "node.userinput.test",

View File

@@ -2,93 +2,13 @@ import logging
from typing import Any from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import ConditionDetail from app.core.workflow.nodes.if_else.config import ConditionDetail
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ConditionExpressionBuilder:
"""
Build a Python boolean expression string based on a comparison operator.
This class does not evaluate the expression.
It only generates a valid Python expression string
that can be evaluated later in a workflow context.
"""
def __init__(self, left: str, operator: ComparisonOperator, right: str):
self.left = left
self.operator = operator
self.right = right
def _empty(self):
return f"{self.left} == ''"
def _not_empty(self):
return f"{self.left} != ''"
def _contains(self):
return f"{self.right} in {self.left}"
def _not_contains(self):
return f"{self.right} not in {self.left}"
def _startwith(self):
return f'{self.left}.startswith({self.right})'
def _endwith(self):
return f'{self.left}.endswith({self.right})'
def _eq(self):
return f"{self.left} == {self.right}"
def _ne(self):
return f"{self.left} != {self.right}"
def _lt(self):
return f"{self.left} < {self.right}"
def _le(self):
return f"{self.left} <= {self.right}"
def _gt(self):
return f"{self.left} > {self.right}"
def _ge(self):
return f"{self.left} >= {self.right}"
def build(self):
match self.operator:
case ComparisonOperator.EMPTY:
return self._empty()
case ComparisonOperator.NOT_EMPTY:
return self._not_empty()
case ComparisonOperator.CONTAINS:
return self._contains()
case ComparisonOperator.NOT_CONTAINS:
return self._not_contains()
case ComparisonOperator.START_WITH:
return self._startwith()
case ComparisonOperator.END_WITH:
return self._endwith()
case ComparisonOperator.EQ:
return self._eq()
case ComparisonOperator.NE:
return self._ne()
case ComparisonOperator.LT:
return self._lt()
case ComparisonOperator.LE:
return self._le()
case ComparisonOperator.GT:
return self._gt()
case ComparisonOperator.GE:
return self._ge()
case _:
raise ValueError(f"Invalid condition: {self.operator}")
class IfElseNode(BaseNode): class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config)
@@ -143,7 +63,7 @@ class IfElseNode(BaseNode):
branch_conditions = [ branch_conditions = [
self._build_condition_expression(condition) self._build_condition_expression(condition)
for condition in case_branch.conditions for condition in case_branch.expressions
] ]
if len(branch_conditions) > 1: if len(branch_conditions) > 1:
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions) combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
@@ -174,5 +94,6 @@ class IfElseNode(BaseNode):
for i in range(len(expressions)): for i in range(len(expressions)):
logger.info(expressions[i]) logger.info(expressions[i])
if self._evaluate_condition(expressions[i], state): if self._evaluate_condition(expressions[i], state):
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
return f'CASE{i + 1}' return f'CASE{i + 1}'
return f'CASE{len(expressions)}' return f'CASE{len(expressions)}'

View File

@@ -1,3 +1,4 @@
import logging
from typing import Any from typing import Any
from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes import WorkflowState
@@ -5,6 +6,7 @@ from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
from app.core.workflow.template_renderer import TemplateRenderer from app.core.workflow.template_renderer import TemplateRenderer
logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode): class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
@@ -41,5 +43,5 @@ class JinjaRenderNode(BaseNode):
res = render.env.from_string(self.typed_config.template).render(**context) res = render.env.from_string(self.typed_config.template).render(**context)
except Exception as e: except Exception as e:
raise RuntimeError(f"JinjaRender Node {self.node_name} render failed: {e}") from e raise RuntimeError(f"JinjaRender Node {self.node_name} render failed: {e}") from e
logger.info(f"Node {self.node_id}: Jinja template rendering completed")
return res return res

View File

@@ -1,18 +1,13 @@
from uuid import UUID from uuid import UUID
from pydantic import Field from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.schemas.chunk_schema import RetrieveType from app.schemas.chunk_schema import RetrieveType
class KnowledgeRetrievalNodeConfig(BaseNodeConfig): class KnowledgeBaseConfig(BaseModel):
query: str = Field( kb_id: UUID = Field(
...,
description="Search query string"
)
kb_ids: list[UUID] = Field(
..., ...,
description="Knowledge base IDs" description="Knowledge base IDs"
) )
@@ -37,18 +32,42 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
description="Retrieve type" description="Retrieve type"
) )
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
query: str = Field(
...,
description="Search query string"
)
knowledge_bases: list[KnowledgeBaseConfig] = Field(
...,
description="Knowledge base config"
)
reranker_id: UUID = Field(
...,
description="Reranker top k"
)
reranker_top_k: int = Field(
default=4,
description="Knowledge base top k"
)
class Config: class Config:
json_schema_extra = { json_schema_extra = {
"examples": [ "examples": [
{ {
"query": "{{sys.message}}", "query": "{{sys.message}}",
"kb_ids": [ "knowledge_bases": [{
"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" "kb_id": "xxxxxxxx-xxxx-xxxx-xxxxxxxxxxxxxxxxx",
], "similarity_threshold": 0.2,
"similarity_threshold": 0.2, "vector_similarity_weight": 0.3,
"vector_similarity_weight": 0.3, "top_k": 4,
"top_k": 1, "retrieve_type": "hybrid"
"retrieve_type": "hybrid" }],
"reranker_top_k": 1,
"reranker_id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
} }
] ]
} }

View File

@@ -2,14 +2,18 @@ import logging
import uuid import uuid
from typing import Any from typing import Any
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.db import get_db_read from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model from app.models import knowledge_model, knowledgeshare_model, ModelType
from app.repositories import knowledge_repository from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType from app.schemas.chunk_schema import RetrieveType
from app.services import knowledge_service, knowledgeshare_service from app.services import knowledge_service, knowledgeshare_service
from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -108,6 +112,44 @@ class KnowledgeRetrievalNode(BaseNode):
existing_ids.extend(items) existing_ids.extend(items)
return existing_ids return existing_ids
def get_reranker_model(self) -> RedBearRerank:
"""
Retrieve and initialize a RedBear reranker model based on configuration.
Raises:
BusinessException: If configuration is missing or API keys are not set.
RuntimeError: If the configured model is not of type RERANK.
"""
with get_db_read() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.reranker_id)
if not config:
raise BusinessException("Configured model does not exist", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER)
# 在 Session 关闭前提取所有需要的数据
api_config = config.api_keys[0]
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
api_base = api_config.api_base
model_type = config.type
if model_type != ModelType.RERANK:
raise RuntimeError("Model is not a reranker")
reranker = RedBearRerank(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
)
)
return reranker
async def execute(self, state: WorkflowState) -> Any: async def execute(self, state: WorkflowState) -> Any:
""" """
Execute the knowledge retrieval workflow node. Execute the knowledge retrieval workflow node.
@@ -131,38 +173,45 @@ class KnowledgeRetrievalNode(BaseNode):
""" """
query = self._render_template(self.typed_config.query, state) query = self._render_template(self.typed_config.query, state)
with get_db_read() as db: with get_db_read() as db:
existing_ids = self._get_existing_kb_ids(db, self.typed_config.kb_ids) knowledge_bases = self.typed_config.knowledge_bases
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])
if not existing_ids: if not existing_ids:
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
kb_id = existing_ids[0] rs = []
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] for kb_config in knowledge_bases:
indices = ",".join(uuid_strs) db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id) vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
if not db_knowledge: indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
raise RuntimeError("The knowledge base does not exist or access is denied.") match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE:
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
match self.typed_config.retrieve_type: score_threshold=kb_config.similarity_threshold))
case RetrieveType.PARTICIPLE: case RetrieveType.SEMANTIC:
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices, indices=indices,
score_threshold=self.typed_config.similarity_threshold) score_threshold=kb_config.vector_similarity_weight))
case RetrieveType.SEMANTIC: case RetrieveType.HYBRID:
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices, indices=indices,
score_threshold=self.typed_config.vector_similarity_weight) score_threshold=kb_config.vector_similarity_weight)
case _: rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, indices=indices,
indices=indices, score_threshold=kb_config.similarity_threshold)
score_threshold=self.typed_config.vector_similarity_weight) # Deduplicate hybrid retrieval results
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, unique_rs = self._deduplicate_docs(rs1, rs2)
indices=indices, vector_service.reranker = self.get_reranker_model()
score_threshold=self.typed_config.similarity_threshold) rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
# Deduplicate hybrid retrieval results case _:
unique_rs = self._deduplicate_docs(rs1, rs2) raise RuntimeError("Unknown retrieval type")
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k) vector_service.reranker = self.get_reranker_model()
return [chunk.model_dump() for chunk in rs] final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
logger.info(
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
)
return [chunk.model_dump() for chunk in final_rs]

View File

@@ -10,6 +10,7 @@ from typing import Any, Union
from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.assigner import AssignerNode from app.core.workflow.nodes.assigner import AssignerNode
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode
from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.http_request import HttpRequestNode from app.core.workflow.nodes.http_request import HttpRequestNode
@@ -22,6 +23,7 @@ from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -39,6 +41,9 @@ WorkflowNode = Union[
JinjaRenderNode, JinjaRenderNode,
VariableAggregatorNode, VariableAggregatorNode,
ParameterExtractorNode, ParameterExtractorNode,
CycleGraphNode,
BreakNode,
ParameterExtractorNode,
QuestionClassifierNode QuestionClassifierNode
] ]
@@ -64,6 +69,9 @@ class NodeFactory:
NodeType.VAR_AGGREGATOR: VariableAggregatorNode, NodeType.VAR_AGGREGATOR: VariableAggregatorNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.LOOP: CycleGraphNode,
NodeType.ITERATION: CycleGraphNode,
NodeType.BREAK: BreakNode,
} }
@classmethod @classmethod

View File

@@ -1,6 +1,7 @@
from abc import ABC from abc import ABC
from typing import Union, Type from typing import Union, Type
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.variable_pool import VariablePool from app.core.workflow.variable_pool import VariablePool
@@ -136,6 +137,23 @@ class ObjectOperator(OperatorBase):
self.pool.set(self.left_selector, dict()) self.pool.set(self.left_selector, dict())
class AssignmentOperatorResolver:
@classmethod
def resolve_by_value(cls, value):
if isinstance(value, str):
return StringOperator
elif isinstance(value, bool):
return BooleanOperator
elif isinstance(value, (int, float)):
return NumberOperator
elif isinstance(value, list):
return ArrayOperator
elif isinstance(value, dict):
return ObjectOperator
else:
raise TypeError(f"Unsupported variable type: {type(value)}")
AssignmentOperatorInstance = Union[ AssignmentOperatorInstance = Union[
StringOperator, StringOperator,
NumberOperator, NumberOperator,
@@ -144,3 +162,83 @@ AssignmentOperatorInstance = Union[
ObjectOperator ObjectOperator
] ]
AssignmentOperatorType = Type[AssignmentOperatorInstance] AssignmentOperatorType = Type[AssignmentOperatorInstance]
class ConditionExpressionBuilder:
"""
Build a Python boolean expression string based on a comparison operator.
This class does not evaluate the expression.
It only generates a valid Python expression string
that can be evaluated later in a workflow context.
"""
def __init__(self, left: str, operator: ComparisonOperator, right: str):
self.left = left
self.operator = operator
self.right = right
def _empty(self):
return f"{self.left} == ''"
def _not_empty(self):
return f"{self.left} != ''"
def _contains(self):
return f"{self.right} in {self.left}"
def _not_contains(self):
return f"{self.right} not in {self.left}"
def _startswith(self):
return f'{self.left}.startswith({self.right})'
def _endswith(self):
return f'{self.left}.endswith({self.right})'
def _eq(self):
return f"{self.left} == {self.right}"
def _ne(self):
return f"{self.left} != {self.right}"
def _lt(self):
return f"{self.left} < {self.right}"
def _le(self):
return f"{self.left} <= {self.right}"
def _gt(self):
return f"{self.left} > {self.right}"
def _ge(self):
return f"{self.left} >= {self.right}"
def build(self):
match self.operator:
case ComparisonOperator.EMPTY:
return self._empty()
case ComparisonOperator.NOT_EMPTY:
return self._not_empty()
case ComparisonOperator.CONTAINS:
return self._contains()
case ComparisonOperator.NOT_CONTAINS:
return self._not_contains()
case ComparisonOperator.START_WITH:
return self._startswith()
case ComparisonOperator.END_WITH:
return self._endswith()
case ComparisonOperator.EQ:
return self._eq()
case ComparisonOperator.NE:
return self._ne()
case ComparisonOperator.LT:
return self._lt()
case ComparisonOperator.LE:
return self._le()
case ComparisonOperator.GT:
return self._gt()
case ComparisonOperator.GE:
return self._ge()
case _:
raise ValueError(f"Invalid condition: {self.operator}")

View File

@@ -36,6 +36,11 @@ class ParamsConfig(BaseModel):
description="Description of the parameter" description="Description of the parameter"
) )
required: bool = Field(
...,
description="Whether the parameter is required"
)
class ParameterExtractorNodeConfig(BaseNodeConfig): class ParameterExtractorNodeConfig(BaseNodeConfig):
model_id: uuid.UUID = Field( model_id: uuid.UUID = Field(
@@ -52,3 +57,8 @@ class ParameterExtractorNodeConfig(BaseNodeConfig):
..., ...,
description="List of parameters" description="List of parameters"
) )
prompt: str = Field(
...,
description="User-provided supplemental prompt"
)

View File

@@ -1,4 +1,5 @@
import os import os
import logging
import json_repair import json_repair
from typing import Any from typing import Any
@@ -15,6 +16,8 @@ from app.db import get_db_read
from app.models import ModelType from app.models import ModelType
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__)
class ParameterExtractorNode(BaseNode): class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
@@ -114,7 +117,7 @@ class ParameterExtractorNode(BaseNode):
""" """
field_type = {} field_type = {}
for param in self.typed_config.params: for param in self.typed_config.params:
field_type[param.name] = param.type field_type[param.name] = f'{param.type}, required:{str(param.required)}'
return field_type return field_type
async def execute(self, state: WorkflowState) -> Any: async def execute(self, state: WorkflowState) -> Any:
@@ -154,12 +157,12 @@ class ParameterExtractorNode(BaseNode):
messages = [ messages = [
("system", system_prompt), ("system", system_prompt),
("user", self._render_template(self.typed_config.prompt, state)),
("user", rendered_user_prompt), ("user", rendered_user_prompt),
] ]
model_resp = await llm.ainvoke(messages) model_resp = await llm.ainvoke(messages)
result = json_repair.repair_json(model_resp.content) result = json_repair.repair_json(model_resp.content, return_objects=True)
logger.info(f"node: {self.node_id} get params:{result}")
return { return result
"output": result,
}

View File

@@ -9,43 +9,27 @@ class VariableAggregatorNodeConfig(BaseNodeConfig):
description="输出变量是否需要分组", description="输出变量是否需要分组",
) )
group_names: list[str] = Field( group_variables: list[str] | dict[str, list[str]] = Field(
default_factory=lambda: ["output"],
description="各个分组的名称"
)
group_variables: list[str] | list[list[str]] = Field(
..., ...,
description="需要被聚合的变量" description="需要被聚合的变量"
) )
@field_validator("group_names", mode="before")
@classmethod
def group_names_validator(cls, v, info):
group_status = info.data.get("group")
if not group_status or not v:
return ["output"]
return v
@field_validator("group_variables") @field_validator("group_variables")
@classmethod @classmethod
def group_variables_validator(cls, v, info): def group_variables_validator(cls, v, info):
group_status = info.data.get("group") group_status = info.data.get("group")
group_names = info.data.get("group_names")
if not isinstance(v, list):
raise ValueError("group_variables must be a list")
if not group_status: if not group_status:
for variable in v: for variable in v:
if not isinstance(variable, str): if not isinstance(variable, str):
raise ValueError("When group=False, group_variables must be a list of strings") raise ValueError("When group=False, group_variables must be a list of strings")
else: else:
if len(group_names) != len(v): if not isinstance(v, dict):
raise ValueError("group_names and group_variables length mismatch") raise ValueError("When group=True, group_variables must be a dict")
for group in v: for group_name, group_values in v.items():
if not isinstance(group, list): if not isinstance(group_name, str):
raise ValueError("When group=True, each element of group_variables must be a list") raise ValueError("When group=True, each element of group_variables must be a list")
for variable in group: for variable in group_values:
if not isinstance(variable, str): if not isinstance(variable, str):
raise ValueError("Each element inside group_variables lists must be a string") raise ValueError("Each element inside group_variables lists must be a string")
return v return v

View File

@@ -50,6 +50,7 @@ class VariableAggregatorNode(BaseNode):
continue continue
if value is not None: if value is not None:
logger.info(f"Node: {self.node_id} variable aggregation result: {value}")
return value return value
logger.info("No variable found in non-group mode; returning empty string.") logger.info("No variable found in non-group mode; returning empty string.")
@@ -59,7 +60,7 @@ class VariableAggregatorNode(BaseNode):
# Group mode # Group mode
# -------------------------- # --------------------------
result = {} result = {}
for group_name, variables in zip(self.typed_config.group_names, self.typed_config.group_variables): for group_name, variables in self.typed_config.group_variables.items():
for variable in variables: for variable in variables:
var_express = self._get_express(variable) var_express = self._get_express(variable)
try: try:
@@ -74,5 +75,5 @@ class VariableAggregatorNode(BaseNode):
else: else:
result[group_name] = "" result[group_name] = ""
logger.info(f"No variable found for group '{group_name}'; set empty string.") logger.info(f"No variable found for group '{group_name}'; set empty string.")
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
return result return result

View File

@@ -7,14 +7,87 @@
import logging import logging
from typing import Any, Union from typing import Any, Union
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowValidator: class WorkflowValidator:
"""工作流配置验证器""" """工作流配置验证器"""
@staticmethod @classmethod
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]: def pure_cycle_graph(cls, workflow_config: Union[dict[str, Any], Any], node_id) -> tuple[list, list]:
"""
Extract cycle nodes and internal edges from the workflow configuration,
removing them from the global workflow.
Raises:
ValueError: If cycle nodes are connected to external nodes improperly.
Returns:
Tuple containing:
- cycle_nodes: List of removed nodes
- cycle_edges: List of removed edges
"""
nodes = workflow_config.get("nodes", [])
edges = workflow_config.get("edges", [])
# Select all nodes that belong to the current cycle
cycle_nodes = [node for node in nodes if node.get("cycle") == node_id]
cycle_node_ids = {node.get("id") for node in cycle_nodes}
cycle_edges = []
remain_edges = []
for edge in edges:
source_in = edge.get("source") in cycle_node_ids
target_in = edge.get("target") in cycle_node_ids
# Raise error if cycle nodes are connected with external nodes
if source_in ^ target_in:
raise ValueError(
f"Cycle node is connected to external node, "
f"source: {edge.get('source')}, target: {edge.get('target')}"
)
if source_in and target_in:
cycle_edges.append(edge)
else:
remain_edges.append(edge)
# Update workflow_config by removing cycle nodes and internal edges
workflow_config["nodes"] = [
node for node in nodes if node.get("cycle") != node_id
]
workflow_config["edges"] = remain_edges
return cycle_nodes, cycle_edges
@classmethod
def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list:
if not isinstance(workflow_config, dict):
workflow_config = {
"nodes": workflow_config.nodes,
"edges": workflow_config.edges,
"variables": workflow_config.variables,
}
cycle_nodes = [
node.get("id")
for node in workflow_config.get("nodes", [])
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
]
graphs = []
for cycle_node in cycle_nodes:
nodes, edges = cls.pure_cycle_graph(workflow_config, cycle_node)
graphs.append({
"nodes": nodes,
"edges": edges,
})
graphs.append(workflow_config)
return graphs
@classmethod
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
"""验证工作流配置 """验证工作流配置
Args: Args:
@@ -38,84 +111,79 @@ class WorkflowValidator:
True True
""" """
errors = [] errors = []
# 支持字典和 Pydantic 模型 graphs = cls.get_subgraph(workflow_config)
if isinstance(workflow_config, dict): logger.info(graphs)
nodes = workflow_config.get("nodes", []) for graph in graphs:
edges = workflow_config.get("edges", []) nodes = graph.get("nodes", [])
variables = workflow_config.get("variables", []) edges = graph.get("edges", [])
else: variables = graph.get("variables", [])
# Pydantic 模型 # 1. 验证 start 节点(有且只有一个)
nodes = getattr(workflow_config, "nodes", []) start_nodes = [n for n in nodes if n.get("type") in [NodeType.START, NodeType.CYCLE_START]]
edges = getattr(workflow_config, "edges", []) if len(start_nodes) == 0:
variables = getattr(workflow_config, "variables", []) errors.append("工作流必须有一个 start 节点")
elif len(start_nodes) > 1:
# 1. 验证 start 节点(有且只有一个) errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
start_nodes = [n for n in nodes if n.get("type") == "start"]
if len(start_nodes) == 0: # 2. 验证 end 节点(至少一个)
errors.append("工作流必须有一个 start 节点") end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
elif len(start_nodes) > 1: if len(end_nodes) == 0:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}") errors.append("工作流必须至少有一个 end 节点")
# 2. 验证 end 节点(至少一个) # 3. 验证节点 ID 唯一性
end_nodes = [n for n in nodes if n.get("type") == "end"] node_ids = [n.get("id") for n in nodes]
if len(end_nodes) == 0: if len(node_ids) != len(set(node_ids)):
errors.append("工作流必须至少有一个 end 节点") duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
# 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes] # 4. 验证节点必须有 id 和 type
if len(node_ids) != len(set(node_ids)): for i, node in enumerate(nodes):
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1] if not node.get("id"):
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}") errors.append(f"节点 #{i} 缺少 id 字段")
if not node.get("type"):
# 4. 验证节点必须有 id 和 type errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
for i, node in enumerate(nodes):
if not node.get("id"): # 5. 验证边的有效性
errors.append(f"节点 #{i} 缺少 id 字段") node_id_set = set(node_ids)
if not node.get("type"): for i, edge in enumerate(edges):
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段") source = edge.get("source")
target = edge.get("target")
# 5. 验证边的有效性
node_id_set = set(node_ids) if not source:
for i, edge in enumerate(edges): errors.append(f"边 #{i} 缺少 source 字段")
source = edge.get("source") elif source not in node_id_set:
target = edge.get("target") errors.append(f"边 #{i} 的 source 节点不存在: {source}")
if not source: if not target:
errors.append(f"边 #{i} 缺少 source 字段") errors.append(f"边 #{i} 缺少 target 字段")
elif source not in node_id_set: elif target not in node_id_set:
errors.append(f"边 #{i}source 节点不存在: {source}") errors.append(f"边 #{i}target 节点不存在: {target}")
if not target: # 6. 验证所有节点可达(从 start 节点出发)
errors.append(f"边 #{i} 缺少 target 字段") if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
elif target not in node_id_set: reachable = WorkflowValidator._get_reachable_nodes(
errors.append(f"边 #{i} 的 target 节点不存在: {target}") start_nodes[0]["id"],
edges
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes(
start_nodes[0]["id"],
edges
)
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle:
errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
) )
unreachable = node_id_set - reachable
# 8. 验证变量名 if unreachable:
from app.core.workflow.expression_evaluator import ExpressionEvaluator errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
var_errors = ExpressionEvaluator.validate_variable_names(variables)
errors.extend(var_errors) # 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle:
errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
)
# 8. 验证变量名
from app.core.workflow.expression_evaluator import ExpressionEvaluator
var_errors = ExpressionEvaluator.validate_variable_names(variables)
errors.extend(var_errors)
return len(errors) == 0, errors return len(errors) == 0, errors
@staticmethod @staticmethod
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]: def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
"""获取从 start 节点可达的所有节点 """获取从 start 节点可达的所有节点
@@ -129,7 +197,7 @@ class WorkflowValidator:
""" """
reachable = {start_id} reachable = {start_id}
queue = [start_id] queue = [start_id]
while queue: while queue:
current = queue.pop(0) current = queue.pop(0)
for edge in edges: for edge in edges:
@@ -138,9 +206,9 @@ class WorkflowValidator:
if target and target not in reachable: if target and target not in reachable:
reachable.add(target) reachable.add(target)
queue.append(target) queue.append(target)
return reachable return reachable
@staticmethod @staticmethod
def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]: def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]:
"""检测是否存在循环依赖DFS """检测是否存在循环依赖DFS
@@ -154,39 +222,39 @@ class WorkflowValidator:
""" """
# 排除 loop 类型的节点 # 排除 loop 类型的节点
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"} loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
# 构建邻接表(排除 loop 节点的边和错误边) # 构建邻接表(排除 loop 节点的边和错误边)
graph: dict[str, list[str]] = {} graph: dict[str, list[str]] = {}
for edge in edges: for edge in edges:
source = edge.get("source") source = edge.get("source")
target = edge.get("target") target = edge.get("target")
edge_type = edge.get("type") edge_type = edge.get("type")
# 跳过错误边 # 跳过错误边
if edge_type == "error": if edge_type == "error":
continue continue
# 如果涉及 loop 节点,跳过 # 如果涉及 loop 节点,跳过
if source in loop_nodes or target in loop_nodes: if source in loop_nodes or target in loop_nodes:
continue continue
if source and target: if source and target:
if source not in graph: if source not in graph:
graph[source] = [] graph[source] = []
graph[source].append(target) graph[source].append(target)
# DFS 检测环 # DFS 检测环
visited = set() visited = set()
rec_stack = set() rec_stack = set()
path = [] path = []
cycle_path = [] cycle_path = []
def dfs(node: str) -> bool: def dfs(node: str) -> bool:
"""DFS 检测环,返回是否找到环""" """DFS 检测环,返回是否找到环"""
visited.add(node) visited.add(node)
rec_stack.add(node) rec_stack.add(node)
path.append(node) path.append(node)
for neighbor in graph.get(node, []): for neighbor in graph.get(node, []):
if neighbor not in visited: if neighbor not in visited:
if dfs(neighbor): if dfs(neighbor):
@@ -196,19 +264,19 @@ class WorkflowValidator:
cycle_start = path.index(neighbor) cycle_start = path.index(neighbor)
cycle_path.extend([*path[cycle_start:], neighbor]) cycle_path.extend([*path[cycle_start:], neighbor])
return True return True
rec_stack.remove(node) rec_stack.remove(node)
path.pop() path.pop()
return False return False
# 检查所有节点 # 检查所有节点
for node_id in graph: for node_id in graph:
if node_id not in visited: if node_id not in visited:
if dfs(node_id): if dfs(node_id):
return True, cycle_path return True, cycle_path
return False, [] return False, []
@staticmethod @staticmethod
def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]: def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]:
"""验证工作流配置是否可以发布(更严格的验证) """验证工作流配置是否可以发布(更严格的验证)
@@ -221,30 +289,30 @@ class WorkflowValidator:
""" """
# 先执行基础验证 # 先执行基础验证
is_valid, errors = WorkflowValidator.validate(workflow_config) is_valid, errors = WorkflowValidator.validate(workflow_config)
if not is_valid: if not is_valid:
return False, errors return False, errors
# 额外的发布验证 # 额外的发布验证
nodes = workflow_config.get("nodes", []) nodes = workflow_config.get("nodes", [])
# 1. 验证所有节点都有名称 # 1. 验证所有节点都有名称
for node in nodes: for node in nodes:
if node.get("type") not in ["start", "end"] and not node.get("name"): if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
errors.append( errors.append(
f"节点 {node.get('id')} 缺少名称(发布时必须提供)" f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
) )
# 2. 验证所有非 start/end 节点都有配置 # 2. 验证所有非 start/end 节点都有配置
for node in nodes: for node in nodes:
node_type = node.get("type") node_type = node.get("type")
if node_type not in ["start", "end"]: if node_type not in [NodeType.START, NodeType.CYCLE_START, NodeType.END, NodeType.BREAK]:
config = node.get("config") config = node.get("config")
if not config or not isinstance(config, dict): if not config or not isinstance(config, dict):
errors.append( errors.append(
f"节点 {node.get('id')} 缺少配置(发布时必须提供)" f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
) )
# 3. 验证必填变量 # 3. 验证必填变量
variables = workflow_config.get("variables", []) variables = workflow_config.get("variables", [])
required_vars = [v for v in variables if v.get("required")] required_vars = [v for v in variables if v.get("required")]
@@ -254,13 +322,13 @@ class WorkflowValidator:
f"工作流包含 {len(required_vars)} 个必填变量: " f"工作流包含 {len(required_vars)} 个必填变量: "
f"{[v.get('name') for v in required_vars]}" f"{[v.get('name') for v in required_vars]}"
) )
return len(errors) == 0, errors return len(errors) == 0, errors
def validate_workflow_config( def validate_workflow_config(
workflow_config: dict[str, Any], workflow_config: dict[str, Any],
for_publish: bool = False for_publish: bool = False
) -> tuple[bool, list[str]]: ) -> tuple[bool, list[str]]:
"""验证工作流配置(便捷函数) """验证工作流配置(便捷函数)

View File

@@ -198,19 +198,22 @@ class VariablePool:
namespace = selector[0] namespace = selector[0]
if namespace != "conv": if namespace != "conv" and namespace not in self.state["cycle_nodes"]:
raise ValueError("只能设置会话变量 (conv.*)") raise ValueError("Only conversation or cycle variables can be assigned.")
key = selector[1] key = selector[1]
# 确保 variables 结构存在 # 确保 variables 结构存在
if "variables" not in self.state: if "variables" not in self.state:
self.state["variables"] = {"sys": {}, "conv": {}} self.state["variables"] = {"sys": {}, "conv": {}}
if "conv" not in self.state["variables"]: if namespace == "conv":
self.state["variables"]["conv"] = {} if "conv" not in self.state["variables"]:
self.state["variables"]["conv"] = {}
# 设置值
self.state["variables"]["conv"][key] = value # 设置值
self.state["variables"]["conv"][key] = value
elif namespace in self.state["cycle_nodes"]:
self.state["runtime_vars"][namespace][key] = value
logger.debug(f"设置变量: {'.'.join(selector)} = {value}") logger.debug(f"设置变量: {'.'.join(selector)} = {value}")

View File

@@ -26,6 +26,7 @@ class Document(Base):
"html4excel": False, "html4excel": False,
"graphrag": { "graphrag": {
"use_graphrag": False, "use_graphrag": False,
"scene_name": "",
"entity_types": [ "entity_types": [
"organization", "organization",
"person", "person",
@@ -33,7 +34,9 @@ class Document(Base):
"event", "event",
"category" "category"
], ],
"method": "general" "method": "general",
"resolution": True,
"community": True
} }
}, comment="default parser config") }, comment="default parser config")
chunk_num = Column(Integer, default=0, comment="chunk num") chunk_num = Column(Integer, default=0, comment="chunk num")

View File

@@ -65,6 +65,7 @@ class Knowledge(Base):
"html4excel": False, "html4excel": False,
"graphrag": { "graphrag": {
"use_graphrag": False, "use_graphrag": False,
"scene_name": "",
"entity_types": [ "entity_types": [
"organization", "organization",
"person", "person",
@@ -72,7 +73,9 @@ class Knowledge(Base):
"event", "event",
"category" "category"
], ],
"method": "general" "method": "general",
"resolution": True,
"community": True
} }
}, },
comment="default parser config") comment="default parser config")

View File

@@ -58,13 +58,13 @@ def get_chunked_knowledgeids(
) -> list: ) -> list:
""" """
Query the list of vectorized knowledge base IDs Query the list of vectorized knowledge base IDs
Return: list[UUID] - List of knowledge base IDs Return: list[(id,workspace_id)] - List of knowledge base id and workspace_id
""" """
db_logger.debug(f"Query the list of vectorized knowledge base IDs: filters_count={len(filters)}") db_logger.debug(f"Query the list of vectorized knowledge base IDs: filters_count={len(filters)}")
try: try:
# Only query the id field # Only query the id field
query = db.query(Knowledge.id) query = db.query(Knowledge.id, Knowledge.workspace_id)
# Apply filter conditions # Apply filter conditions
for filter_cond in filters: for filter_cond in filters:
@@ -74,8 +74,8 @@ def get_chunked_knowledgeids(
items = query.all() items = query.all()
db_logger.info(f"Querying the vectorized knowledge base id list succeeded: count={len(items)}") db_logger.info(f"Querying the vectorized knowledge base id list succeeded: count={len(items)}")
# Return the list of IDs directly. Since only the ID field is queried, the returned data is a single column # Return the list of ID and workspace_id directly. Since only the ID and workspace_id field is queried
return [item[0] for item in items] return items
except Exception as e: except Exception as e:
db_logger.error(f"Querying the vectorized knowledge base id list failed: {str(e)}") db_logger.error(f"Querying the vectorized knowledge base id list failed: {str(e)}")
raise raise

View File

@@ -61,14 +61,14 @@ def get_source_kb_ids_by_target_kb_id(
) -> list: ) -> list:
""" """
Query the original knowledge base ID list by sharing the knowledge base Query the original knowledge base ID list by sharing the knowledge base
Return: list[UUID] - List of knowledge base IDs Return: list[(source_kb_id,source_workspace_id)] - List of knowledge base source_kb_id and source_workspace_id
""" """
db_logger.debug( db_logger.debug(
f"Query the original knowledge base id list by sharing the knowledge base: filters_count={len(filters)}") f"Query the original knowledge base id list by sharing the knowledge base: filters_count={len(filters)}")
try: try:
# Only query the id field # Only query the id field
query = db.query(KnowledgeShare.source_kb_id) query = db.query(KnowledgeShare.source_kb_id, KnowledgeShare.source_workspace_id)
# Apply filter conditions # Apply filter conditions
for filter_cond in filters: for filter_cond in filters:
@@ -78,8 +78,8 @@ def get_source_kb_ids_by_target_kb_id(
items = query.all() items = query.all()
db_logger.info(f"Successfully queried the original knowledge base ID list by sharing the knowledge base: count={len(items)}") db_logger.info(f"Successfully queried the original knowledge base ID list by sharing the knowledge base: count={len(items)}")
# Return the list of IDs directly. Since only the ID field is queried, the returned data is a single column # Return the list of source_kb_id and source_workspace_id directly. Since only the source_kb_id and source_workspace_id field is queried
return [item[0] for item in items] return items
except Exception as e: except Exception as e:
db_logger.error(f"Failed to query the original knowledge base ID list through knowledge base sharing: {str(e)}") db_logger.error(f"Failed to query the original knowledge base ID list through knowledge base sharing: {str(e)}")
raise raise

View File

@@ -32,6 +32,7 @@ class KnowledgeRetrievalConfig(BaseModel):
) )
reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID") reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID")
reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数") reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数")
use_graph: bool = Field(default=False, description="是否使用图搜索")
class ToolConfig(BaseModel): class ToolConfig(BaseModel):

View File

@@ -10,6 +10,7 @@ class RetrieveType(StrEnum):
PARTICIPLE = "participle" PARTICIPLE = "participle"
SEMANTIC = "semantic" SEMANTIC = "semantic"
HYBRID = "hybrid" HYBRID = "hybrid"
Graph = "graph"
class ChunkCreate(BaseModel): class ChunkCreate(BaseModel):

View File

@@ -12,8 +12,8 @@ class Memory_Reflection(BaseModel):
config_id: Optional[int] = None config_id: Optional[int] = None
reflection_enabled: bool reflection_enabled: bool
reflection_period_in_hours: str reflection_period_in_hours: str
reflexion_range: str reflexion_range: Optional[str] = "partial"
baseline: str baseline: Optional[str] = "TIME"
reflection_model_id: str reflection_model_id: str
memory_verify: bool memory_verify: bool
quality_assessment: bool quality_assessment: bool

View File

@@ -20,6 +20,7 @@ class NodeDefinition(BaseModel):
id: str = Field(..., description="节点唯一标识") id: str = Field(..., description="节点唯一标识")
type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code") type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code")
name: str | None = Field(None, description="节点名称") name: str | None = Field(None, description="节点名称")
cycle: str | None = Field(None, description="父循环节点id")
description: str | None = Field(None, description="节点描述") description: str | None = Field(None, description="节点描述")
config: dict[str, Any] = Field(default_factory=dict, description="节点配置") config: dict[str, Any] = Field(default_factory=dict, description="节点配置")
position: dict[str, float] | None = Field(None, description="节点位置 {x, y}") position: dict[str, float] | None = Field(None, description="节点位置 {x, y}")

View File

@@ -42,7 +42,7 @@ nodes:
- 适当使用格式化(如列表、段落)提高可读性 - 适当使用格式化(如列表、段落)提高可读性
- role: user - role: user
content: "{{ sys.message }}" content: "{{sys.message}}"
model_id: null model_id: null
temperature: 0.7 temperature: 0.7
@@ -55,7 +55,7 @@ nodes:
type: end type: end
name: 结束 name: 结束
config: config:
output: "{{ llm_qa.output }}" output: "{{llm_qa.output}}"
position: position:
x: 900 x: 900
y: 100 y: 100

View File

@@ -135,6 +135,8 @@ dependencies = [
"graspologic==3.4.5.dev2", "graspologic==3.4.5.dev2",
"markdown-to-json==2.1.1", "markdown-to-json==2.1.1",
"valkey==6.0.2", "valkey==6.0.2",
"python-calamine>=0.4.0",
"xlrd==2.0.2"
] ]
[tool.pytest.ini_options] [tool.pytest.ini_options]

View File

@@ -129,3 +129,5 @@ editdistance==0.8.1
graspologic==3.4.5.dev2 graspologic==3.4.5.dev2
markdown-to-json==2.1.1 markdown-to-json==2.1.1
valkey==6.0.2 valkey==6.0.2
python-calamine>=0.4.0
xlrd==2.0.2