Merge branch 'refs/heads/develop' into feature/20251219_xjn

This commit is contained in:
谢俊男
2025-12-30 21:05:31 +08:00
47 changed files with 1313 additions and 403 deletions

View File

@@ -18,6 +18,9 @@ from app.schemas.response_schema import ApiResponse
from app.core.response_utils import success
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.rag.common.settings import kg_retriever
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.embedding_model import OpenAIEmbed
from app.core.logging_config import get_api_logger
# Obtain a dedicated API logger
@@ -389,36 +392,41 @@ async def retrieve_chunks(
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
existing_ids = knowledge_service.get_chunded_knowledgeids(
private_items = knowledge_service.get_chunded_knowledgeids(
db=db,
filters=filters,
current_user=current_user
)
private_kb_ids = [item[0] for item in private_items]
private_workspace_ids = [item[1] for item in private_items]
filters = [
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
share_ids = knowledge_service.get_chunded_knowledgeids(
items = knowledge_service.get_chunded_knowledgeids(
db=db,
filters=filters,
current_user=current_user
)
if share_ids:
if items:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids)
]
items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
share_items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters,
current_user=current_user
)
existing_ids.extend(items)
if not existing_ids:
share_kb_ids = [item[0] for item in share_items]
share_workspace_ids = [item[1] for item in share_items]
private_kb_ids.extend(share_kb_ids)
private_workspace_ids.extend(share_workspace_ids)
if not private_kb_ids:
return success(data=[], msg="retrieval successful")
kb_id = existing_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
kb_id = private_kb_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in private_kb_ids]
indices = ",".join(uuid_strs)
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
@@ -448,4 +456,21 @@ async def retrieve_chunks(
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
# Prepare to configure chat_mdl、embedding_model、vision_model information
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base
)
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
rs.insert(0, doc)
return success(data=rs, msg="retrieval successful")

View File

@@ -86,5 +86,5 @@
- **quality_assessment**:
quality_assessment=true时输出评估对象否则为null注意- summary输出的结果不允许含有expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
- **memory_verify**: memory_verify=true时输出隐私检测对象否则为null
(注意:- summary输出的结果不允许含有expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
(注意:- summary输出的结果不允许含有expired_at设为2024-01-01T00:00:00Z、memory_verify=true\memory_verify=false)等原数据字段以及涉及需要修改的字段以及内容)
模式参考:{{ json_schema }}

View File

@@ -1,4 +1,3 @@
from typing import Any, Dict, List, Optional, Sequence, Type, Union
from copy import deepcopy
from urllib.parse import urlparse
@@ -8,8 +7,10 @@ from langchain_core.callbacks import Callbacks
from app.core.models.base import RedBearModelConfig, get_provider_rerank_class, RedBearModelFactory
from app.models import ModelProvider
class RedBearRerank(BaseDocumentCompressor):
""" Rerank → 作为 Runnable 插入任意 LCEL 链"""
def __init__(self, config: RedBearModelConfig):
self._model = self._create_model(config)
self._config = config
@@ -22,10 +23,10 @@ class RedBearRerank(BaseDocumentCompressor):
return model_class(**model_params)
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using Jina's Rerank API.
@@ -46,17 +47,17 @@ class RedBearRerank(BaseDocumentCompressor):
compressed.append(doc_copy)
return compressed
def rerank(
self,
documents: Sequence[Union[str, Document, dict]],
query: str,
*,
top_n: Optional[int] = -1,
) -> List[Dict[str, Any]]:
provider = self._config.provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
self,
documents: Sequence[Union[str, Document, dict]],
query: str,
*,
top_n: Optional[int] = -1,
) -> List[Dict[str, Any]]:
provider = self._config.provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
import langchain_community.document_compressors.jina_rerank as jina_mod
# 规范化:如果不以 /v1/rerank 结尾,则补齐;若已以 /v1 结尾,则补 /rerank
def _normalize_jina_base(base_url: Optional[str]) -> Optional[str]:
if not base_url:
@@ -73,8 +74,7 @@ class RedBearRerank(BaseDocumentCompressor):
# 设置完整的 rerank 端点,例如 http://host:port/v1/rerank
jina_mod.JINA_API_URL = jina_base
from langchain_community.document_compressors import JinaRerank
model_instance : JinaRerank = self._model
return model_instance.rerank(documents = documents, query = query, top_n=top_n)
model_instance: JinaRerank = self._model
return model_instance.rerank(documents=documents, query=query, top_n=top_n)
else:
raise ValueError(f"不支持的模型提供商: {provider}")

View File

@@ -51,7 +51,7 @@ def chunk(filename, binary, lang, callback=None, vision_model=None, **kwargs):
img_binary = io.BytesIO()
img.save(img_binary, format="JPEG")
img_binary.seek(0)
ans = vision_model.describe(img_binary.read())
ans, ans_num_tokens = vision_model.describe(img_binary.read())
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
txt += "\n" + ans
tokenize(doc, txt, eng)

View File

@@ -42,11 +42,14 @@ class RAGExcelParser:
file_like_object.seek(0)
try:
dfs = pd.read_excel(file_like_object, sheet_name=None)
if isinstance(dfs, dict):
dfs = next(iter(dfs.values()))
return RAGExcelParser._dataframe_to_workbook(dfs)
except Exception as ex:
logging.info(f"pandas with default engine load error: {ex}, try calamine instead")
file_like_object.seek(0)
df = pd.read_excel(file_like_object, engine="calamine")
print(df)
return RAGExcelParser._dataframe_to_workbook(df)
except Exception as e_pandas:
raise Exception(f"pandas.read_excel error: {e_pandas}, original openpyxl error: {e}")

View File

@@ -4,6 +4,7 @@ from collections import defaultdict
from copy import deepcopy
import json_repair
import pandas as pd
import time
import trio
from app.core.rag.common.misc_utils import get_uuid
@@ -262,21 +263,21 @@ class KGSearch(Dealer):
relas = ""
return {
"chunk_id": get_uuid(),
"content_ltks": "",
"page_content": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms,
comm_topn, max_token),
"page_content": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms, comm_topn, max_token),
"vector": None,
"metadata": {
"doc_id": get_uuid(),
"file_id": "",
"file_name": "Related content in Knowledge Graph",
"file_created_at": int(time.time() * 1000),
"document_id": "",
"docnm_kwd": "Related content in Knowledge Graph",
"kb_id": kb_ids,
"important_kwd": [],
"image_id": "",
"similarity": 1.,
"vector_similarity": 1.,
"term_similarity": 0,
"vector": [],
"positions": [],
}
"knowledge_id": kb_ids,
"sort_id": 0,
"status": 1,
"score": 1
},
"children": None
}
def _community_retrieval_(self, entities, condition, kb_ids, idxnms, topn, max_token):
## Community retrieval

View File

@@ -26,6 +26,8 @@ from app.core.rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr
from app.core.rag.common.string_utils import remove_redundant_spaces
from app.core.rag.common.float_utils import get_float
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
from app.core.rag.llm.chat_model import Base
from app.core.rag.llm.embedding_model import OpenAIEmbed
def knowledge_retrieval(
@@ -48,6 +50,7 @@ def knowledge_retrieval(
- merge_strategy: "weight" or other strategies
- reranker_id: UUID of the reranker to use
- reranker_top_k: int
- use_graph: bool, whether to use a graph
Returns:
Rearranged document block list (in descending order of relevance)
@@ -59,6 +62,7 @@ def knowledge_retrieval(
merge_strategy = config.get("merge_strategy", "weight")
reranker_id = config.get("reranker_id")
reranker_top_k = config.get("reranker_top_k", 1024)
use_graph = config.get("use_graph", "false").lower() == "true"
file_names_filter = []
if user_ids:
@@ -67,6 +71,10 @@ def knowledge_retrieval(
if not knowledge_bases:
return []
kb_ids = []
workspace_ids = []
chat_model = None
embedding_model = None
all_results = []
# Search each knowledge base
for kb_config in knowledge_bases:
@@ -87,6 +95,22 @@ def knowledge_retrieval(
else:
continue
if str(db_knowledge.id) not in kb_ids:
kb_ids.append(str(db_knowledge.id))
if str(db_knowledge.workspace_id) not in workspace_ids:
workspace_ids.append(str(db_knowledge.workspace_id))
if not chat_model:
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
if not embedding_model:
embedding_model = OpenAIEmbed(
key=db_knowledge.embedding.api_keys[0].api_key,
model_name=db_knowledge.embedding.api_keys[0].model_name,
base_url=db_knowledge.embedding.api_keys[0].api_base
)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# Retrieve according to the configured retrieval type
match kb_config["retrieve_type"]:
@@ -136,6 +160,12 @@ def knowledge_retrieval(
# Use the specified reranker for re-ranking
if reranker_id:
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
# use graph
if use_graph:
from app.core.rag.common.settings import kg_retriever
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
if doc:
all_results.insert(0, doc)
return all_results
except Exception as e:

View File

@@ -213,7 +213,7 @@ class ESConnection(DocStoreConnection):
m.topn * 2,
query_vector=list(m.embedding_data),
filter=bqry.to_dict(),
similarity=similarity,
# similarity=similarity
)
if bqry and rank_feature:

View File

@@ -4,9 +4,9 @@
基于 LangGraph 的工作流执行引擎。
"""
import logging
# import uuid
import datetime
import logging
from typing import Any
from langchain_core.messages import HumanMessage
@@ -107,7 +107,13 @@ class WorkflowExecutor:
"user_id": self.user_id,
"error": None,
"error_node": None,
"streaming_buffer": {} # 流式缓冲区
"streaming_buffer": {}, # 流式缓冲区
"cycle_nodes": [
node.get("id")
for node in self.workflow_config.get("nodes")
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
], # loop, iteration node id
"looping": False # loop runing flag, only use in loop node,not use in main loop
}
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
@@ -199,6 +205,10 @@ class WorkflowExecutor:
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
cycle_node = node.get("cycle")
if cycle_node:
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
continue
# 记录 start 和 end 节点 ID
if node_type == NodeType.START:
@@ -271,7 +281,7 @@ class WorkflowExecutor:
workflow.add_edge(START, start_node_id)
logger.debug(f"添加边: START -> {start_node_id}")
for edge in self.edges:
for edge in self.workflow_config.get("edges", []):
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
@@ -284,12 +294,12 @@ class WorkflowExecutor:
logger.debug(f"添加边: {source} -> {target}")
continue
# 处理到 end 节点的边
if target in end_node_ids:
# 连接到 end 节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# # 处理到 end 节点的边
# if target in end_node_ids:
# # 连接到 end 节点
# workflow.add_edge(source, target)
# logger.debug(f"添加边: {source} -> {target}")
# continue
# 跳过错误边(在节点内部处理)
if edge_type == "error":
@@ -297,22 +307,30 @@ class WorkflowExecutor:
if condition:
# 条件边
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
def make_router(cond, tgt):
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
workflow.add_conditional_edges(source, router)
def router_fn(state: WorkflowState):
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END
# 动态修改函数名,避免重复
router_fn.__name__ = f"router_{tgt}"
return router_fn
router_fn = make_router(condition, target)
workflow.add_conditional_edges(source, router_fn)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边

View File

@@ -74,6 +74,7 @@ class ExpressionEvaluator:
# 为了向后兼容,也支持直接访问(但会在日志中警告)
context.update(variables)
context["nodes"] = node_outputs
context.update(node_outputs)
try:
# simpleeval 只支持安全的操作:

View File

@@ -6,7 +6,7 @@ from app.core.workflow.expression_evaluator import ExpressionEvaluator
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import AssignmentOperator
from app.core.workflow.nodes.operators import AssignmentOperatorInstance
from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -40,8 +40,8 @@ class AssignerNode(BaseNode):
variable_selector = expression.split('.')
# Only conversation variables ('conv') are allowed
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
raise ValueError("Only conversation variables can be assigned.")
if variable_selector[0] != 'conv' and variable_selector[0] not in state["cycle_nodes"]:
raise ValueError("Only conversation or cycle variables can be assigned.")
# Get the value or expression to assign
value = assignment.value
@@ -55,7 +55,9 @@ class AssignerNode(BaseNode):
)
# Select the appropriate assignment operator instance based on the target variable type
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value(
pool.get(variable_selector)
)(
pool, variable_selector, value
)
@@ -81,3 +83,5 @@ class AssignerNode(BaseNode):
operator.remove_last()
case _:
raise ValueError(f"Invalid Operator: {assignment.operation}")
logger.info(f"Node {self.node_id}: execution completed")

View File

@@ -14,9 +14,13 @@ class VariableType(StrEnum):
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_BOOLEAN = "array[boolean]"
ARRAY_OBJECT = "array[object]"
class VariableDefinition(BaseModel):
"""变量定义

View File

@@ -20,40 +20,44 @@ logger = logging.getLogger(__name__)
class WorkflowState(TypedDict):
"""工作流状态
在节点间传递的状态对象,包含消息、变量、节点输出等信息。
"""Workflow state
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
"""
# 消息列表(追加模式)
# List of messages (append mode)
messages: Annotated[list[AnyMessage], add]
# 输入变量(从配置的 variables 传入)
# 使用深度合并函数,支持嵌套字典的更新(如 conv.xxx
# Set of loop node IDs, used for assigning values in loop nodes
cycle_nodes: list
looping: bool
# Input variables (passed from configured variables)
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
variables: Annotated[dict[str, Any], lambda x, y: {
**x,
**{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v
for k, v in y.items()}
}]
# 节点输出(存储每个节点的执行结果,用于变量引用)
# 使用自定义合并函数,将新的节点输出合并到现有字典中
# Node outputs (stores execution results of each node for variable references)
# Uses a custom merge function to combine new node outputs into the existing dictionary
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# 运行时节点变量(简化版,只存储业务数据,供节点间快速访问)
# 格式:{node_id: business_result}
# Runtime node variables (simplified version, stores business data for fast access between nodes)
# Format: {node_id: business_result}
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# 执行上下文
# Execution context
execution_id: str
workspace_id: str
user_id: str
# 错误信息(用于错误边)
# Error information (for error edges)
error: str | None
error_node: str | None
# 流式缓冲区(存储节点的实时流式输出)
# 格式:{node_id: {"chunks": [...], "full_content": "..."}}
# Streaming buffer (stores real-time streaming output of nodes)
# Format: {node_id: {"chunks": [...], "full_content": "..."}}
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
@@ -74,6 +78,7 @@ class BaseNode(ABC):
self.workflow_config = workflow_config
self.node_id = node_config["id"]
self.node_type = node_config["type"]
self.cycle = node_config.get("cycle")
self.node_name = node_config.get("name", self.node_id)
# 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {}
@@ -170,10 +175,10 @@ class BaseNode(ABC):
import time
start_time = time.time()
timeout = self.get_timeout()
try:
timeout = self.get_timeout()
# 调用节点的业务逻辑
business_result = await asyncio.wait_for(
self.execute(state),
@@ -200,7 +205,8 @@ class BaseNode(ABC):
**wrapped_output,
"runtime_vars": {
self.node_id: runtime_var
}
},
"looping": state["looping"]
}
except TimeoutError:
@@ -236,10 +242,10 @@ class BaseNode(ABC):
import time
start_time = time.time()
timeout = self.get_timeout()
try:
timeout = self.get_timeout()
# Get LangGraph's stream writer for sending custom data
writer = get_stream_writer()

View File

@@ -0,0 +1,3 @@
from app.core.workflow.nodes.breaker.node import BreakNode
__all__ = ["BreakNode"]

View File

@@ -0,0 +1,33 @@
import logging
from typing import Any
from app.core.workflow.nodes import BaseNode, WorkflowState
logger = logging.getLogger(__name__)
class BreakNode(BaseNode):
"""
Workflow node that immediately stops loop execution.
When executed, this node sets the 'looping' flag in the workflow state
to False, signaling the outer loop runtime to terminate further iterations.
"""
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the break node.
Args:
state: Current workflow state, including loop control flags.
Effects:
- Sets 'looping' in the state to False to stop the loop.
- Logs the action for debugging purposes.
Returns:
Optional dictionary indicating the loop has been stopped.
"""
state["looping"] = False
logger.info(f"Setting cycle node exit flag, cycle={self.cycle}, looping={state['looping']}")

View File

@@ -22,6 +22,7 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregato
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
__all__ = [
# 基础类
"BaseNodeConfig",
@@ -41,5 +42,7 @@ __all__ = [
"JinjaRenderNodeConfig",
"VariableAggregatorNodeConfig",
"ParameterExtractorNodeConfig",
"LoopNodeConfig",
"IterationNodeConfig",
"QuestionClassifierNodeConfig"
]

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode
__all__ = ['CycleGraphNode', 'LoopNodeConfig', 'IterationNodeConfig']

View File

@@ -0,0 +1,96 @@
from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
class CycleVariable(BaseNodeConfig):
name: str = Field(
...,
description="Name of the loop variable"
)
type: VariableType = Field(
...,
description="Data type of the loop variable"
)
value: str = Field(
...,
description="Initial or current value of the loop variable"
)
class ConditionDetail(BaseModel):
comparison_operator: ComparisonOperator = Field(
...,
description="Operator used to compare the left and right operands"
)
left: str = Field(
...,
description="Left-hand operand of the comparison expression"
)
right: str = Field(
...,
description="Right-hand operand of the comparison expression"
)
class ConditionsConfig(BaseModel):
"""Configuration for loop condition evaluation"""
logical_operator: LogicOperator = Field(
default=LogicOperator.AND.value,
description="Logical operator used to combine multiple condition expressions"
)
expressions: list[ConditionDetail] = Field(
...,
description="Collection of condition expressions to be evaluated"
)
class LoopNodeConfig(BaseNodeConfig):
condition: ConditionsConfig = Field(
default_factory=list,
description="Conditional configuration that controls loop execution"
)
cycle_vars: list[CycleVariable] = Field(
default_factory=list,
description="List of variables used and updated during the loop"
)
max_loop: int = Field(
default=10,
description="Maximum number of loop iterations"
)
class IterationNodeConfig(BaseNodeConfig):
input: str = Field(
...,
description="Input of the loop iteration"
)
parallel: bool = Field(
default=False,
description="Whether to execute loop iterations in parallel"
)
parallel_count: int = Field(
default=4,
description="Number of iterations to run in parallel"
)
flatten: bool = Field(
default=False,
description="Whether to flatten the output list from iterations"
)
output: str = Field(
...,
description="Output of the loop iteration"
)

View File

@@ -0,0 +1,154 @@
import asyncio
import copy
import logging
import re
from typing import Any
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class IterationRuntime:
"""
Runtime executor for loop/iteration nodes in a workflow.
This class handles executing iterations over a list variable, supporting
optional parallel execution, flattening of output, and loop control via
the workflow state.
"""
def __init__(
self,
graph: CompiledStateGraph,
node_id: str,
config: dict[str, Any],
state: WorkflowState,
):
"""
Initialize the iteration runtime.
Args:
graph: Compiled workflow graph capable of async invocation.
node_id: Unique identifier of the loop node.
config: Dictionary containing iteration node configuration.
state: Current workflow state at the point of iteration.
"""
self.graph = graph
self.state = state
self.node_id = node_id
self.typed_config = IterationNodeConfig(**config)
self.looping = True
self.output_value = None
self.result: list = []
def _init_iteration_state(self, item, idx):
"""
Initialize a per-iteration copy of the workflow state.
Args:
item: Current element from the input array for this iteration.
idx: Index of the element in the input array.
Returns:
A deep copy of the workflow state with iteration-specific variables set.
"""
loopstate = WorkflowState(
**copy.deepcopy(self.state)
)
loopstate["runtime_vars"][self.node_id] = {
"item": item,
"index": idx,
}
loopstate["node_outputs"][self.node_id] = {
"item": item,
"index": idx,
}
loopstate["looping"] = True
return loopstate
async def run_task(self, item, idx):
"""
Execute a single iteration asynchronously.
Args:
item: The input element for this iteration.
idx: The index of this iteration.
"""
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
output = VariablePool(result).get(self.output_value)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
if not result["looping"]:
self.looping = False
def _create_iteration_tasks(self, array_obj, idx):
"""
Create async tasks for a batch of iterations based on parallel count.
Args:
array_obj: The input array to iterate over.
idx: Starting index for this batch of iterations.
Returns:
List of coroutine tasks ready to be executed in parallel.
"""
tasks = []
for i in range(self.typed_config.parallel_count):
if idx + i >= len(array_obj):
break
item = array_obj[idx + i]
tasks.append(self.run_task(item, idx + i))
return tasks
async def run(self):
"""
Execute the loop over the input array according to configuration.
Returns:
A list of outputs from all iterations, optionally flattened.
Raises:
RuntimeError: If the input variable is not a list.
"""
pattern = r"\{\{\s*(.*?)\s*\}\}"
input_expression = re.sub(pattern, r"\1", self.typed_config.input).strip()
self.output_value = re.sub(pattern, r"\1", self.typed_config.output).strip()
array_obj = VariablePool(self.state).get(input_expression)
if not isinstance(array_obj, list):
raise RuntimeError("Cannot iterate over a non-list variable")
idx = 0
if self.typed_config.parallel:
# Execute iterations in parallel batches
while idx < len(array_obj) and self.looping:
tasks = self._create_iteration_tasks(array_obj, idx)
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
idx += self.typed_config.parallel_count
await asyncio.gather(*tasks)
logger.info(f"Iteration node {self.node_id}: execution completed")
return self.result
else:
# Execute iterations sequentially
while idx < len(array_obj) and self.looping:
logger.info(f"Iteration node {self.node_id}: running")
item = array_obj[idx]
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
output = VariablePool(result).get(self.output_value)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
if not result["looping"]:
self.looping = False
idx += 1
logger.info(f"Iteration node {self.node_id}: execution completed")
return self.result

View File

@@ -0,0 +1,130 @@
import logging
from typing import Any
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class LoopRuntime:
"""
Runtime executor for loop nodes in a workflow.
Handles iterative execution of a loop node according to defined loop variables
and conditional expressions. Supports maximum loop count and loop control
through the workflow state.
"""
def __init__(
self,
graph: CompiledStateGraph,
node_id: str,
config: dict[str, Any],
state: WorkflowState,
):
"""
Initialize the loop runtime.
Args:
graph: Compiled workflow graph capable of async invocation.
node_id: Unique identifier of the loop node.
config: Dictionary containing loop node configuration.
state: Current workflow state at the point of loop execution.
"""
self.graph = graph
self.state = state
self.node_id = node_id
self.typed_config = LoopNodeConfig(**config)
def _init_loop_state(self):
"""
Initialize workflow state for loop execution.
- Evaluates initial values of loop variables.
- Stores loop variables in runtime_vars and node_outputs.
- Marks the loop as active by setting 'looping' to True.
Returns:
A copy of the workflow state prepared for the loop execution.
"""
pool = VariablePool(self.state)
# 循环变量
self.state["runtime_vars"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
for variable in self.typed_config.cycle_vars
}
self.state["node_outputs"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
for variable in self.typed_config.cycle_vars
}
loopstate = WorkflowState(
**self.state
)
loopstate["looping"] = True
return loopstate
def _get_loop_expression(self):
"""
Build the Python boolean expression for evaluating the loop condition.
- Converts each condition in the loop configuration into a Python expression string.
- Combines multiple conditions with the configured logical operator (AND/OR).
Returns:
A string representing the combined loop condition expression.
"""
branch_conditions = [
ConditionExpressionBuilder(
left=condition.left,
operator=condition.comparison_operator,
right=condition.right
).build()
for condition in self.typed_config.condition.expressions
]
if len(branch_conditions) > 1:
combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions)
else:
combined_condition = branch_conditions[0]
return combined_condition
async def run(self):
"""
Execute the loop node until the condition is no longer met, the loop is
manually stopped, or the maximum loop count is reached.
Returns:
The final runtime variables of this loop node after completion.
"""
loopstate = self._init_loop_state()
expression = self._get_loop_expression()
loop_variable_pool = VariablePool(loopstate)
loop_time = self.typed_config.max_loop
while evaluate_condition(
expression=expression,
variables=loop_variable_pool.get_all_conversation_vars(),
node_outputs=loop_variable_pool.get_all_node_outputs(),
system_vars=loop_variable_pool.get_all_system_vars(),
) and loopstate["looping"] and loop_time > 0:
logger.info(f"loop node {self.node_id}: running")
await self.graph.ainvoke(loopstate)
loop_time -= 1
logger.info(f"loop node {self.node_id}: execution completed")
return loopstate["runtime_vars"][self.node_id]

View File

@@ -0,0 +1,226 @@
import logging
from typing import Any
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__)
class CycleGraphNode(BaseNode):
"""
Node representing a cycle (loop) subgraph within the workflow.
This node manages internal loop/iteration nodes, builds a subgraph
for execution, handles conditional routing, and executes loop
or iteration logic based on node type.
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
self.cycle_nodes = list() # Nodes belonging to this cycle
self.cycle_edges = list() # Edges connecting nodes within the cycle
self.start_node_id = None # ID of the start node within the cycle
self.end_node_ids = [] # IDs of end nodes within the cycle
self.graph: StateGraph | CompiledStateGraph | None = None
self.build_graph()
self.iteration_flag = True
def pure_cycle_graph(self) -> tuple[list, list]:
"""
Extract cycle nodes and internal edges from the workflow configuration,
removing them from the global workflow.
Raises:
ValueError: If cycle nodes are connected to external nodes improperly.
Returns:
Tuple containing:
- cycle_nodes: List of removed nodes
- cycle_edges: List of removed edges
"""
nodes = self.workflow_config.get("nodes", [])
edges = self.workflow_config.get("edges", [])
# Select all nodes that belong to the current cycle
cycle_nodes = [node for node in nodes if node.get("cycle") == self.node_id]
cycle_node_ids = {node.get("id") for node in cycle_nodes}
cycle_edges = []
remain_edges = []
for edge in edges:
source_in = edge.get("source") in cycle_node_ids
target_in = edge.get("target") in cycle_node_ids
# Raise error if cycle nodes are connected with external nodes
if source_in ^ target_in:
raise ValueError(
f"Cycle node is connected to external node, "
f"source: {edge.get('source')}, target: {edge.get('target')}"
)
if source_in and target_in:
cycle_edges.append(edge)
else:
remain_edges.append(edge)
# Update workflow_config by removing cycle nodes and internal edges
self.workflow_config["nodes"] = [
node for node in nodes if node.get("cycle") != self.node_id
]
self.workflow_config["edges"] = remain_edges
return cycle_nodes, cycle_edges
def create_node(self):
"""
Instantiate node objects for each node in the cycle subgraph and add them to the graph.
Special handling is applied for conditional nodes to generate
edge conditions based on node outputs.
"""
from app.core.workflow.nodes import NodeFactory
for node in self.cycle_nodes:
node_type = node.get("type")
node_id = node.get("id")
if node_type == NodeType.CYCLE_START:
self.start_node_id = node_id
continue
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
expressions = node_instance.build_conditional_edge_expressions()
# Number of branches, usually matches the number of conditional expressions
branch_number = len(expressions)
# Find all edges whose source is the current node
related_edge = [edge for edge in self.cycle_edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(branch_number):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
self.graph.add_node(node_id, make_func(node_instance))
def create_edge(self):
"""
Connect nodes within the cycle subgraph by adding edges to the internal graph.
Conditional edges are routed based on evaluated expressions.
Start and end nodes are connected to global START and END nodes.
"""
for edge in self.cycle_edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == self.start_node_id:
# 但要连接 start 到下一个节点
self.graph.add_edge(START, target)
logger.debug(f"添加边: {source} -> {target}")
continue
if condition:
# 条件边
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
self.graph.add_conditional_edges(source, router)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in self.end_node_ids:
self.graph.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
def build_graph(self):
"""
Build the internal subgraph for the cycle node.
Steps:
1. Extract cycle nodes and edges.
2. Create node instances and add them to the graph.
3. Connect edges and conditional routes.
4. Compile the graph for execution.
"""
self.graph = StateGraph(WorkflowState)
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.create_node()
self.create_edge()
self.graph = self.graph.compile()
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the cycle node at runtime.
Depending on the node type, runs either a loop (LoopRuntime)
or an iteration (IterationRuntime) over the internal subgraph.
Args:
state: Current workflow state.
Returns:
Runtime result of the cycle, typically the final loop/iteration variables.
Raises:
RuntimeError: If node type is unrecognized.
"""
if self.node_type == NodeType.LOOP:
return await LoopRuntime(
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
).run()
if self.node_type == NodeType.ITERATION:
return await IterationRuntime(
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
).run()
raise RuntimeError("Unknown cycle node type")

View File

@@ -1,14 +1,5 @@
from enum import StrEnum
from app.core.workflow.nodes.operators import (
StringOperator,
NumberOperator,
AssignmentOperatorType,
BooleanOperator,
ArrayOperator,
ObjectOperator
)
class NodeType(StrEnum):
START = "start"
@@ -27,6 +18,10 @@ class NodeType(StrEnum):
JINJARENDER = "jinja-render"
VAR_AGGREGATOR = "var-aggregator"
PARAMETER_EXTRACTOR = "parameter-extractor"
LOOP = "loop"
ITERATION = "iteration"
CYCLE_START = "cycle-start"
BREAK = "break"
class ComparisonOperator(StrEnum):
@@ -62,21 +57,6 @@ class AssignmentOperator(StrEnum):
REMOVE_LAST = "remove_last"
REMOVE_FIRST = "remove_first"
@classmethod
def get_operator(cls, obj) -> AssignmentOperatorType:
if isinstance(obj, str):
return StringOperator
elif isinstance(obj, bool):
return BooleanOperator
elif isinstance(obj, (int, float)):
return NumberOperator
elif isinstance(obj, list):
return ArrayOperator
elif isinstance(obj, dict):
return ObjectOperator
raise TypeError(f"Unsupported variable type ({type(obj)})")
class HttpRequestMethod(StrEnum):
GET = "GET"

View File

@@ -215,6 +215,7 @@ class HttpRequestNode(BaseNode):
**self._build_content(state)
)
resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded")
return HttpRequestNodeOutput(
body=resp.text,
status_code=resp.status_code,
@@ -228,12 +229,21 @@ class HttpRequestNode(BaseNode):
else:
match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE:
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning error response"
)
return HttpRequestNodeOutput(
body="",
status_code=resp.status_code,
headers=resp.headers,
).model_dump()
case HttpErrorHandle.DEFAULT:
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning default result"
)
return self.typed_config.error_handle.default.model_dump()
case HttpErrorHandle.BRANCH:
logger.warning(
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
)
return "ERROR"

View File

@@ -30,7 +30,7 @@ class ConditionBranchConfig(BaseModel):
description="Logical operator used to combine multiple condition expressions"
)
conditions: list[ConditionDetail] = Field(
expressions: list[ConditionDetail] = Field(
...,
description="List of condition expressions within this branch"
)
@@ -57,7 +57,7 @@ class IfElseNodeConfig(BaseNodeConfig):
# CASE1 / IF Branch
{
"logical_operator": "and",
"conditions": [
"expressions": [
[
{
"left": "node.userinput.message",
@@ -75,7 +75,7 @@ class IfElseNodeConfig(BaseNodeConfig):
# CASE1 / ELIF Branch
{
"logical_operator": "or",
"conditions": [
"expressions": [
[
{
"left": "node.userinput.test",

View File

@@ -2,93 +2,13 @@ import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import ConditionDetail
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
logger = logging.getLogger(__name__)
class ConditionExpressionBuilder:
"""
Build a Python boolean expression string based on a comparison operator.
This class does not evaluate the expression.
It only generates a valid Python expression string
that can be evaluated later in a workflow context.
"""
def __init__(self, left: str, operator: ComparisonOperator, right: str):
self.left = left
self.operator = operator
self.right = right
def _empty(self):
return f"{self.left} == ''"
def _not_empty(self):
return f"{self.left} != ''"
def _contains(self):
return f"{self.right} in {self.left}"
def _not_contains(self):
return f"{self.right} not in {self.left}"
def _startwith(self):
return f'{self.left}.startswith({self.right})'
def _endwith(self):
return f'{self.left}.endswith({self.right})'
def _eq(self):
return f"{self.left} == {self.right}"
def _ne(self):
return f"{self.left} != {self.right}"
def _lt(self):
return f"{self.left} < {self.right}"
def _le(self):
return f"{self.left} <= {self.right}"
def _gt(self):
return f"{self.left} > {self.right}"
def _ge(self):
return f"{self.left} >= {self.right}"
def build(self):
match self.operator:
case ComparisonOperator.EMPTY:
return self._empty()
case ComparisonOperator.NOT_EMPTY:
return self._not_empty()
case ComparisonOperator.CONTAINS:
return self._contains()
case ComparisonOperator.NOT_CONTAINS:
return self._not_contains()
case ComparisonOperator.START_WITH:
return self._startwith()
case ComparisonOperator.END_WITH:
return self._endwith()
case ComparisonOperator.EQ:
return self._eq()
case ComparisonOperator.NE:
return self._ne()
case ComparisonOperator.LT:
return self._lt()
case ComparisonOperator.LE:
return self._le()
case ComparisonOperator.GT:
return self._gt()
case ComparisonOperator.GE:
return self._ge()
case _:
raise ValueError(f"Invalid condition: {self.operator}")
class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
@@ -143,7 +63,7 @@ class IfElseNode(BaseNode):
branch_conditions = [
self._build_condition_expression(condition)
for condition in case_branch.conditions
for condition in case_branch.expressions
]
if len(branch_conditions) > 1:
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
@@ -174,5 +94,6 @@ class IfElseNode(BaseNode):
for i in range(len(expressions)):
logger.info(expressions[i])
if self._evaluate_condition(expressions[i], state):
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
return f'CASE{i + 1}'
return f'CASE{len(expressions)}'

View File

@@ -1,3 +1,4 @@
import logging
from typing import Any
from app.core.workflow.nodes import WorkflowState
@@ -5,6 +6,7 @@ from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
from app.core.workflow.template_renderer import TemplateRenderer
logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
@@ -41,5 +43,5 @@ class JinjaRenderNode(BaseNode):
res = render.env.from_string(self.typed_config.template).render(**context)
except Exception as e:
raise RuntimeError(f"JinjaRender Node {self.node_name} render failed: {e}") from e
logger.info(f"Node {self.node_id}: Jinja template rendering completed")
return res

View File

@@ -1,18 +1,13 @@
from uuid import UUID
from pydantic import Field
from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.schemas.chunk_schema import RetrieveType
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
query: str = Field(
...,
description="Search query string"
)
kb_ids: list[UUID] = Field(
class KnowledgeBaseConfig(BaseModel):
kb_id: UUID = Field(
...,
description="Knowledge base IDs"
)
@@ -37,18 +32,42 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
description="Retrieve type"
)
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
query: str = Field(
...,
description="Search query string"
)
knowledge_bases: list[KnowledgeBaseConfig] = Field(
...,
description="Knowledge base config"
)
reranker_id: UUID = Field(
...,
description="Reranker top k"
)
reranker_top_k: int = Field(
default=4,
description="Knowledge base top k"
)
class Config:
json_schema_extra = {
"examples": [
{
"query": "{{sys.message}}",
"kb_ids": [
"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
],
"similarity_threshold": 0.2,
"vector_similarity_weight": 0.3,
"top_k": 1,
"retrieve_type": "hybrid"
"knowledge_bases": [{
"kb_id": "xxxxxxxx-xxxx-xxxx-xxxxxxxxxxxxxxxxx",
"similarity_threshold": 0.2,
"vector_similarity_weight": 0.3,
"top_k": 4,
"retrieve_type": "hybrid"
}],
"reranker_top_k": 1,
"reranker_id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
}
]
}

View File

@@ -2,14 +2,18 @@ import logging
import uuid
from typing import Any
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model
from app.models import knowledge_model, knowledgeshare_model, ModelType
from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType
from app.services import knowledge_service, knowledgeshare_service
from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__)
@@ -108,6 +112,44 @@ class KnowledgeRetrievalNode(BaseNode):
existing_ids.extend(items)
return existing_ids
def get_reranker_model(self) -> RedBearRerank:
"""
Retrieve and initialize a RedBear reranker model based on configuration.
Raises:
BusinessException: If configuration is missing or API keys are not set.
RuntimeError: If the configured model is not of type RERANK.
"""
with get_db_read() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.reranker_id)
if not config:
raise BusinessException("Configured model does not exist", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER)
# 在 Session 关闭前提取所有需要的数据
api_config = config.api_keys[0]
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
api_base = api_config.api_base
model_type = config.type
if model_type != ModelType.RERANK:
raise RuntimeError("Model is not a reranker")
reranker = RedBearRerank(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
)
)
return reranker
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the knowledge retrieval workflow node.
@@ -131,38 +173,45 @@ class KnowledgeRetrievalNode(BaseNode):
"""
query = self._render_template(self.typed_config.query, state)
with get_db_read() as db:
existing_ids = self._get_existing_kb_ids(db, self.typed_config.kb_ids)
knowledge_bases = self.typed_config.knowledge_bases
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])
if not existing_ids:
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
kb_id = existing_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
indices = ",".join(uuid_strs)
rs = []
for kb_config in knowledge_bases:
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
match self.typed_config.retrieve_type:
case RetrieveType.PARTICIPLE:
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.similarity_threshold)
case RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.vector_similarity_weight)
case _:
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.similarity_threshold)
# Deduplicate hybrid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2)
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
return [chunk.model_dump() for chunk in rs]
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE:
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold))
case RetrieveType.SEMANTIC:
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight))
case RetrieveType.HYBRID:
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold)
# Deduplicate hybrid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2)
vector_service.reranker = self.get_reranker_model()
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
case _:
raise RuntimeError("Unknown retrieval type")
vector_service.reranker = self.get_reranker_model()
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
logger.info(
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
)
return [chunk.model_dump() for chunk in final_rs]

View File

@@ -10,6 +10,7 @@ from typing import Any, Union
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.assigner import AssignerNode
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.http_request import HttpRequestNode
@@ -22,6 +23,7 @@ from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode
logger = logging.getLogger(__name__)
@@ -39,6 +41,9 @@ WorkflowNode = Union[
JinjaRenderNode,
VariableAggregatorNode,
ParameterExtractorNode,
CycleGraphNode,
BreakNode,
ParameterExtractorNode,
QuestionClassifierNode
]
@@ -64,6 +69,9 @@ class NodeFactory:
NodeType.VAR_AGGREGATOR: VariableAggregatorNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.LOOP: CycleGraphNode,
NodeType.ITERATION: CycleGraphNode,
NodeType.BREAK: BreakNode,
}
@classmethod

View File

@@ -1,6 +1,7 @@
from abc import ABC
from typing import Union, Type
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.variable_pool import VariablePool
@@ -136,6 +137,23 @@ class ObjectOperator(OperatorBase):
self.pool.set(self.left_selector, dict())
class AssignmentOperatorResolver:
@classmethod
def resolve_by_value(cls, value):
if isinstance(value, str):
return StringOperator
elif isinstance(value, bool):
return BooleanOperator
elif isinstance(value, (int, float)):
return NumberOperator
elif isinstance(value, list):
return ArrayOperator
elif isinstance(value, dict):
return ObjectOperator
else:
raise TypeError(f"Unsupported variable type: {type(value)}")
AssignmentOperatorInstance = Union[
StringOperator,
NumberOperator,
@@ -144,3 +162,83 @@ AssignmentOperatorInstance = Union[
ObjectOperator
]
AssignmentOperatorType = Type[AssignmentOperatorInstance]
class ConditionExpressionBuilder:
"""
Build a Python boolean expression string based on a comparison operator.
This class does not evaluate the expression.
It only generates a valid Python expression string
that can be evaluated later in a workflow context.
"""
def __init__(self, left: str, operator: ComparisonOperator, right: str):
self.left = left
self.operator = operator
self.right = right
def _empty(self):
return f"{self.left} == ''"
def _not_empty(self):
return f"{self.left} != ''"
def _contains(self):
return f"{self.right} in {self.left}"
def _not_contains(self):
return f"{self.right} not in {self.left}"
def _startswith(self):
return f'{self.left}.startswith({self.right})'
def _endswith(self):
return f'{self.left}.endswith({self.right})'
def _eq(self):
return f"{self.left} == {self.right}"
def _ne(self):
return f"{self.left} != {self.right}"
def _lt(self):
return f"{self.left} < {self.right}"
def _le(self):
return f"{self.left} <= {self.right}"
def _gt(self):
return f"{self.left} > {self.right}"
def _ge(self):
return f"{self.left} >= {self.right}"
def build(self):
match self.operator:
case ComparisonOperator.EMPTY:
return self._empty()
case ComparisonOperator.NOT_EMPTY:
return self._not_empty()
case ComparisonOperator.CONTAINS:
return self._contains()
case ComparisonOperator.NOT_CONTAINS:
return self._not_contains()
case ComparisonOperator.START_WITH:
return self._startswith()
case ComparisonOperator.END_WITH:
return self._endswith()
case ComparisonOperator.EQ:
return self._eq()
case ComparisonOperator.NE:
return self._ne()
case ComparisonOperator.LT:
return self._lt()
case ComparisonOperator.LE:
return self._le()
case ComparisonOperator.GT:
return self._gt()
case ComparisonOperator.GE:
return self._ge()
case _:
raise ValueError(f"Invalid condition: {self.operator}")

View File

@@ -36,6 +36,11 @@ class ParamsConfig(BaseModel):
description="Description of the parameter"
)
required: bool = Field(
...,
description="Whether the parameter is required"
)
class ParameterExtractorNodeConfig(BaseNodeConfig):
model_id: uuid.UUID = Field(
@@ -52,3 +57,8 @@ class ParameterExtractorNodeConfig(BaseNodeConfig):
...,
description="List of parameters"
)
prompt: str = Field(
...,
description="User-provided supplemental prompt"
)

View File

@@ -1,4 +1,5 @@
import os
import logging
import json_repair
from typing import Any
@@ -15,6 +16,8 @@ from app.db import get_db_read
from app.models import ModelType
from app.services.model_service import ModelConfigService
logger = logging.getLogger(__name__)
class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
@@ -114,7 +117,7 @@ class ParameterExtractorNode(BaseNode):
"""
field_type = {}
for param in self.typed_config.params:
field_type[param.name] = param.type
field_type[param.name] = f'{param.type}, required:{str(param.required)}'
return field_type
async def execute(self, state: WorkflowState) -> Any:
@@ -154,12 +157,12 @@ class ParameterExtractorNode(BaseNode):
messages = [
("system", system_prompt),
("user", self._render_template(self.typed_config.prompt, state)),
("user", rendered_user_prompt),
]
model_resp = await llm.ainvoke(messages)
result = json_repair.repair_json(model_resp.content)
result = json_repair.repair_json(model_resp.content, return_objects=True)
logger.info(f"node: {self.node_id} get params:{result}")
return {
"output": result,
}
return result

View File

@@ -9,43 +9,27 @@ class VariableAggregatorNodeConfig(BaseNodeConfig):
description="输出变量是否需要分组",
)
group_names: list[str] = Field(
default_factory=lambda: ["output"],
description="各个分组的名称"
)
group_variables: list[str] | list[list[str]] = Field(
group_variables: list[str] | dict[str, list[str]] = Field(
...,
description="需要被聚合的变量"
)
@field_validator("group_names", mode="before")
@classmethod
def group_names_validator(cls, v, info):
group_status = info.data.get("group")
if not group_status or not v:
return ["output"]
return v
@field_validator("group_variables")
@classmethod
def group_variables_validator(cls, v, info):
group_status = info.data.get("group")
group_names = info.data.get("group_names")
if not isinstance(v, list):
raise ValueError("group_variables must be a list")
if not group_status:
for variable in v:
if not isinstance(variable, str):
raise ValueError("When group=False, group_variables must be a list of strings")
else:
if len(group_names) != len(v):
raise ValueError("group_names and group_variables length mismatch")
for group in v:
if not isinstance(group, list):
if not isinstance(v, dict):
raise ValueError("When group=True, group_variables must be a dict")
for group_name, group_values in v.items():
if not isinstance(group_name, str):
raise ValueError("When group=True, each element of group_variables must be a list")
for variable in group:
for variable in group_values:
if not isinstance(variable, str):
raise ValueError("Each element inside group_variables lists must be a string")
return v

View File

@@ -50,6 +50,7 @@ class VariableAggregatorNode(BaseNode):
continue
if value is not None:
logger.info(f"Node: {self.node_id} variable aggregation result: {value}")
return value
logger.info("No variable found in non-group mode; returning empty string.")
@@ -59,7 +60,7 @@ class VariableAggregatorNode(BaseNode):
# Group mode
# --------------------------
result = {}
for group_name, variables in zip(self.typed_config.group_names, self.typed_config.group_variables):
for group_name, variables in self.typed_config.group_variables.items():
for variable in variables:
var_express = self._get_express(variable)
try:
@@ -74,5 +75,5 @@ class VariableAggregatorNode(BaseNode):
else:
result[group_name] = ""
logger.info(f"No variable found for group '{group_name}'; set empty string.")
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
return result

View File

@@ -7,14 +7,87 @@
import logging
from typing import Any, Union
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__)
class WorkflowValidator:
"""工作流配置验证器"""
@staticmethod
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
@classmethod
def pure_cycle_graph(cls, workflow_config: Union[dict[str, Any], Any], node_id) -> tuple[list, list]:
"""
Extract cycle nodes and internal edges from the workflow configuration,
removing them from the global workflow.
Raises:
ValueError: If cycle nodes are connected to external nodes improperly.
Returns:
Tuple containing:
- cycle_nodes: List of removed nodes
- cycle_edges: List of removed edges
"""
nodes = workflow_config.get("nodes", [])
edges = workflow_config.get("edges", [])
# Select all nodes that belong to the current cycle
cycle_nodes = [node for node in nodes if node.get("cycle") == node_id]
cycle_node_ids = {node.get("id") for node in cycle_nodes}
cycle_edges = []
remain_edges = []
for edge in edges:
source_in = edge.get("source") in cycle_node_ids
target_in = edge.get("target") in cycle_node_ids
# Raise error if cycle nodes are connected with external nodes
if source_in ^ target_in:
raise ValueError(
f"Cycle node is connected to external node, "
f"source: {edge.get('source')}, target: {edge.get('target')}"
)
if source_in and target_in:
cycle_edges.append(edge)
else:
remain_edges.append(edge)
# Update workflow_config by removing cycle nodes and internal edges
workflow_config["nodes"] = [
node for node in nodes if node.get("cycle") != node_id
]
workflow_config["edges"] = remain_edges
return cycle_nodes, cycle_edges
@classmethod
def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list:
if not isinstance(workflow_config, dict):
workflow_config = {
"nodes": workflow_config.nodes,
"edges": workflow_config.edges,
"variables": workflow_config.variables,
}
cycle_nodes = [
node.get("id")
for node in workflow_config.get("nodes", [])
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
]
graphs = []
for cycle_node in cycle_nodes:
nodes, edges = cls.pure_cycle_graph(workflow_config, cycle_node)
graphs.append({
"nodes": nodes,
"edges": edges,
})
graphs.append(workflow_config)
return graphs
@classmethod
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
"""验证工作流配置
Args:
@@ -38,84 +111,79 @@ class WorkflowValidator:
True
"""
errors = []
# 支持字典和 Pydantic 模型
if isinstance(workflow_config, dict):
nodes = workflow_config.get("nodes", [])
edges = workflow_config.get("edges", [])
variables = workflow_config.get("variables", [])
else:
# Pydantic 模型
nodes = getattr(workflow_config, "nodes", [])
edges = getattr(workflow_config, "edges", [])
variables = getattr(workflow_config, "variables", [])
# 1. 验证 start 节点(有且只有一个)
start_nodes = [n for n in nodes if n.get("type") == "start"]
if len(start_nodes) == 0:
errors.append("工作流必须有一个 start 节点")
elif len(start_nodes) > 1:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
# 2. 验证 end 节点(至少一个)
end_nodes = [n for n in nodes if n.get("type") == "end"]
if len(end_nodes) == 0:
errors.append("工作流必须至少有一个 end 节点")
# 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes]
if len(node_ids) != len(set(node_ids)):
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
# 4. 验证节点必须有 id 和 type
for i, node in enumerate(nodes):
if not node.get("id"):
errors.append(f"节点 #{i} 缺少 id 字段")
if not node.get("type"):
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
# 5. 验证边的有效性
node_id_set = set(node_ids)
for i, edge in enumerate(edges):
source = edge.get("source")
target = edge.get("target")
if not source:
errors.append(f"边 #{i} 缺少 source 字段")
elif source not in node_id_set:
errors.append(f"边 #{i}source 节点不存在: {source}")
if not target:
errors.append(f"边 #{i} 缺少 target 字段")
elif target not in node_id_set:
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes(
start_nodes[0]["id"],
edges
)
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle:
errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
graphs = cls.get_subgraph(workflow_config)
logger.info(graphs)
for graph in graphs:
nodes = graph.get("nodes", [])
edges = graph.get("edges", [])
variables = graph.get("variables", [])
# 1. 验证 start 节点(有且只有一个)
start_nodes = [n for n in nodes if n.get("type") in [NodeType.START, NodeType.CYCLE_START]]
if len(start_nodes) == 0:
errors.append("工作流必须有一个 start 节点")
elif len(start_nodes) > 1:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
# 2. 验证 end 节点(至少一个)
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
if len(end_nodes) == 0:
errors.append("工作流必须至少有一个 end 节点")
# 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes]
if len(node_ids) != len(set(node_ids)):
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
# 4. 验证节点必须有 id 和 type
for i, node in enumerate(nodes):
if not node.get("id"):
errors.append(f"节点 #{i} 缺少 id 字段")
if not node.get("type"):
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
# 5. 验证边的有效性
node_id_set = set(node_ids)
for i, edge in enumerate(edges):
source = edge.get("source")
target = edge.get("target")
if not source:
errors.append(f"边 #{i} 缺少 source 字段")
elif source not in node_id_set:
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
if not target:
errors.append(f"边 #{i} 缺少 target 字段")
elif target not in node_id_set:
errors.append(f"边 #{i}target 节点不存在: {target}")
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes(
start_nodes[0]["id"],
edges
)
# 8. 验证变量名
from app.core.workflow.expression_evaluator import ExpressionEvaluator
var_errors = ExpressionEvaluator.validate_variable_names(variables)
errors.extend(var_errors)
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle:
errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
)
# 8. 验证变量名
from app.core.workflow.expression_evaluator import ExpressionEvaluator
var_errors = ExpressionEvaluator.validate_variable_names(variables)
errors.extend(var_errors)
return len(errors) == 0, errors
@staticmethod
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
"""获取从 start 节点可达的所有节点
@@ -129,7 +197,7 @@ class WorkflowValidator:
"""
reachable = {start_id}
queue = [start_id]
while queue:
current = queue.pop(0)
for edge in edges:
@@ -138,9 +206,9 @@ class WorkflowValidator:
if target and target not in reachable:
reachable.add(target)
queue.append(target)
return reachable
@staticmethod
def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]:
"""检测是否存在循环依赖DFS
@@ -154,39 +222,39 @@ class WorkflowValidator:
"""
# 排除 loop 类型的节点
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
# 构建邻接表(排除 loop 节点的边和错误边)
graph: dict[str, list[str]] = {}
for edge in edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
# 跳过错误边
if edge_type == "error":
continue
# 如果涉及 loop 节点,跳过
if source in loop_nodes or target in loop_nodes:
continue
if source and target:
if source not in graph:
graph[source] = []
graph[source].append(target)
# DFS 检测环
visited = set()
rec_stack = set()
path = []
cycle_path = []
def dfs(node: str) -> bool:
"""DFS 检测环,返回是否找到环"""
visited.add(node)
rec_stack.add(node)
path.append(node)
for neighbor in graph.get(node, []):
if neighbor not in visited:
if dfs(neighbor):
@@ -196,19 +264,19 @@ class WorkflowValidator:
cycle_start = path.index(neighbor)
cycle_path.extend([*path[cycle_start:], neighbor])
return True
rec_stack.remove(node)
path.pop()
return False
# 检查所有节点
for node_id in graph:
if node_id not in visited:
if dfs(node_id):
return True, cycle_path
return False, []
@staticmethod
def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]:
"""验证工作流配置是否可以发布(更严格的验证)
@@ -221,30 +289,30 @@ class WorkflowValidator:
"""
# 先执行基础验证
is_valid, errors = WorkflowValidator.validate(workflow_config)
if not is_valid:
return False, errors
# 额外的发布验证
nodes = workflow_config.get("nodes", [])
# 1. 验证所有节点都有名称
for node in nodes:
if node.get("type") not in ["start", "end"] and not node.get("name"):
if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
errors.append(
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
)
# 2. 验证所有非 start/end 节点都有配置
for node in nodes:
node_type = node.get("type")
if node_type not in ["start", "end"]:
if node_type not in [NodeType.START, NodeType.CYCLE_START, NodeType.END, NodeType.BREAK]:
config = node.get("config")
if not config or not isinstance(config, dict):
errors.append(
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
)
# 3. 验证必填变量
variables = workflow_config.get("variables", [])
required_vars = [v for v in variables if v.get("required")]
@@ -254,13 +322,13 @@ class WorkflowValidator:
f"工作流包含 {len(required_vars)} 个必填变量: "
f"{[v.get('name') for v in required_vars]}"
)
return len(errors) == 0, errors
def validate_workflow_config(
workflow_config: dict[str, Any],
for_publish: bool = False
workflow_config: dict[str, Any],
for_publish: bool = False
) -> tuple[bool, list[str]]:
"""验证工作流配置(便捷函数)

View File

@@ -198,19 +198,22 @@ class VariablePool:
namespace = selector[0]
if namespace != "conv":
raise ValueError("只能设置会话变量 (conv.*)")
if namespace != "conv" and namespace not in self.state["cycle_nodes"]:
raise ValueError("Only conversation or cycle variables can be assigned.")
key = selector[1]
# 确保 variables 结构存在
if "variables" not in self.state:
self.state["variables"] = {"sys": {}, "conv": {}}
if "conv" not in self.state["variables"]:
self.state["variables"]["conv"] = {}
# 设置值
self.state["variables"]["conv"][key] = value
if namespace == "conv":
if "conv" not in self.state["variables"]:
self.state["variables"]["conv"] = {}
# 设置值
self.state["variables"]["conv"][key] = value
elif namespace in self.state["cycle_nodes"]:
self.state["runtime_vars"][namespace][key] = value
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")

View File

@@ -26,6 +26,7 @@ class Document(Base):
"html4excel": False,
"graphrag": {
"use_graphrag": False,
"scene_name": "",
"entity_types": [
"organization",
"person",
@@ -33,7 +34,9 @@ class Document(Base):
"event",
"category"
],
"method": "general"
"method": "general",
"resolution": True,
"community": True
}
}, comment="default parser config")
chunk_num = Column(Integer, default=0, comment="chunk num")

View File

@@ -65,6 +65,7 @@ class Knowledge(Base):
"html4excel": False,
"graphrag": {
"use_graphrag": False,
"scene_name": "",
"entity_types": [
"organization",
"person",
@@ -72,7 +73,9 @@ class Knowledge(Base):
"event",
"category"
],
"method": "general"
"method": "general",
"resolution": True,
"community": True
}
},
comment="default parser config")

View File

@@ -58,13 +58,13 @@ def get_chunked_knowledgeids(
) -> list:
"""
Query the list of vectorized knowledge base IDs
Return: list[UUID] - List of knowledge base IDs
Return: list[(id,workspace_id)] - List of knowledge base id and workspace_id
"""
db_logger.debug(f"Query the list of vectorized knowledge base IDs: filters_count={len(filters)}")
try:
# Only query the id field
query = db.query(Knowledge.id)
query = db.query(Knowledge.id, Knowledge.workspace_id)
# Apply filter conditions
for filter_cond in filters:
@@ -74,8 +74,8 @@ def get_chunked_knowledgeids(
items = query.all()
db_logger.info(f"Querying the vectorized knowledge base id list succeeded: count={len(items)}")
# Return the list of IDs directly. Since only the ID field is queried, the returned data is a single column
return [item[0] for item in items]
# Return the list of ID and workspace_id directly. Since only the ID and workspace_id field is queried
return items
except Exception as e:
db_logger.error(f"Querying the vectorized knowledge base id list failed: {str(e)}")
raise

View File

@@ -61,14 +61,14 @@ def get_source_kb_ids_by_target_kb_id(
) -> list:
"""
Query the original knowledge base ID list by sharing the knowledge base
Return: list[UUID] - List of knowledge base IDs
Return: list[(source_kb_id,source_workspace_id)] - List of knowledge base source_kb_id and source_workspace_id
"""
db_logger.debug(
f"Query the original knowledge base id list by sharing the knowledge base: filters_count={len(filters)}")
try:
# Only query the id field
query = db.query(KnowledgeShare.source_kb_id)
query = db.query(KnowledgeShare.source_kb_id, KnowledgeShare.source_workspace_id)
# Apply filter conditions
for filter_cond in filters:
@@ -78,8 +78,8 @@ def get_source_kb_ids_by_target_kb_id(
items = query.all()
db_logger.info(f"Successfully queried the original knowledge base ID list by sharing the knowledge base: count={len(items)}")
# Return the list of IDs directly. Since only the ID field is queried, the returned data is a single column
return [item[0] for item in items]
# Return the list of source_kb_id and source_workspace_id directly. Since only the source_kb_id and source_workspace_id field is queried
return items
except Exception as e:
db_logger.error(f"Failed to query the original knowledge base ID list through knowledge base sharing: {str(e)}")
raise

View File

@@ -32,6 +32,7 @@ class KnowledgeRetrievalConfig(BaseModel):
)
reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID")
reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数")
use_graph: bool = Field(default=False, description="是否使用图搜索")
class ToolConfig(BaseModel):

View File

@@ -10,6 +10,7 @@ class RetrieveType(StrEnum):
PARTICIPLE = "participle"
SEMANTIC = "semantic"
HYBRID = "hybrid"
Graph = "graph"
class ChunkCreate(BaseModel):

View File

@@ -12,8 +12,8 @@ class Memory_Reflection(BaseModel):
config_id: Optional[int] = None
reflection_enabled: bool
reflection_period_in_hours: str
reflexion_range: str
baseline: str
reflexion_range: Optional[str] = "partial"
baseline: Optional[str] = "TIME"
reflection_model_id: str
memory_verify: bool
quality_assessment: bool

View File

@@ -20,6 +20,7 @@ class NodeDefinition(BaseModel):
id: str = Field(..., description="节点唯一标识")
type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code")
name: str | None = Field(None, description="节点名称")
cycle: str | None = Field(None, description="父循环节点id")
description: str | None = Field(None, description="节点描述")
config: dict[str, Any] = Field(default_factory=dict, description="节点配置")
position: dict[str, float] | None = Field(None, description="节点位置 {x, y}")

View File

@@ -42,7 +42,7 @@ nodes:
- 适当使用格式化(如列表、段落)提高可读性
- role: user
content: "{{ sys.message }}"
content: "{{sys.message}}"
model_id: null
temperature: 0.7
@@ -55,7 +55,7 @@ nodes:
type: end
name: 结束
config:
output: "{{ llm_qa.output }}"
output: "{{llm_qa.output}}"
position:
x: 900
y: 100

View File

@@ -135,6 +135,8 @@ dependencies = [
"graspologic==3.4.5.dev2",
"markdown-to-json==2.1.1",
"valkey==6.0.2",
"python-calamine>=0.4.0",
"xlrd==2.0.2"
]
[tool.pytest.ini_options]

View File

@@ -129,3 +129,5 @@ editdistance==0.8.1
graspologic==3.4.5.dev2
markdown-to-json==2.1.1
valkey==6.0.2
python-calamine>=0.4.0
xlrd==2.0.2