Merge branch 'refs/heads/develop' into feature/20251219_xjn
This commit is contained in:
@@ -18,6 +18,9 @@ from app.schemas.response_schema import ApiResponse
|
||||
from app.core.response_utils import success
|
||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
@@ -389,36 +392,41 @@ async def retrieve_chunks(
|
||||
knowledge_model.Knowledge.chunk_num > 0,
|
||||
knowledge_model.Knowledge.status == 1
|
||||
]
|
||||
existing_ids = knowledge_service.get_chunded_knowledgeids(
|
||||
private_items = knowledge_service.get_chunded_knowledgeids(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
private_kb_ids = [item[0] for item in private_items]
|
||||
private_workspace_ids = [item[1] for item in private_items]
|
||||
filters = [
|
||||
knowledge_model.Knowledge.id.in_(retrieve_data.kb_ids),
|
||||
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
|
||||
knowledge_model.Knowledge.chunk_num > 0,
|
||||
knowledge_model.Knowledge.status == 1
|
||||
]
|
||||
share_ids = knowledge_service.get_chunded_knowledgeids(
|
||||
items = knowledge_service.get_chunded_knowledgeids(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
if share_ids:
|
||||
if items:
|
||||
filters = [
|
||||
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(retrieve_data.kb_ids)
|
||||
]
|
||||
items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
|
||||
share_items = knowledgeshare_service.get_source_kb_ids_by_target_kb_id(
|
||||
db=db,
|
||||
filters=filters,
|
||||
current_user=current_user
|
||||
)
|
||||
existing_ids.extend(items)
|
||||
if not existing_ids:
|
||||
share_kb_ids = [item[0] for item in share_items]
|
||||
share_workspace_ids = [item[1] for item in share_items]
|
||||
private_kb_ids.extend(share_kb_ids)
|
||||
private_workspace_ids.extend(share_workspace_ids)
|
||||
if not private_kb_ids:
|
||||
return success(data=[], msg="retrieval successful")
|
||||
kb_id = existing_ids[0]
|
||||
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
|
||||
kb_id = private_kb_ids[0]
|
||||
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in private_kb_ids]
|
||||
indices = ",".join(uuid_strs)
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
@@ -448,4 +456,21 @@ async def retrieve_chunks(
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
|
||||
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
||||
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
||||
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
||||
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
)
|
||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
rs.insert(0, doc)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
@@ -86,5 +86,5 @@
|
||||
- **quality_assessment**:
|
||||
quality_assessment=true时输出评估对象,否则为null(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
|
||||
- **memory_verify**: memory_verify=true时输出隐私检测对象,否则为null
|
||||
(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
|
||||
(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true\memory_verify=false)等原数据字段以及涉及需要修改的字段以及内容)
|
||||
模式参考:{{ json_schema }}
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Type, Union
|
||||
from copy import deepcopy
|
||||
from urllib.parse import urlparse
|
||||
@@ -8,8 +7,10 @@ from langchain_core.callbacks import Callbacks
|
||||
from app.core.models.base import RedBearModelConfig, get_provider_rerank_class, RedBearModelFactory
|
||||
from app.models import ModelProvider
|
||||
|
||||
|
||||
class RedBearRerank(BaseDocumentCompressor):
|
||||
""" Rerank → 作为 Runnable 插入任意 LCEL 链"""
|
||||
|
||||
def __init__(self, config: RedBearModelConfig):
|
||||
self._model = self._create_model(config)
|
||||
self._config = config
|
||||
@@ -22,10 +23,10 @@ class RedBearRerank(BaseDocumentCompressor):
|
||||
return model_class(**model_params)
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""
|
||||
Compress documents using Jina's Rerank API.
|
||||
@@ -46,17 +47,17 @@ class RedBearRerank(BaseDocumentCompressor):
|
||||
compressed.append(doc_copy)
|
||||
return compressed
|
||||
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
documents: Sequence[Union[str, Document, dict]],
|
||||
query: str,
|
||||
*,
|
||||
top_n: Optional[int] = -1,
|
||||
) -> List[Dict[str, Any]]:
|
||||
provider = self._config.provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
self,
|
||||
documents: Sequence[Union[str, Document, dict]],
|
||||
query: str,
|
||||
*,
|
||||
top_n: Optional[int] = -1,
|
||||
) -> List[Dict[str, Any]]:
|
||||
provider = self._config.provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
import langchain_community.document_compressors.jina_rerank as jina_mod
|
||||
|
||||
# 规范化:如果不以 /v1/rerank 结尾,则补齐;若已以 /v1 结尾,则补 /rerank
|
||||
def _normalize_jina_base(base_url: Optional[str]) -> Optional[str]:
|
||||
if not base_url:
|
||||
@@ -73,8 +74,7 @@ class RedBearRerank(BaseDocumentCompressor):
|
||||
# 设置完整的 rerank 端点,例如 http://host:port/v1/rerank
|
||||
jina_mod.JINA_API_URL = jina_base
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
model_instance : JinaRerank = self._model
|
||||
return model_instance.rerank(documents = documents, query = query, top_n=top_n)
|
||||
model_instance: JinaRerank = self._model
|
||||
return model_instance.rerank(documents=documents, query=query, top_n=top_n)
|
||||
else:
|
||||
raise ValueError(f"不支持的模型提供商: {provider}")
|
||||
|
||||
@@ -51,7 +51,7 @@ def chunk(filename, binary, lang, callback=None, vision_model=None, **kwargs):
|
||||
img_binary = io.BytesIO()
|
||||
img.save(img_binary, format="JPEG")
|
||||
img_binary.seek(0)
|
||||
ans = vision_model.describe(img_binary.read())
|
||||
ans, ans_num_tokens = vision_model.describe(img_binary.read())
|
||||
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
|
||||
txt += "\n" + ans
|
||||
tokenize(doc, txt, eng)
|
||||
|
||||
@@ -42,11 +42,14 @@ class RAGExcelParser:
|
||||
file_like_object.seek(0)
|
||||
try:
|
||||
dfs = pd.read_excel(file_like_object, sheet_name=None)
|
||||
if isinstance(dfs, dict):
|
||||
dfs = next(iter(dfs.values()))
|
||||
return RAGExcelParser._dataframe_to_workbook(dfs)
|
||||
except Exception as ex:
|
||||
logging.info(f"pandas with default engine load error: {ex}, try calamine instead")
|
||||
file_like_object.seek(0)
|
||||
df = pd.read_excel(file_like_object, engine="calamine")
|
||||
print(df)
|
||||
return RAGExcelParser._dataframe_to_workbook(df)
|
||||
except Exception as e_pandas:
|
||||
raise Exception(f"pandas.read_excel error: {e_pandas}, original openpyxl error: {e}")
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
import json_repair
|
||||
import pandas as pd
|
||||
import time
|
||||
import trio
|
||||
|
||||
from app.core.rag.common.misc_utils import get_uuid
|
||||
@@ -262,21 +263,21 @@ class KGSearch(Dealer):
|
||||
relas = ""
|
||||
|
||||
return {
|
||||
"chunk_id": get_uuid(),
|
||||
"content_ltks": "",
|
||||
"page_content": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms,
|
||||
comm_topn, max_token),
|
||||
"page_content": ents + relas + self._community_retrieval_([n for n, _ in ents_from_query], filters, kb_ids, idxnms, comm_topn, max_token),
|
||||
"vector": None,
|
||||
"metadata": {
|
||||
"doc_id": get_uuid(),
|
||||
"file_id": "",
|
||||
"file_name": "Related content in Knowledge Graph",
|
||||
"file_created_at": int(time.time() * 1000),
|
||||
"document_id": "",
|
||||
"docnm_kwd": "Related content in Knowledge Graph",
|
||||
"kb_id": kb_ids,
|
||||
"important_kwd": [],
|
||||
"image_id": "",
|
||||
"similarity": 1.,
|
||||
"vector_similarity": 1.,
|
||||
"term_similarity": 0,
|
||||
"vector": [],
|
||||
"positions": [],
|
||||
}
|
||||
"knowledge_id": kb_ids,
|
||||
"sort_id": 0,
|
||||
"status": 1,
|
||||
"score": 1
|
||||
},
|
||||
"children": None
|
||||
}
|
||||
|
||||
def _community_retrieval_(self, entities, condition, kb_ids, idxnms, topn, max_token):
|
||||
## Community retrieval
|
||||
|
||||
@@ -26,6 +26,8 @@ from app.core.rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr
|
||||
from app.core.rag.common.string_utils import remove_redundant_spaces
|
||||
from app.core.rag.common.float_utils import get_float
|
||||
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
|
||||
|
||||
def knowledge_retrieval(
|
||||
@@ -48,6 +50,7 @@ def knowledge_retrieval(
|
||||
- merge_strategy: "weight" or other strategies
|
||||
- reranker_id: UUID of the reranker to use
|
||||
- reranker_top_k: int
|
||||
- use_graph: bool, whether to use a graph
|
||||
|
||||
Returns:
|
||||
Rearranged document block list (in descending order of relevance)
|
||||
@@ -59,6 +62,7 @@ def knowledge_retrieval(
|
||||
merge_strategy = config.get("merge_strategy", "weight")
|
||||
reranker_id = config.get("reranker_id")
|
||||
reranker_top_k = config.get("reranker_top_k", 1024)
|
||||
use_graph = config.get("use_graph", "false").lower() == "true"
|
||||
|
||||
file_names_filter = []
|
||||
if user_ids:
|
||||
@@ -67,6 +71,10 @@ def knowledge_retrieval(
|
||||
if not knowledge_bases:
|
||||
return []
|
||||
|
||||
kb_ids = []
|
||||
workspace_ids = []
|
||||
chat_model = None
|
||||
embedding_model = None
|
||||
all_results = []
|
||||
# Search each knowledge base
|
||||
for kb_config in knowledge_bases:
|
||||
@@ -87,6 +95,22 @@ def knowledge_retrieval(
|
||||
else:
|
||||
continue
|
||||
|
||||
if str(db_knowledge.id) not in kb_ids:
|
||||
kb_ids.append(str(db_knowledge.id))
|
||||
if str(db_knowledge.workspace_id) not in workspace_ids:
|
||||
workspace_ids.append(str(db_knowledge.workspace_id))
|
||||
if not chat_model:
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
)
|
||||
if not embedding_model:
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
)
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
# Retrieve according to the configured retrieval type
|
||||
match kb_config["retrieve_type"]:
|
||||
@@ -136,6 +160,12 @@ def knowledge_retrieval(
|
||||
# Use the specified reranker for re-ranking
|
||||
if reranker_id:
|
||||
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||
# use graph
|
||||
if use_graph:
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
all_results.insert(0, doc)
|
||||
return all_results
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -213,7 +213,7 @@ class ESConnection(DocStoreConnection):
|
||||
m.topn * 2,
|
||||
query_vector=list(m.embedding_data),
|
||||
filter=bqry.to_dict(),
|
||||
similarity=similarity,
|
||||
# similarity=similarity
|
||||
)
|
||||
|
||||
if bqry and rank_feature:
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
基于 LangGraph 的工作流执行引擎。
|
||||
"""
|
||||
|
||||
import logging
|
||||
# import uuid
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -107,7 +107,13 @@ class WorkflowExecutor:
|
||||
"user_id": self.user_id,
|
||||
"error": None,
|
||||
"error_node": None,
|
||||
"streaming_buffer": {} # 流式缓冲区
|
||||
"streaming_buffer": {}, # 流式缓冲区
|
||||
"cycle_nodes": [
|
||||
node.get("id")
|
||||
for node in self.workflow_config.get("nodes")
|
||||
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
|
||||
], # loop, iteration node id
|
||||
"looping": False # loop runing flag, only use in loop node,not use in main loop
|
||||
}
|
||||
|
||||
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
||||
@@ -199,6 +205,10 @@ class WorkflowExecutor:
|
||||
for node in self.nodes:
|
||||
node_type = node.get("type")
|
||||
node_id = node.get("id")
|
||||
cycle_node = node.get("cycle")
|
||||
if cycle_node:
|
||||
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
|
||||
continue
|
||||
|
||||
# 记录 start 和 end 节点 ID
|
||||
if node_type == NodeType.START:
|
||||
@@ -271,7 +281,7 @@ class WorkflowExecutor:
|
||||
workflow.add_edge(START, start_node_id)
|
||||
logger.debug(f"添加边: START -> {start_node_id}")
|
||||
|
||||
for edge in self.edges:
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
edge_type = edge.get("type")
|
||||
@@ -284,12 +294,12 @@ class WorkflowExecutor:
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
|
||||
# 处理到 end 节点的边
|
||||
if target in end_node_ids:
|
||||
# 连接到 end 节点
|
||||
workflow.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
# # 处理到 end 节点的边
|
||||
# if target in end_node_ids:
|
||||
# # 连接到 end 节点
|
||||
# workflow.add_edge(source, target)
|
||||
# logger.debug(f"添加边: {source} -> {target}")
|
||||
# continue
|
||||
|
||||
# 跳过错误边(在节点内部处理)
|
||||
if edge_type == "error":
|
||||
@@ -297,22 +307,30 @@ class WorkflowExecutor:
|
||||
|
||||
if condition:
|
||||
# 条件边
|
||||
def router(state: WorkflowState, cond=condition, tgt=target):
|
||||
"""条件路由函数"""
|
||||
if evaluate_condition(
|
||||
cond,
|
||||
state.get("variables", {}),
|
||||
state.get("node_outputs", {}),
|
||||
{
|
||||
"execution_id": state.get("execution_id"),
|
||||
"workspace_id": state.get("workspace_id"),
|
||||
"user_id": state.get("user_id")
|
||||
}
|
||||
):
|
||||
return tgt
|
||||
return END # 条件不满足,结束
|
||||
def make_router(cond, tgt):
|
||||
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
|
||||
|
||||
workflow.add_conditional_edges(source, router)
|
||||
|
||||
def router_fn(state: WorkflowState):
|
||||
if evaluate_condition(
|
||||
cond,
|
||||
state.get("variables", {}),
|
||||
state.get("node_outputs", {}),
|
||||
{
|
||||
"execution_id": state.get("execution_id"),
|
||||
"workspace_id": state.get("workspace_id"),
|
||||
"user_id": state.get("user_id")
|
||||
}
|
||||
):
|
||||
return tgt
|
||||
return END
|
||||
|
||||
# 动态修改函数名,避免重复
|
||||
router_fn.__name__ = f"router_{tgt}"
|
||||
return router_fn
|
||||
|
||||
router_fn = make_router(condition, target)
|
||||
workflow.add_conditional_edges(source, router_fn)
|
||||
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
||||
else:
|
||||
# 普通边
|
||||
|
||||
@@ -74,6 +74,7 @@ class ExpressionEvaluator:
|
||||
# 为了向后兼容,也支持直接访问(但会在日志中警告)
|
||||
context.update(variables)
|
||||
context["nodes"] = node_outputs
|
||||
context.update(node_outputs)
|
||||
|
||||
try:
|
||||
# simpleeval 只支持安全的操作:
|
||||
|
||||
@@ -6,7 +6,7 @@ from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||
from app.core.workflow.nodes.operators import AssignmentOperatorInstance
|
||||
from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -40,8 +40,8 @@ class AssignerNode(BaseNode):
|
||||
variable_selector = expression.split('.')
|
||||
|
||||
# Only conversation variables ('conv') are allowed
|
||||
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
|
||||
raise ValueError("Only conversation variables can be assigned.")
|
||||
if variable_selector[0] != 'conv' and variable_selector[0] not in state["cycle_nodes"]:
|
||||
raise ValueError("Only conversation or cycle variables can be assigned.")
|
||||
|
||||
# Get the value or expression to assign
|
||||
value = assignment.value
|
||||
@@ -55,7 +55,9 @@ class AssignerNode(BaseNode):
|
||||
)
|
||||
|
||||
# Select the appropriate assignment operator instance based on the target variable type
|
||||
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
|
||||
operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value(
|
||||
pool.get(variable_selector)
|
||||
)(
|
||||
pool, variable_selector, value
|
||||
)
|
||||
|
||||
@@ -81,3 +83,5 @@ class AssignerNode(BaseNode):
|
||||
operator.remove_last()
|
||||
case _:
|
||||
raise ValueError(f"Invalid Operator: {assignment.operation}")
|
||||
logger.info(f"Node {self.node_id}: execution completed")
|
||||
|
||||
|
||||
@@ -14,9 +14,13 @@ class VariableType(StrEnum):
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
|
||||
|
||||
class VariableDefinition(BaseModel):
|
||||
"""变量定义
|
||||
|
||||
@@ -20,40 +20,44 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowState(TypedDict):
|
||||
"""工作流状态
|
||||
|
||||
在节点间传递的状态对象,包含消息、变量、节点输出等信息。
|
||||
"""Workflow state
|
||||
|
||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||
"""
|
||||
# 消息列表(追加模式)
|
||||
# List of messages (append mode)
|
||||
messages: Annotated[list[AnyMessage], add]
|
||||
|
||||
# 输入变量(从配置的 variables 传入)
|
||||
# 使用深度合并函数,支持嵌套字典的更新(如 conv.xxx)
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
looping: bool
|
||||
|
||||
# Input variables (passed from configured variables)
|
||||
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
||||
variables: Annotated[dict[str, Any], lambda x, y: {
|
||||
**x,
|
||||
**{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v
|
||||
for k, v in y.items()}
|
||||
}]
|
||||
|
||||
# 节点输出(存储每个节点的执行结果,用于变量引用)
|
||||
# 使用自定义合并函数,将新的节点输出合并到现有字典中
|
||||
|
||||
# Node outputs (stores execution results of each node for variable references)
|
||||
# Uses a custom merge function to combine new node outputs into the existing dictionary
|
||||
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# 运行时节点变量(简化版,只存储业务数据,供节点间快速访问)
|
||||
# 格式:{node_id: business_result}
|
||||
|
||||
# Runtime node variables (simplified version, stores business data for fast access between nodes)
|
||||
# Format: {node_id: business_result}
|
||||
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# 执行上下文
|
||||
# Execution context
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
|
||||
# 错误信息(用于错误边)
|
||||
# Error information (for error edges)
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
# 流式缓冲区(存储节点的实时流式输出)
|
||||
# 格式:{node_id: {"chunks": [...], "full_content": "..."}}
|
||||
|
||||
# Streaming buffer (stores real-time streaming output of nodes)
|
||||
# Format: {node_id: {"chunks": [...], "full_content": "..."}}
|
||||
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
|
||||
@@ -74,6 +78,7 @@ class BaseNode(ABC):
|
||||
self.workflow_config = workflow_config
|
||||
self.node_id = node_config["id"]
|
||||
self.node_type = node_config["type"]
|
||||
self.cycle = node_config.get("cycle")
|
||||
self.node_name = node_config.get("name", self.node_id)
|
||||
# 使用 or 运算符处理 None 值
|
||||
self.config = node_config.get("config") or {}
|
||||
@@ -170,10 +175,10 @@ class BaseNode(ABC):
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
timeout = self.get_timeout()
|
||||
|
||||
try:
|
||||
timeout = self.get_timeout()
|
||||
|
||||
# 调用节点的业务逻辑
|
||||
business_result = await asyncio.wait_for(
|
||||
self.execute(state),
|
||||
@@ -200,7 +205,8 @@ class BaseNode(ABC):
|
||||
**wrapped_output,
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
}
|
||||
},
|
||||
"looping": state["looping"]
|
||||
}
|
||||
|
||||
except TimeoutError:
|
||||
@@ -236,10 +242,10 @@ class BaseNode(ABC):
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
timeout = self.get_timeout()
|
||||
|
||||
try:
|
||||
timeout = self.get_timeout()
|
||||
|
||||
# Get LangGraph's stream writer for sending custom data
|
||||
writer = get_stream_writer()
|
||||
|
||||
|
||||
3
api/app/core/workflow/nodes/breaker/__init__.py
Normal file
3
api/app/core/workflow/nodes/breaker/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.core.workflow.nodes.breaker.node import BreakNode
|
||||
|
||||
__all__ = ["BreakNode"]
|
||||
33
api/app/core/workflow/nodes/breaker/node.py
Normal file
33
api/app/core/workflow/nodes/breaker/node.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BreakNode(BaseNode):
|
||||
"""
|
||||
Workflow node that immediately stops loop execution.
|
||||
|
||||
When executed, this node sets the 'looping' flag in the workflow state
|
||||
to False, signaling the outer loop runtime to terminate further iterations.
|
||||
"""
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
Execute the break node.
|
||||
|
||||
Args:
|
||||
state: Current workflow state, including loop control flags.
|
||||
|
||||
Effects:
|
||||
- Sets 'looping' in the state to False to stop the loop.
|
||||
- Logs the action for debugging purposes.
|
||||
|
||||
Returns:
|
||||
Optional dictionary indicating the loop has been stopped.
|
||||
"""
|
||||
state["looping"] = False
|
||||
logger.info(f"Setting cycle node exit flag, cycle={self.cycle}, looping={state['looping']}")
|
||||
|
||||
@@ -22,6 +22,7 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregato
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||
|
||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||
__all__ = [
|
||||
# 基础类
|
||||
"BaseNodeConfig",
|
||||
@@ -41,5 +42,7 @@ __all__ = [
|
||||
"JinjaRenderNodeConfig",
|
||||
"VariableAggregatorNodeConfig",
|
||||
"ParameterExtractorNodeConfig",
|
||||
"LoopNodeConfig",
|
||||
"IterationNodeConfig",
|
||||
"QuestionClassifierNodeConfig"
|
||||
]
|
||||
|
||||
4
api/app/core/workflow/nodes/cycle_graph/__init__.py
Normal file
4
api/app/core/workflow/nodes/cycle_graph/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode
|
||||
|
||||
__all__ = ['CycleGraphNode', 'LoopNodeConfig', 'IterationNodeConfig']
|
||||
96
api/app/core/workflow/nodes/cycle_graph/config.py
Normal file
96
api/app/core/workflow/nodes/cycle_graph/config.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||
|
||||
|
||||
class CycleVariable(BaseNodeConfig):
|
||||
name: str = Field(
|
||||
...,
|
||||
description="Name of the loop variable"
|
||||
)
|
||||
type: VariableType = Field(
|
||||
...,
|
||||
description="Data type of the loop variable"
|
||||
)
|
||||
value: str = Field(
|
||||
...,
|
||||
description="Initial or current value of the loop variable"
|
||||
)
|
||||
|
||||
|
||||
class ConditionDetail(BaseModel):
|
||||
comparison_operator: ComparisonOperator = Field(
|
||||
...,
|
||||
description="Operator used to compare the left and right operands"
|
||||
)
|
||||
|
||||
left: str = Field(
|
||||
...,
|
||||
description="Left-hand operand of the comparison expression"
|
||||
)
|
||||
|
||||
right: str = Field(
|
||||
...,
|
||||
description="Right-hand operand of the comparison expression"
|
||||
)
|
||||
|
||||
|
||||
class ConditionsConfig(BaseModel):
|
||||
"""Configuration for loop condition evaluation"""
|
||||
|
||||
logical_operator: LogicOperator = Field(
|
||||
default=LogicOperator.AND.value,
|
||||
description="Logical operator used to combine multiple condition expressions"
|
||||
)
|
||||
|
||||
expressions: list[ConditionDetail] = Field(
|
||||
...,
|
||||
description="Collection of condition expressions to be evaluated"
|
||||
)
|
||||
|
||||
|
||||
class LoopNodeConfig(BaseNodeConfig):
|
||||
condition: ConditionsConfig = Field(
|
||||
default_factory=list,
|
||||
description="Conditional configuration that controls loop execution"
|
||||
)
|
||||
|
||||
cycle_vars: list[CycleVariable] = Field(
|
||||
default_factory=list,
|
||||
description="List of variables used and updated during the loop"
|
||||
)
|
||||
|
||||
max_loop: int = Field(
|
||||
default=10,
|
||||
description="Maximum number of loop iterations"
|
||||
)
|
||||
|
||||
|
||||
class IterationNodeConfig(BaseNodeConfig):
|
||||
input: str = Field(
|
||||
...,
|
||||
description="Input of the loop iteration"
|
||||
)
|
||||
|
||||
parallel: bool = Field(
|
||||
default=False,
|
||||
description="Whether to execute loop iterations in parallel"
|
||||
)
|
||||
|
||||
parallel_count: int = Field(
|
||||
default=4,
|
||||
description="Number of iterations to run in parallel"
|
||||
)
|
||||
|
||||
flatten: bool = Field(
|
||||
default=False,
|
||||
description="Whether to flatten the output list from iterations"
|
||||
)
|
||||
|
||||
output: str = Field(
|
||||
...,
|
||||
description="Output of the loop iteration"
|
||||
)
|
||||
|
||||
|
||||
154
api/app/core/workflow/nodes/cycle_graph/iteration.py
Normal file
154
api/app/core/workflow/nodes/cycle_graph/iteration.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IterationRuntime:
|
||||
"""
|
||||
Runtime executor for loop/iteration nodes in a workflow.
|
||||
|
||||
This class handles executing iterations over a list variable, supporting
|
||||
optional parallel execution, flattening of output, and loop control via
|
||||
the workflow state.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
graph: CompiledStateGraph,
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
state: WorkflowState,
|
||||
):
|
||||
"""
|
||||
Initialize the iteration runtime.
|
||||
|
||||
Args:
|
||||
graph: Compiled workflow graph capable of async invocation.
|
||||
node_id: Unique identifier of the loop node.
|
||||
config: Dictionary containing iteration node configuration.
|
||||
state: Current workflow state at the point of iteration.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
self.typed_config = IterationNodeConfig(**config)
|
||||
self.looping = True
|
||||
|
||||
self.output_value = None
|
||||
self.result: list = []
|
||||
|
||||
def _init_iteration_state(self, item, idx):
|
||||
"""
|
||||
Initialize a per-iteration copy of the workflow state.
|
||||
|
||||
Args:
|
||||
item: Current element from the input array for this iteration.
|
||||
idx: Index of the element in the input array.
|
||||
|
||||
Returns:
|
||||
A deep copy of the workflow state with iteration-specific variables set.
|
||||
"""
|
||||
loopstate = WorkflowState(
|
||||
**copy.deepcopy(self.state)
|
||||
)
|
||||
loopstate["runtime_vars"][self.node_id] = {
|
||||
"item": item,
|
||||
"index": idx,
|
||||
}
|
||||
loopstate["node_outputs"][self.node_id] = {
|
||||
"item": item,
|
||||
"index": idx,
|
||||
}
|
||||
loopstate["looping"] = True
|
||||
return loopstate
|
||||
|
||||
async def run_task(self, item, idx):
|
||||
"""
|
||||
Execute a single iteration asynchronously.
|
||||
|
||||
Args:
|
||||
item: The input element for this iteration.
|
||||
idx: The index of this iteration.
|
||||
"""
|
||||
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
|
||||
output = VariablePool(result).get(self.output_value)
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
if not result["looping"]:
|
||||
self.looping = False
|
||||
|
||||
def _create_iteration_tasks(self, array_obj, idx):
|
||||
"""
|
||||
Create async tasks for a batch of iterations based on parallel count.
|
||||
|
||||
Args:
|
||||
array_obj: The input array to iterate over.
|
||||
idx: Starting index for this batch of iterations.
|
||||
|
||||
Returns:
|
||||
List of coroutine tasks ready to be executed in parallel.
|
||||
"""
|
||||
tasks = []
|
||||
for i in range(self.typed_config.parallel_count):
|
||||
if idx + i >= len(array_obj):
|
||||
break
|
||||
item = array_obj[idx + i]
|
||||
tasks.append(self.run_task(item, idx + i))
|
||||
return tasks
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
Execute the loop over the input array according to configuration.
|
||||
|
||||
Returns:
|
||||
A list of outputs from all iterations, optionally flattened.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the input variable is not a list.
|
||||
"""
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
input_expression = re.sub(pattern, r"\1", self.typed_config.input).strip()
|
||||
self.output_value = re.sub(pattern, r"\1", self.typed_config.output).strip()
|
||||
|
||||
array_obj = VariablePool(self.state).get(input_expression)
|
||||
if not isinstance(array_obj, list):
|
||||
raise RuntimeError("Cannot iterate over a non-list variable")
|
||||
|
||||
idx = 0
|
||||
if self.typed_config.parallel:
|
||||
# Execute iterations in parallel batches
|
||||
while idx < len(array_obj) and self.looping:
|
||||
tasks = self._create_iteration_tasks(array_obj, idx)
|
||||
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
|
||||
idx += self.typed_config.parallel_count
|
||||
await asyncio.gather(*tasks)
|
||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||
return self.result
|
||||
else:
|
||||
# Execute iterations sequentially
|
||||
while idx < len(array_obj) and self.looping:
|
||||
logger.info(f"Iteration node {self.node_id}: running")
|
||||
item = array_obj[idx]
|
||||
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
|
||||
output = VariablePool(result).get(self.output_value)
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
if not result["looping"]:
|
||||
self.looping = False
|
||||
idx += 1
|
||||
|
||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||
return self.result
|
||||
130
api/app/core/workflow/nodes/cycle_graph/loop.py
Normal file
130
api/app/core/workflow/nodes/cycle_graph/loop.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopRuntime:
|
||||
"""
|
||||
Runtime executor for loop nodes in a workflow.
|
||||
|
||||
Handles iterative execution of a loop node according to defined loop variables
|
||||
and conditional expressions. Supports maximum loop count and loop control
|
||||
through the workflow state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: CompiledStateGraph,
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
state: WorkflowState,
|
||||
):
|
||||
"""
|
||||
Initialize the loop runtime.
|
||||
|
||||
Args:
|
||||
graph: Compiled workflow graph capable of async invocation.
|
||||
node_id: Unique identifier of the loop node.
|
||||
config: Dictionary containing loop node configuration.
|
||||
state: Current workflow state at the point of loop execution.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
self.typed_config = LoopNodeConfig(**config)
|
||||
|
||||
def _init_loop_state(self):
|
||||
"""
|
||||
Initialize workflow state for loop execution.
|
||||
|
||||
- Evaluates initial values of loop variables.
|
||||
- Stores loop variables in runtime_vars and node_outputs.
|
||||
- Marks the loop as active by setting 'looping' to True.
|
||||
|
||||
Returns:
|
||||
A copy of the workflow state prepared for the loop execution.
|
||||
"""
|
||||
pool = VariablePool(self.state)
|
||||
# 循环变量
|
||||
self.state["runtime_vars"][self.node_id] = {
|
||||
variable.name: evaluate_expression(
|
||||
expression=variable.value,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
}
|
||||
self.state["node_outputs"][self.node_id] = {
|
||||
variable.name: evaluate_expression(
|
||||
expression=variable.value,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
}
|
||||
loopstate = WorkflowState(
|
||||
**self.state
|
||||
)
|
||||
loopstate["looping"] = True
|
||||
return loopstate
|
||||
|
||||
def _get_loop_expression(self):
|
||||
"""
|
||||
Build the Python boolean expression for evaluating the loop condition.
|
||||
|
||||
- Converts each condition in the loop configuration into a Python expression string.
|
||||
- Combines multiple conditions with the configured logical operator (AND/OR).
|
||||
|
||||
Returns:
|
||||
A string representing the combined loop condition expression.
|
||||
"""
|
||||
branch_conditions = [
|
||||
ConditionExpressionBuilder(
|
||||
left=condition.left,
|
||||
operator=condition.comparison_operator,
|
||||
right=condition.right
|
||||
).build()
|
||||
for condition in self.typed_config.condition.expressions
|
||||
]
|
||||
if len(branch_conditions) > 1:
|
||||
combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions)
|
||||
else:
|
||||
combined_condition = branch_conditions[0]
|
||||
|
||||
return combined_condition
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
Execute the loop node until the condition is no longer met, the loop is
|
||||
manually stopped, or the maximum loop count is reached.
|
||||
|
||||
Returns:
|
||||
The final runtime variables of this loop node after completion.
|
||||
"""
|
||||
loopstate = self._init_loop_state()
|
||||
expression = self._get_loop_expression()
|
||||
loop_variable_pool = VariablePool(loopstate)
|
||||
loop_time = self.typed_config.max_loop
|
||||
while evaluate_condition(
|
||||
expression=expression,
|
||||
variables=loop_variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=loop_variable_pool.get_all_node_outputs(),
|
||||
system_vars=loop_variable_pool.get_all_system_vars(),
|
||||
) and loopstate["looping"] and loop_time > 0:
|
||||
logger.info(f"loop node {self.node_id}: running")
|
||||
await self.graph.ainvoke(loopstate)
|
||||
loop_time -= 1
|
||||
|
||||
logger.info(f"loop node {self.node_id}: execution completed")
|
||||
return loopstate["runtime_vars"][self.node_id]
|
||||
226
api/app/core/workflow/nodes/cycle_graph/node.py
Normal file
226
api/app/core/workflow/nodes/cycle_graph/node.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
|
||||
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CycleGraphNode(BaseNode):
|
||||
"""
|
||||
Node representing a cycle (loop) subgraph within the workflow.
|
||||
|
||||
This node manages internal loop/iteration nodes, builds a subgraph
|
||||
for execution, handles conditional routing, and executes loop
|
||||
or iteration logic based on node type.
|
||||
"""
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
|
||||
|
||||
self.cycle_nodes = list() # Nodes belonging to this cycle
|
||||
self.cycle_edges = list() # Edges connecting nodes within the cycle
|
||||
self.start_node_id = None # ID of the start node within the cycle
|
||||
self.end_node_ids = [] # IDs of end nodes within the cycle
|
||||
|
||||
self.graph: StateGraph | CompiledStateGraph | None = None
|
||||
self.build_graph()
|
||||
self.iteration_flag = True
|
||||
|
||||
def pure_cycle_graph(self) -> tuple[list, list]:
|
||||
"""
|
||||
Extract cycle nodes and internal edges from the workflow configuration,
|
||||
removing them from the global workflow.
|
||||
|
||||
Raises:
|
||||
ValueError: If cycle nodes are connected to external nodes improperly.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- cycle_nodes: List of removed nodes
|
||||
- cycle_edges: List of removed edges
|
||||
"""
|
||||
nodes = self.workflow_config.get("nodes", [])
|
||||
edges = self.workflow_config.get("edges", [])
|
||||
|
||||
# Select all nodes that belong to the current cycle
|
||||
cycle_nodes = [node for node in nodes if node.get("cycle") == self.node_id]
|
||||
cycle_node_ids = {node.get("id") for node in cycle_nodes}
|
||||
|
||||
cycle_edges = []
|
||||
remain_edges = []
|
||||
|
||||
for edge in edges:
|
||||
source_in = edge.get("source") in cycle_node_ids
|
||||
target_in = edge.get("target") in cycle_node_ids
|
||||
|
||||
# Raise error if cycle nodes are connected with external nodes
|
||||
if source_in ^ target_in:
|
||||
raise ValueError(
|
||||
f"Cycle node is connected to external node, "
|
||||
f"source: {edge.get('source')}, target: {edge.get('target')}"
|
||||
)
|
||||
|
||||
if source_in and target_in:
|
||||
cycle_edges.append(edge)
|
||||
else:
|
||||
remain_edges.append(edge)
|
||||
|
||||
# Update workflow_config by removing cycle nodes and internal edges
|
||||
self.workflow_config["nodes"] = [
|
||||
node for node in nodes if node.get("cycle") != self.node_id
|
||||
]
|
||||
self.workflow_config["edges"] = remain_edges
|
||||
|
||||
return cycle_nodes, cycle_edges
|
||||
|
||||
def create_node(self):
|
||||
"""
|
||||
Instantiate node objects for each node in the cycle subgraph and add them to the graph.
|
||||
|
||||
Special handling is applied for conditional nodes to generate
|
||||
edge conditions based on node outputs.
|
||||
"""
|
||||
from app.core.workflow.nodes import NodeFactory
|
||||
for node in self.cycle_nodes:
|
||||
node_type = node.get("type")
|
||||
node_id = node.get("id")
|
||||
|
||||
if node_type == NodeType.CYCLE_START:
|
||||
self.start_node_id = node_id
|
||||
continue
|
||||
elif node_type == NodeType.END:
|
||||
self.end_node_ids.append(node_id)
|
||||
|
||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||
|
||||
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
|
||||
expressions = node_instance.build_conditional_edge_expressions()
|
||||
|
||||
# Number of branches, usually matches the number of conditional expressions
|
||||
branch_number = len(expressions)
|
||||
|
||||
# Find all edges whose source is the current node
|
||||
related_edge = [edge for edge in self.cycle_edges if edge.get("source") == node_id]
|
||||
|
||||
# Iterate over each branch
|
||||
for idx in range(branch_number):
|
||||
# Generate a condition expression for each edge
|
||||
# Used later to determine which branch to take based on the node's output
|
||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
||||
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||
|
||||
def make_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
return await inst.run(state)
|
||||
|
||||
return node_func
|
||||
|
||||
self.graph.add_node(node_id, make_func(node_instance))
|
||||
|
||||
def create_edge(self):
|
||||
"""
|
||||
Connect nodes within the cycle subgraph by adding edges to the internal graph.
|
||||
|
||||
Conditional edges are routed based on evaluated expressions.
|
||||
Start and end nodes are connected to global START and END nodes.
|
||||
"""
|
||||
for edge in self.cycle_edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
edge_type = edge.get("type")
|
||||
condition = edge.get("condition")
|
||||
|
||||
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start)
|
||||
if source == self.start_node_id:
|
||||
# 但要连接 start 到下一个节点
|
||||
self.graph.add_edge(START, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
|
||||
if condition:
|
||||
# 条件边
|
||||
def router(state: WorkflowState, cond=condition, tgt=target):
|
||||
"""条件路由函数"""
|
||||
if evaluate_condition(
|
||||
cond,
|
||||
state.get("variables", {}),
|
||||
state.get("node_outputs", {}),
|
||||
{
|
||||
"execution_id": state.get("execution_id"),
|
||||
"workspace_id": state.get("workspace_id"),
|
||||
"user_id": state.get("user_id")
|
||||
}
|
||||
):
|
||||
return tgt
|
||||
return END # 条件不满足,结束
|
||||
|
||||
self.graph.add_conditional_edges(source, router)
|
||||
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
||||
else:
|
||||
# 普通边
|
||||
self.graph.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
|
||||
# 从 end 节点连接到 END
|
||||
for end_node_id in self.end_node_ids:
|
||||
self.graph.add_edge(end_node_id, END)
|
||||
logger.debug(f"添加边: {end_node_id} -> END")
|
||||
|
||||
def build_graph(self):
|
||||
"""
|
||||
Build the internal subgraph for the cycle node.
|
||||
|
||||
Steps:
|
||||
1. Extract cycle nodes and edges.
|
||||
2. Create node instances and add them to the graph.
|
||||
3. Connect edges and conditional routes.
|
||||
4. Compile the graph for execution.
|
||||
"""
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||
self.create_node()
|
||||
self.create_edge()
|
||||
self.graph = self.graph.compile()
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
Execute the cycle node at runtime.
|
||||
|
||||
Depending on the node type, runs either a loop (LoopRuntime)
|
||||
or an iteration (IterationRuntime) over the internal subgraph.
|
||||
|
||||
Args:
|
||||
state: Current workflow state.
|
||||
|
||||
Returns:
|
||||
Runtime result of the cycle, typically the final loop/iteration variables.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If node type is unrecognized.
|
||||
"""
|
||||
if self.node_type == NodeType.LOOP:
|
||||
return await LoopRuntime(
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
).run()
|
||||
if self.node_type == NodeType.ITERATION:
|
||||
return await IterationRuntime(
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
).run()
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
@@ -1,14 +1,5 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from app.core.workflow.nodes.operators import (
|
||||
StringOperator,
|
||||
NumberOperator,
|
||||
AssignmentOperatorType,
|
||||
BooleanOperator,
|
||||
ArrayOperator,
|
||||
ObjectOperator
|
||||
)
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
@@ -27,6 +18,10 @@ class NodeType(StrEnum):
|
||||
JINJARENDER = "jinja-render"
|
||||
VAR_AGGREGATOR = "var-aggregator"
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
LOOP = "loop"
|
||||
ITERATION = "iteration"
|
||||
CYCLE_START = "cycle-start"
|
||||
BREAK = "break"
|
||||
|
||||
|
||||
class ComparisonOperator(StrEnum):
|
||||
@@ -62,21 +57,6 @@ class AssignmentOperator(StrEnum):
|
||||
REMOVE_LAST = "remove_last"
|
||||
REMOVE_FIRST = "remove_first"
|
||||
|
||||
@classmethod
|
||||
def get_operator(cls, obj) -> AssignmentOperatorType:
|
||||
if isinstance(obj, str):
|
||||
return StringOperator
|
||||
elif isinstance(obj, bool):
|
||||
return BooleanOperator
|
||||
elif isinstance(obj, (int, float)):
|
||||
return NumberOperator
|
||||
elif isinstance(obj, list):
|
||||
return ArrayOperator
|
||||
elif isinstance(obj, dict):
|
||||
return ObjectOperator
|
||||
|
||||
raise TypeError(f"Unsupported variable type ({type(obj)})")
|
||||
|
||||
|
||||
class HttpRequestMethod(StrEnum):
|
||||
GET = "GET"
|
||||
|
||||
@@ -215,6 +215,7 @@ class HttpRequestNode(BaseNode):
|
||||
**self._build_content(state)
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||
return HttpRequestNodeOutput(
|
||||
body=resp.text,
|
||||
status_code=resp.status_code,
|
||||
@@ -228,12 +229,21 @@ class HttpRequestNode(BaseNode):
|
||||
else:
|
||||
match self.typed_config.error_handle.method:
|
||||
case HttpErrorHandle.NONE:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, returning error response"
|
||||
)
|
||||
return HttpRequestNodeOutput(
|
||||
body="",
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
).model_dump()
|
||||
case HttpErrorHandle.DEFAULT:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, returning default result"
|
||||
)
|
||||
return self.typed_config.error_handle.default.model_dump()
|
||||
case HttpErrorHandle.BRANCH:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||
)
|
||||
return "ERROR"
|
||||
|
||||
@@ -30,7 +30,7 @@ class ConditionBranchConfig(BaseModel):
|
||||
description="Logical operator used to combine multiple condition expressions"
|
||||
)
|
||||
|
||||
conditions: list[ConditionDetail] = Field(
|
||||
expressions: list[ConditionDetail] = Field(
|
||||
...,
|
||||
description="List of condition expressions within this branch"
|
||||
)
|
||||
@@ -57,7 +57,7 @@ class IfElseNodeConfig(BaseNodeConfig):
|
||||
# CASE1 / IF Branch
|
||||
{
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
"expressions": [
|
||||
[
|
||||
{
|
||||
"left": "node.userinput.message",
|
||||
@@ -75,7 +75,7 @@ class IfElseNodeConfig(BaseNodeConfig):
|
||||
# CASE1 / ELIF Branch
|
||||
{
|
||||
"logical_operator": "or",
|
||||
"conditions": [
|
||||
"expressions": [
|
||||
[
|
||||
{
|
||||
"left": "node.userinput.test",
|
||||
|
||||
@@ -2,93 +2,13 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import ConditionDetail
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConditionExpressionBuilder:
|
||||
"""
|
||||
Build a Python boolean expression string based on a comparison operator.
|
||||
|
||||
This class does not evaluate the expression.
|
||||
It only generates a valid Python expression string
|
||||
that can be evaluated later in a workflow context.
|
||||
"""
|
||||
|
||||
def __init__(self, left: str, operator: ComparisonOperator, right: str):
|
||||
self.left = left
|
||||
self.operator = operator
|
||||
self.right = right
|
||||
|
||||
def _empty(self):
|
||||
return f"{self.left} == ''"
|
||||
|
||||
def _not_empty(self):
|
||||
return f"{self.left} != ''"
|
||||
|
||||
def _contains(self):
|
||||
return f"{self.right} in {self.left}"
|
||||
|
||||
def _not_contains(self):
|
||||
return f"{self.right} not in {self.left}"
|
||||
|
||||
def _startwith(self):
|
||||
return f'{self.left}.startswith({self.right})'
|
||||
|
||||
def _endwith(self):
|
||||
return f'{self.left}.endswith({self.right})'
|
||||
|
||||
def _eq(self):
|
||||
return f"{self.left} == {self.right}"
|
||||
|
||||
def _ne(self):
|
||||
return f"{self.left} != {self.right}"
|
||||
|
||||
def _lt(self):
|
||||
return f"{self.left} < {self.right}"
|
||||
|
||||
def _le(self):
|
||||
return f"{self.left} <= {self.right}"
|
||||
|
||||
def _gt(self):
|
||||
return f"{self.left} > {self.right}"
|
||||
|
||||
def _ge(self):
|
||||
return f"{self.left} >= {self.right}"
|
||||
|
||||
def build(self):
|
||||
match self.operator:
|
||||
case ComparisonOperator.EMPTY:
|
||||
return self._empty()
|
||||
case ComparisonOperator.NOT_EMPTY:
|
||||
return self._not_empty()
|
||||
case ComparisonOperator.CONTAINS:
|
||||
return self._contains()
|
||||
case ComparisonOperator.NOT_CONTAINS:
|
||||
return self._not_contains()
|
||||
case ComparisonOperator.START_WITH:
|
||||
return self._startwith()
|
||||
case ComparisonOperator.END_WITH:
|
||||
return self._endwith()
|
||||
case ComparisonOperator.EQ:
|
||||
return self._eq()
|
||||
case ComparisonOperator.NE:
|
||||
return self._ne()
|
||||
case ComparisonOperator.LT:
|
||||
return self._lt()
|
||||
case ComparisonOperator.LE:
|
||||
return self._le()
|
||||
case ComparisonOperator.GT:
|
||||
return self._gt()
|
||||
case ComparisonOperator.GE:
|
||||
return self._ge()
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {self.operator}")
|
||||
|
||||
|
||||
class IfElseNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
@@ -143,7 +63,7 @@ class IfElseNode(BaseNode):
|
||||
|
||||
branch_conditions = [
|
||||
self._build_condition_expression(condition)
|
||||
for condition in case_branch.conditions
|
||||
for condition in case_branch.expressions
|
||||
]
|
||||
if len(branch_conditions) > 1:
|
||||
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
|
||||
@@ -174,5 +94,6 @@ class IfElseNode(BaseNode):
|
||||
for i in range(len(expressions)):
|
||||
logger.info(expressions[i])
|
||||
if self._evaluate_condition(expressions[i], state):
|
||||
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
|
||||
return f'CASE{i + 1}'
|
||||
return f'CASE{len(expressions)}'
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
@@ -5,6 +6,7 @@ from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
|
||||
from app.core.workflow.template_renderer import TemplateRenderer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class JinjaRenderNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
@@ -41,5 +43,5 @@ class JinjaRenderNode(BaseNode):
|
||||
res = render.env.from_string(self.typed_config.template).render(**context)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"JinjaRender Node {self.node_name} render failed: {e}") from e
|
||||
|
||||
logger.info(f"Node {self.node_id}: Jinja template rendering completed")
|
||||
return res
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.schemas.chunk_schema import RetrieveType
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||
query: str = Field(
|
||||
...,
|
||||
description="Search query string"
|
||||
)
|
||||
|
||||
kb_ids: list[UUID] = Field(
|
||||
class KnowledgeBaseConfig(BaseModel):
|
||||
kb_id: UUID = Field(
|
||||
...,
|
||||
description="Knowledge base IDs"
|
||||
)
|
||||
@@ -37,18 +32,42 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||
description="Retrieve type"
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||
query: str = Field(
|
||||
...,
|
||||
description="Search query string"
|
||||
)
|
||||
|
||||
knowledge_bases: list[KnowledgeBaseConfig] = Field(
|
||||
...,
|
||||
description="Knowledge base config"
|
||||
)
|
||||
|
||||
reranker_id: UUID = Field(
|
||||
...,
|
||||
description="Reranker top k"
|
||||
)
|
||||
|
||||
reranker_top_k: int = Field(
|
||||
default=4,
|
||||
description="Knowledge base top k"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"query": "{{sys.message}}",
|
||||
"kb_ids": [
|
||||
"xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
|
||||
],
|
||||
"similarity_threshold": 0.2,
|
||||
"vector_similarity_weight": 0.3,
|
||||
"top_k": 1,
|
||||
"retrieve_type": "hybrid"
|
||||
"knowledge_bases": [{
|
||||
"kb_id": "xxxxxxxx-xxxx-xxxx-xxxxxxxxxxxxxxxxx",
|
||||
"similarity_threshold": 0.2,
|
||||
"vector_similarity_weight": 0.3,
|
||||
"top_k": 4,
|
||||
"retrieve_type": "hybrid"
|
||||
}],
|
||||
"reranker_top_k": 1,
|
||||
"reranker_id": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -2,14 +2,18 @@ import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.db import get_db_read
|
||||
from app.models import knowledge_model, knowledgeshare_model
|
||||
from app.models import knowledge_model, knowledgeshare_model, ModelType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.chunk_schema import RetrieveType
|
||||
from app.services import knowledge_service, knowledgeshare_service
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -108,6 +112,44 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
existing_ids.extend(items)
|
||||
return existing_ids
|
||||
|
||||
def get_reranker_model(self) -> RedBearRerank:
|
||||
"""
|
||||
Retrieve and initialize a RedBear reranker model based on configuration.
|
||||
|
||||
Raises:
|
||||
BusinessException: If configuration is missing or API keys are not set.
|
||||
RuntimeError: If the configured model is not of type RERANK.
|
||||
"""
|
||||
with get_db_read() as db:
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.reranker_id)
|
||||
|
||||
if not config:
|
||||
raise BusinessException("Configured model does not exist", BizCode.NOT_FOUND)
|
||||
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
# 在 Session 关闭前提取所有需要的数据
|
||||
api_config = config.api_keys[0]
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
model_type = config.type
|
||||
|
||||
if model_type != ModelType.RERANK:
|
||||
raise RuntimeError("Model is not a reranker")
|
||||
|
||||
reranker = RedBearRerank(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
)
|
||||
)
|
||||
return reranker
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
Execute the knowledge retrieval workflow node.
|
||||
@@ -131,38 +173,45 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
"""
|
||||
query = self._render_template(self.typed_config.query, state)
|
||||
with get_db_read() as db:
|
||||
existing_ids = self._get_existing_kb_ids(db, self.typed_config.kb_ids)
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])
|
||||
|
||||
if not existing_ids:
|
||||
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
|
||||
|
||||
kb_id = existing_ids[0]
|
||||
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
|
||||
indices = ",".join(uuid_strs)
|
||||
rs = []
|
||||
for kb_config in knowledge_bases:
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||
if not db_knowledge:
|
||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id)
|
||||
if not db_knowledge:
|
||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
match self.typed_config.retrieve_type:
|
||||
case RetrieveType.PARTICIPLE:
|
||||
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=self.typed_config.similarity_threshold)
|
||||
case RetrieveType.SEMANTIC:
|
||||
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=self.typed_config.vector_similarity_weight)
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=self.typed_config.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=self.typed_config.similarity_threshold)
|
||||
# Deduplicate hybrid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
|
||||
return [chunk.model_dump() for chunk in rs]
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||
match kb_config.retrieve_type:
|
||||
case RetrieveType.PARTICIPLE:
|
||||
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold))
|
||||
case RetrieveType.SEMANTIC:
|
||||
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight))
|
||||
case RetrieveType.HYBRID:
|
||||
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
# Deduplicate hybrid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
logger.info(
|
||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||
)
|
||||
return [chunk.model_dump() for chunk in final_rs]
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any, Union
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.assigner import AssignerNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.nodes.http_request import HttpRequestNode
|
||||
@@ -22,6 +23,7 @@ from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.breaker import BreakNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,6 +41,9 @@ WorkflowNode = Union[
|
||||
JinjaRenderNode,
|
||||
VariableAggregatorNode,
|
||||
ParameterExtractorNode,
|
||||
CycleGraphNode,
|
||||
BreakNode,
|
||||
ParameterExtractorNode,
|
||||
QuestionClassifierNode
|
||||
]
|
||||
|
||||
@@ -64,6 +69,9 @@ class NodeFactory:
|
||||
NodeType.VAR_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
||||
NodeType.LOOP: CycleGraphNode,
|
||||
NodeType.ITERATION: CycleGraphNode,
|
||||
NodeType.BREAK: BreakNode,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC
|
||||
from typing import Union, Type
|
||||
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
|
||||
@@ -136,6 +137,23 @@ class ObjectOperator(OperatorBase):
|
||||
self.pool.set(self.left_selector, dict())
|
||||
|
||||
|
||||
class AssignmentOperatorResolver:
|
||||
@classmethod
|
||||
def resolve_by_value(cls, value):
|
||||
if isinstance(value, str):
|
||||
return StringOperator
|
||||
elif isinstance(value, bool):
|
||||
return BooleanOperator
|
||||
elif isinstance(value, (int, float)):
|
||||
return NumberOperator
|
||||
elif isinstance(value, list):
|
||||
return ArrayOperator
|
||||
elif isinstance(value, dict):
|
||||
return ObjectOperator
|
||||
else:
|
||||
raise TypeError(f"Unsupported variable type: {type(value)}")
|
||||
|
||||
|
||||
AssignmentOperatorInstance = Union[
|
||||
StringOperator,
|
||||
NumberOperator,
|
||||
@@ -144,3 +162,83 @@ AssignmentOperatorInstance = Union[
|
||||
ObjectOperator
|
||||
]
|
||||
AssignmentOperatorType = Type[AssignmentOperatorInstance]
|
||||
|
||||
|
||||
class ConditionExpressionBuilder:
|
||||
"""
|
||||
Build a Python boolean expression string based on a comparison operator.
|
||||
|
||||
This class does not evaluate the expression.
|
||||
It only generates a valid Python expression string
|
||||
that can be evaluated later in a workflow context.
|
||||
"""
|
||||
|
||||
def __init__(self, left: str, operator: ComparisonOperator, right: str):
|
||||
self.left = left
|
||||
self.operator = operator
|
||||
self.right = right
|
||||
|
||||
def _empty(self):
|
||||
return f"{self.left} == ''"
|
||||
|
||||
def _not_empty(self):
|
||||
return f"{self.left} != ''"
|
||||
|
||||
def _contains(self):
|
||||
return f"{self.right} in {self.left}"
|
||||
|
||||
def _not_contains(self):
|
||||
return f"{self.right} not in {self.left}"
|
||||
|
||||
def _startswith(self):
|
||||
return f'{self.left}.startswith({self.right})'
|
||||
|
||||
def _endswith(self):
|
||||
return f'{self.left}.endswith({self.right})'
|
||||
|
||||
def _eq(self):
|
||||
return f"{self.left} == {self.right}"
|
||||
|
||||
def _ne(self):
|
||||
return f"{self.left} != {self.right}"
|
||||
|
||||
def _lt(self):
|
||||
return f"{self.left} < {self.right}"
|
||||
|
||||
def _le(self):
|
||||
return f"{self.left} <= {self.right}"
|
||||
|
||||
def _gt(self):
|
||||
return f"{self.left} > {self.right}"
|
||||
|
||||
def _ge(self):
|
||||
return f"{self.left} >= {self.right}"
|
||||
|
||||
def build(self):
|
||||
match self.operator:
|
||||
case ComparisonOperator.EMPTY:
|
||||
return self._empty()
|
||||
case ComparisonOperator.NOT_EMPTY:
|
||||
return self._not_empty()
|
||||
case ComparisonOperator.CONTAINS:
|
||||
return self._contains()
|
||||
case ComparisonOperator.NOT_CONTAINS:
|
||||
return self._not_contains()
|
||||
case ComparisonOperator.START_WITH:
|
||||
return self._startswith()
|
||||
case ComparisonOperator.END_WITH:
|
||||
return self._endswith()
|
||||
case ComparisonOperator.EQ:
|
||||
return self._eq()
|
||||
case ComparisonOperator.NE:
|
||||
return self._ne()
|
||||
case ComparisonOperator.LT:
|
||||
return self._lt()
|
||||
case ComparisonOperator.LE:
|
||||
return self._le()
|
||||
case ComparisonOperator.GT:
|
||||
return self._gt()
|
||||
case ComparisonOperator.GE:
|
||||
return self._ge()
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {self.operator}")
|
||||
|
||||
@@ -36,6 +36,11 @@ class ParamsConfig(BaseModel):
|
||||
description="Description of the parameter"
|
||||
)
|
||||
|
||||
required: bool = Field(
|
||||
...,
|
||||
description="Whether the parameter is required"
|
||||
)
|
||||
|
||||
|
||||
class ParameterExtractorNodeConfig(BaseNodeConfig):
|
||||
model_id: uuid.UUID = Field(
|
||||
@@ -52,3 +57,8 @@ class ParameterExtractorNodeConfig(BaseNodeConfig):
|
||||
...,
|
||||
description="List of parameters"
|
||||
)
|
||||
|
||||
prompt: str = Field(
|
||||
...,
|
||||
description="User-provided supplemental prompt"
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
import json_repair
|
||||
from typing import Any
|
||||
@@ -15,6 +16,8 @@ from app.db import get_db_read
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParameterExtractorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
@@ -114,7 +117,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
"""
|
||||
field_type = {}
|
||||
for param in self.typed_config.params:
|
||||
field_type[param.name] = param.type
|
||||
field_type[param.name] = f'{param.type}, required:{str(param.required)}'
|
||||
return field_type
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
@@ -154,12 +157,12 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
messages = [
|
||||
("system", system_prompt),
|
||||
("user", self._render_template(self.typed_config.prompt, state)),
|
||||
("user", rendered_user_prompt),
|
||||
]
|
||||
|
||||
model_resp = await llm.ainvoke(messages)
|
||||
result = json_repair.repair_json(model_resp.content)
|
||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||
logger.info(f"node: {self.node_id} get params:{result}")
|
||||
|
||||
return {
|
||||
"output": result,
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -9,43 +9,27 @@ class VariableAggregatorNodeConfig(BaseNodeConfig):
|
||||
description="输出变量是否需要分组",
|
||||
)
|
||||
|
||||
group_names: list[str] = Field(
|
||||
default_factory=lambda: ["output"],
|
||||
description="各个分组的名称"
|
||||
)
|
||||
|
||||
group_variables: list[str] | list[list[str]] = Field(
|
||||
group_variables: list[str] | dict[str, list[str]] = Field(
|
||||
...,
|
||||
description="需要被聚合的变量"
|
||||
)
|
||||
|
||||
@field_validator("group_names", mode="before")
|
||||
@classmethod
|
||||
def group_names_validator(cls, v, info):
|
||||
group_status = info.data.get("group")
|
||||
if not group_status or not v:
|
||||
return ["output"]
|
||||
return v
|
||||
|
||||
@field_validator("group_variables")
|
||||
@classmethod
|
||||
def group_variables_validator(cls, v, info):
|
||||
group_status = info.data.get("group")
|
||||
group_names = info.data.get("group_names")
|
||||
if not isinstance(v, list):
|
||||
raise ValueError("group_variables must be a list")
|
||||
|
||||
if not group_status:
|
||||
for variable in v:
|
||||
if not isinstance(variable, str):
|
||||
raise ValueError("When group=False, group_variables must be a list of strings")
|
||||
else:
|
||||
if len(group_names) != len(v):
|
||||
raise ValueError("group_names and group_variables length mismatch")
|
||||
for group in v:
|
||||
if not isinstance(group, list):
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError("When group=True, group_variables must be a dict")
|
||||
for group_name, group_values in v.items():
|
||||
if not isinstance(group_name, str):
|
||||
raise ValueError("When group=True, each element of group_variables must be a list")
|
||||
for variable in group:
|
||||
for variable in group_values:
|
||||
if not isinstance(variable, str):
|
||||
raise ValueError("Each element inside group_variables lists must be a string")
|
||||
return v
|
||||
|
||||
@@ -50,6 +50,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
continue
|
||||
|
||||
if value is not None:
|
||||
logger.info(f"Node: {self.node_id} variable aggregation result: {value}")
|
||||
return value
|
||||
|
||||
logger.info("No variable found in non-group mode; returning empty string.")
|
||||
@@ -59,7 +60,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
# Group mode
|
||||
# --------------------------
|
||||
result = {}
|
||||
for group_name, variables in zip(self.typed_config.group_names, self.typed_config.group_variables):
|
||||
for group_name, variables in self.typed_config.group_variables.items():
|
||||
for variable in variables:
|
||||
var_express = self._get_express(variable)
|
||||
try:
|
||||
@@ -74,5 +75,5 @@ class VariableAggregatorNode(BaseNode):
|
||||
else:
|
||||
result[group_name] = ""
|
||||
logger.info(f"No variable found for group '{group_name}'; set empty string.")
|
||||
|
||||
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
|
||||
return result
|
||||
|
||||
@@ -7,14 +7,87 @@
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowValidator:
|
||||
"""工作流配置验证器"""
|
||||
|
||||
@staticmethod
|
||||
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
||||
|
||||
@classmethod
|
||||
def pure_cycle_graph(cls, workflow_config: Union[dict[str, Any], Any], node_id) -> tuple[list, list]:
|
||||
"""
|
||||
Extract cycle nodes and internal edges from the workflow configuration,
|
||||
removing them from the global workflow.
|
||||
|
||||
Raises:
|
||||
ValueError: If cycle nodes are connected to external nodes improperly.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- cycle_nodes: List of removed nodes
|
||||
- cycle_edges: List of removed edges
|
||||
"""
|
||||
nodes = workflow_config.get("nodes", [])
|
||||
edges = workflow_config.get("edges", [])
|
||||
|
||||
# Select all nodes that belong to the current cycle
|
||||
cycle_nodes = [node for node in nodes if node.get("cycle") == node_id]
|
||||
cycle_node_ids = {node.get("id") for node in cycle_nodes}
|
||||
|
||||
cycle_edges = []
|
||||
remain_edges = []
|
||||
|
||||
for edge in edges:
|
||||
source_in = edge.get("source") in cycle_node_ids
|
||||
target_in = edge.get("target") in cycle_node_ids
|
||||
|
||||
# Raise error if cycle nodes are connected with external nodes
|
||||
if source_in ^ target_in:
|
||||
raise ValueError(
|
||||
f"Cycle node is connected to external node, "
|
||||
f"source: {edge.get('source')}, target: {edge.get('target')}"
|
||||
)
|
||||
|
||||
if source_in and target_in:
|
||||
cycle_edges.append(edge)
|
||||
else:
|
||||
remain_edges.append(edge)
|
||||
|
||||
# Update workflow_config by removing cycle nodes and internal edges
|
||||
workflow_config["nodes"] = [
|
||||
node for node in nodes if node.get("cycle") != node_id
|
||||
]
|
||||
workflow_config["edges"] = remain_edges
|
||||
|
||||
return cycle_nodes, cycle_edges
|
||||
|
||||
@classmethod
|
||||
def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list:
|
||||
if not isinstance(workflow_config, dict):
|
||||
workflow_config = {
|
||||
"nodes": workflow_config.nodes,
|
||||
"edges": workflow_config.edges,
|
||||
"variables": workflow_config.variables,
|
||||
}
|
||||
cycle_nodes = [
|
||||
node.get("id")
|
||||
for node in workflow_config.get("nodes", [])
|
||||
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
|
||||
]
|
||||
graphs = []
|
||||
for cycle_node in cycle_nodes:
|
||||
nodes, edges = cls.pure_cycle_graph(workflow_config, cycle_node)
|
||||
graphs.append({
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
})
|
||||
graphs.append(workflow_config)
|
||||
return graphs
|
||||
|
||||
@classmethod
|
||||
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置
|
||||
|
||||
Args:
|
||||
@@ -38,84 +111,79 @@ class WorkflowValidator:
|
||||
True
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# 支持字典和 Pydantic 模型
|
||||
if isinstance(workflow_config, dict):
|
||||
nodes = workflow_config.get("nodes", [])
|
||||
edges = workflow_config.get("edges", [])
|
||||
variables = workflow_config.get("variables", [])
|
||||
else:
|
||||
# Pydantic 模型
|
||||
nodes = getattr(workflow_config, "nodes", [])
|
||||
edges = getattr(workflow_config, "edges", [])
|
||||
variables = getattr(workflow_config, "variables", [])
|
||||
|
||||
# 1. 验证 start 节点(有且只有一个)
|
||||
start_nodes = [n for n in nodes if n.get("type") == "start"]
|
||||
if len(start_nodes) == 0:
|
||||
errors.append("工作流必须有一个 start 节点")
|
||||
elif len(start_nodes) > 1:
|
||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||
|
||||
# 2. 验证 end 节点(至少一个)
|
||||
end_nodes = [n for n in nodes if n.get("type") == "end"]
|
||||
if len(end_nodes) == 0:
|
||||
errors.append("工作流必须至少有一个 end 节点")
|
||||
|
||||
# 3. 验证节点 ID 唯一性
|
||||
node_ids = [n.get("id") for n in nodes]
|
||||
if len(node_ids) != len(set(node_ids)):
|
||||
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
|
||||
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
|
||||
|
||||
# 4. 验证节点必须有 id 和 type
|
||||
for i, node in enumerate(nodes):
|
||||
if not node.get("id"):
|
||||
errors.append(f"节点 #{i} 缺少 id 字段")
|
||||
if not node.get("type"):
|
||||
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
|
||||
|
||||
# 5. 验证边的有效性
|
||||
node_id_set = set(node_ids)
|
||||
for i, edge in enumerate(edges):
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
|
||||
if not source:
|
||||
errors.append(f"边 #{i} 缺少 source 字段")
|
||||
elif source not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
|
||||
|
||||
if not target:
|
||||
errors.append(f"边 #{i} 缺少 target 字段")
|
||||
elif target not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
||||
|
||||
# 6. 验证所有节点可达(从 start 节点出发)
|
||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||
reachable = WorkflowValidator._get_reachable_nodes(
|
||||
start_nodes[0]["id"],
|
||||
edges
|
||||
)
|
||||
unreachable = node_id_set - reachable
|
||||
if unreachable:
|
||||
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||
|
||||
# 7. 检测循环依赖(非 loop 节点)
|
||||
if not errors: # 只有在前面验证通过时才检查循环
|
||||
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
||||
if has_cycle:
|
||||
errors.append(
|
||||
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
||||
|
||||
graphs = cls.get_subgraph(workflow_config)
|
||||
logger.info(graphs)
|
||||
for graph in graphs:
|
||||
nodes = graph.get("nodes", [])
|
||||
edges = graph.get("edges", [])
|
||||
variables = graph.get("variables", [])
|
||||
# 1. 验证 start 节点(有且只有一个)
|
||||
start_nodes = [n for n in nodes if n.get("type") in [NodeType.START, NodeType.CYCLE_START]]
|
||||
if len(start_nodes) == 0:
|
||||
errors.append("工作流必须有一个 start 节点")
|
||||
elif len(start_nodes) > 1:
|
||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||
|
||||
# 2. 验证 end 节点(至少一个)
|
||||
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
||||
if len(end_nodes) == 0:
|
||||
errors.append("工作流必须至少有一个 end 节点")
|
||||
|
||||
# 3. 验证节点 ID 唯一性
|
||||
node_ids = [n.get("id") for n in nodes]
|
||||
if len(node_ids) != len(set(node_ids)):
|
||||
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
|
||||
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
|
||||
|
||||
# 4. 验证节点必须有 id 和 type
|
||||
for i, node in enumerate(nodes):
|
||||
if not node.get("id"):
|
||||
errors.append(f"节点 #{i} 缺少 id 字段")
|
||||
if not node.get("type"):
|
||||
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
|
||||
|
||||
# 5. 验证边的有效性
|
||||
node_id_set = set(node_ids)
|
||||
for i, edge in enumerate(edges):
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
|
||||
if not source:
|
||||
errors.append(f"边 #{i} 缺少 source 字段")
|
||||
elif source not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
|
||||
|
||||
if not target:
|
||||
errors.append(f"边 #{i} 缺少 target 字段")
|
||||
elif target not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
||||
|
||||
# 6. 验证所有节点可达(从 start 节点出发)
|
||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||
reachable = WorkflowValidator._get_reachable_nodes(
|
||||
start_nodes[0]["id"],
|
||||
edges
|
||||
)
|
||||
|
||||
# 8. 验证变量名
|
||||
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||
var_errors = ExpressionEvaluator.validate_variable_names(variables)
|
||||
errors.extend(var_errors)
|
||||
|
||||
unreachable = node_id_set - reachable
|
||||
if unreachable:
|
||||
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||
|
||||
# 7. 检测循环依赖(非 loop 节点)
|
||||
if not errors: # 只有在前面验证通过时才检查循环
|
||||
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
||||
if has_cycle:
|
||||
errors.append(
|
||||
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
||||
)
|
||||
|
||||
# 8. 验证变量名
|
||||
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||
var_errors = ExpressionEvaluator.validate_variable_names(variables)
|
||||
errors.extend(var_errors)
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
||||
"""获取从 start 节点可达的所有节点
|
||||
@@ -129,7 +197,7 @@ class WorkflowValidator:
|
||||
"""
|
||||
reachable = {start_id}
|
||||
queue = [start_id]
|
||||
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for edge in edges:
|
||||
@@ -138,9 +206,9 @@ class WorkflowValidator:
|
||||
if target and target not in reachable:
|
||||
reachable.add(target)
|
||||
queue.append(target)
|
||||
|
||||
|
||||
return reachable
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]:
|
||||
"""检测是否存在循环依赖(DFS)
|
||||
@@ -154,39 +222,39 @@ class WorkflowValidator:
|
||||
"""
|
||||
# 排除 loop 类型的节点
|
||||
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
|
||||
|
||||
|
||||
# 构建邻接表(排除 loop 节点的边和错误边)
|
||||
graph: dict[str, list[str]] = {}
|
||||
for edge in edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
edge_type = edge.get("type")
|
||||
|
||||
|
||||
# 跳过错误边
|
||||
if edge_type == "error":
|
||||
continue
|
||||
|
||||
|
||||
# 如果涉及 loop 节点,跳过
|
||||
if source in loop_nodes or target in loop_nodes:
|
||||
continue
|
||||
|
||||
|
||||
if source and target:
|
||||
if source not in graph:
|
||||
graph[source] = []
|
||||
graph[source].append(target)
|
||||
|
||||
|
||||
# DFS 检测环
|
||||
visited = set()
|
||||
rec_stack = set()
|
||||
path = []
|
||||
cycle_path = []
|
||||
|
||||
|
||||
def dfs(node: str) -> bool:
|
||||
"""DFS 检测环,返回是否找到环"""
|
||||
visited.add(node)
|
||||
rec_stack.add(node)
|
||||
path.append(node)
|
||||
|
||||
|
||||
for neighbor in graph.get(node, []):
|
||||
if neighbor not in visited:
|
||||
if dfs(neighbor):
|
||||
@@ -196,19 +264,19 @@ class WorkflowValidator:
|
||||
cycle_start = path.index(neighbor)
|
||||
cycle_path.extend([*path[cycle_start:], neighbor])
|
||||
return True
|
||||
|
||||
|
||||
rec_stack.remove(node)
|
||||
path.pop()
|
||||
return False
|
||||
|
||||
|
||||
# 检查所有节点
|
||||
for node_id in graph:
|
||||
if node_id not in visited:
|
||||
if dfs(node_id):
|
||||
return True, cycle_path
|
||||
|
||||
|
||||
return False, []
|
||||
|
||||
|
||||
@staticmethod
|
||||
def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置是否可以发布(更严格的验证)
|
||||
@@ -221,30 +289,30 @@ class WorkflowValidator:
|
||||
"""
|
||||
# 先执行基础验证
|
||||
is_valid, errors = WorkflowValidator.validate(workflow_config)
|
||||
|
||||
|
||||
if not is_valid:
|
||||
return False, errors
|
||||
|
||||
|
||||
# 额外的发布验证
|
||||
nodes = workflow_config.get("nodes", [])
|
||||
|
||||
|
||||
# 1. 验证所有节点都有名称
|
||||
for node in nodes:
|
||||
if node.get("type") not in ["start", "end"] and not node.get("name"):
|
||||
if node.get("type") not in [NodeType.START, NodeType.CYCLE_START, NodeType.END] and not node.get("name"):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
|
||||
)
|
||||
|
||||
|
||||
# 2. 验证所有非 start/end 节点都有配置
|
||||
for node in nodes:
|
||||
node_type = node.get("type")
|
||||
if node_type not in ["start", "end"]:
|
||||
if node_type not in [NodeType.START, NodeType.CYCLE_START, NodeType.END, NodeType.BREAK]:
|
||||
config = node.get("config")
|
||||
if not config or not isinstance(config, dict):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
|
||||
)
|
||||
|
||||
|
||||
# 3. 验证必填变量
|
||||
variables = workflow_config.get("variables", [])
|
||||
required_vars = [v for v in variables if v.get("required")]
|
||||
@@ -254,13 +322,13 @@ class WorkflowValidator:
|
||||
f"工作流包含 {len(required_vars)} 个必填变量: "
|
||||
f"{[v.get('name') for v in required_vars]}"
|
||||
)
|
||||
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
|
||||
def validate_workflow_config(
|
||||
workflow_config: dict[str, Any],
|
||||
for_publish: bool = False
|
||||
workflow_config: dict[str, Any],
|
||||
for_publish: bool = False
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置(便捷函数)
|
||||
|
||||
|
||||
@@ -198,19 +198,22 @@ class VariablePool:
|
||||
|
||||
namespace = selector[0]
|
||||
|
||||
if namespace != "conv":
|
||||
raise ValueError("只能设置会话变量 (conv.*)")
|
||||
if namespace != "conv" and namespace not in self.state["cycle_nodes"]:
|
||||
raise ValueError("Only conversation or cycle variables can be assigned.")
|
||||
|
||||
key = selector[1]
|
||||
|
||||
# 确保 variables 结构存在
|
||||
if "variables" not in self.state:
|
||||
self.state["variables"] = {"sys": {}, "conv": {}}
|
||||
if "conv" not in self.state["variables"]:
|
||||
self.state["variables"]["conv"] = {}
|
||||
|
||||
# 设置值
|
||||
self.state["variables"]["conv"][key] = value
|
||||
if namespace == "conv":
|
||||
if "conv" not in self.state["variables"]:
|
||||
self.state["variables"]["conv"] = {}
|
||||
|
||||
# 设置值
|
||||
self.state["variables"]["conv"][key] = value
|
||||
elif namespace in self.state["cycle_nodes"]:
|
||||
self.state["runtime_vars"][namespace][key] = value
|
||||
|
||||
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ class Document(Base):
|
||||
"html4excel": False,
|
||||
"graphrag": {
|
||||
"use_graphrag": False,
|
||||
"scene_name": "",
|
||||
"entity_types": [
|
||||
"organization",
|
||||
"person",
|
||||
@@ -33,7 +34,9 @@ class Document(Base):
|
||||
"event",
|
||||
"category"
|
||||
],
|
||||
"method": "general"
|
||||
"method": "general",
|
||||
"resolution": True,
|
||||
"community": True
|
||||
}
|
||||
}, comment="default parser config")
|
||||
chunk_num = Column(Integer, default=0, comment="chunk num")
|
||||
|
||||
@@ -65,6 +65,7 @@ class Knowledge(Base):
|
||||
"html4excel": False,
|
||||
"graphrag": {
|
||||
"use_graphrag": False,
|
||||
"scene_name": "",
|
||||
"entity_types": [
|
||||
"organization",
|
||||
"person",
|
||||
@@ -72,7 +73,9 @@ class Knowledge(Base):
|
||||
"event",
|
||||
"category"
|
||||
],
|
||||
"method": "general"
|
||||
"method": "general",
|
||||
"resolution": True,
|
||||
"community": True
|
||||
}
|
||||
},
|
||||
comment="default parser config")
|
||||
|
||||
@@ -58,13 +58,13 @@ def get_chunked_knowledgeids(
|
||||
) -> list:
|
||||
"""
|
||||
Query the list of vectorized knowledge base IDs
|
||||
Return: list[UUID] - List of knowledge base IDs
|
||||
Return: list[(id,workspace_id)] - List of knowledge base id and workspace_id
|
||||
"""
|
||||
db_logger.debug(f"Query the list of vectorized knowledge base IDs: filters_count={len(filters)}")
|
||||
|
||||
try:
|
||||
# Only query the id field
|
||||
query = db.query(Knowledge.id)
|
||||
query = db.query(Knowledge.id, Knowledge.workspace_id)
|
||||
|
||||
# Apply filter conditions
|
||||
for filter_cond in filters:
|
||||
@@ -74,8 +74,8 @@ def get_chunked_knowledgeids(
|
||||
items = query.all()
|
||||
db_logger.info(f"Querying the vectorized knowledge base id list succeeded: count={len(items)}")
|
||||
|
||||
# Return the list of IDs directly. Since only the ID field is queried, the returned data is a single column
|
||||
return [item[0] for item in items]
|
||||
# Return the list of ID and workspace_id directly. Since only the ID and workspace_id field is queried
|
||||
return items
|
||||
except Exception as e:
|
||||
db_logger.error(f"Querying the vectorized knowledge base id list failed: {str(e)}")
|
||||
raise
|
||||
|
||||
@@ -61,14 +61,14 @@ def get_source_kb_ids_by_target_kb_id(
|
||||
) -> list:
|
||||
"""
|
||||
Query the original knowledge base ID list by sharing the knowledge base
|
||||
Return: list[UUID] - List of knowledge base IDs
|
||||
Return: list[(source_kb_id,source_workspace_id)] - List of knowledge base source_kb_id and source_workspace_id
|
||||
"""
|
||||
db_logger.debug(
|
||||
f"Query the original knowledge base id list by sharing the knowledge base: filters_count={len(filters)}")
|
||||
|
||||
try:
|
||||
# Only query the id field
|
||||
query = db.query(KnowledgeShare.source_kb_id)
|
||||
query = db.query(KnowledgeShare.source_kb_id, KnowledgeShare.source_workspace_id)
|
||||
|
||||
# Apply filter conditions
|
||||
for filter_cond in filters:
|
||||
@@ -78,8 +78,8 @@ def get_source_kb_ids_by_target_kb_id(
|
||||
items = query.all()
|
||||
db_logger.info(f"Successfully queried the original knowledge base ID list by sharing the knowledge base: count={len(items)}")
|
||||
|
||||
# Return the list of IDs directly. Since only the ID field is queried, the returned data is a single column
|
||||
return [item[0] for item in items]
|
||||
# Return the list of source_kb_id and source_workspace_id directly. Since only the source_kb_id and source_workspace_id field is queried
|
||||
return items
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the original knowledge base ID list through knowledge base sharing: {str(e)}")
|
||||
raise
|
||||
|
||||
@@ -32,6 +32,7 @@ class KnowledgeRetrievalConfig(BaseModel):
|
||||
)
|
||||
reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID")
|
||||
reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数")
|
||||
use_graph: bool = Field(default=False, description="是否使用图搜索")
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
|
||||
@@ -10,6 +10,7 @@ class RetrieveType(StrEnum):
|
||||
PARTICIPLE = "participle"
|
||||
SEMANTIC = "semantic"
|
||||
HYBRID = "hybrid"
|
||||
Graph = "graph"
|
||||
|
||||
|
||||
class ChunkCreate(BaseModel):
|
||||
|
||||
@@ -12,8 +12,8 @@ class Memory_Reflection(BaseModel):
|
||||
config_id: Optional[int] = None
|
||||
reflection_enabled: bool
|
||||
reflection_period_in_hours: str
|
||||
reflexion_range: str
|
||||
baseline: str
|
||||
reflexion_range: Optional[str] = "partial"
|
||||
baseline: Optional[str] = "TIME"
|
||||
reflection_model_id: str
|
||||
memory_verify: bool
|
||||
quality_assessment: bool
|
||||
|
||||
@@ -20,6 +20,7 @@ class NodeDefinition(BaseModel):
|
||||
id: str = Field(..., description="节点唯一标识")
|
||||
type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code")
|
||||
name: str | None = Field(None, description="节点名称")
|
||||
cycle: str | None = Field(None, description="父循环节点id")
|
||||
description: str | None = Field(None, description="节点描述")
|
||||
config: dict[str, Any] = Field(default_factory=dict, description="节点配置")
|
||||
position: dict[str, float] | None = Field(None, description="节点位置 {x, y}")
|
||||
|
||||
@@ -42,7 +42,7 @@ nodes:
|
||||
- 适当使用格式化(如列表、段落)提高可读性
|
||||
|
||||
- role: user
|
||||
content: "{{ sys.message }}"
|
||||
content: "{{sys.message}}"
|
||||
|
||||
model_id: null
|
||||
temperature: 0.7
|
||||
@@ -55,7 +55,7 @@ nodes:
|
||||
type: end
|
||||
name: 结束
|
||||
config:
|
||||
output: "{{ llm_qa.output }}"
|
||||
output: "{{llm_qa.output}}"
|
||||
position:
|
||||
x: 900
|
||||
y: 100
|
||||
|
||||
@@ -135,6 +135,8 @@ dependencies = [
|
||||
"graspologic==3.4.5.dev2",
|
||||
"markdown-to-json==2.1.1",
|
||||
"valkey==6.0.2",
|
||||
"python-calamine>=0.4.0",
|
||||
"xlrd==2.0.2"
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -129,3 +129,5 @@ editdistance==0.8.1
|
||||
graspologic==3.4.5.dev2
|
||||
markdown-to-json==2.1.1
|
||||
valkey==6.0.2
|
||||
python-calamine>=0.4.0
|
||||
xlrd==2.0.2
|
||||
|
||||
Reference in New Issue
Block a user