Merge branch 'refs/heads/develop' into fix/memory_bug_fix

This commit is contained in:
lixinyue
2026-01-27 20:23:57 +08:00
30 changed files with 706 additions and 896 deletions

View File

@@ -310,7 +310,7 @@ async def get_file_url(
try:
if permanent:
# Generate permanent URL (no expiration check)
server_url = f"http://{settings.SERVER_IP}:8000/api"
server_url = settings.FILE_LOCAL_SERVER_URL
url = f"{server_url}/storage/permanent/{file_id}"
return success(
data={

View File

@@ -9,6 +9,25 @@ load_dotenv()
class Settings:
# ========================================================================
# Deployment Mode Configuration
# ========================================================================
# community: 社区版(开源,功能受限)
# cloud: SaaS 云服务版(全功能,按量计费)
# enterprise: 企业私有化版License 控制)
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
# License 配置(企业版)
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
# 计费服务配置SaaS 版)
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
# 基础 URL用于 SSO 回调等)
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
# API Keys Configuration
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
@@ -72,6 +91,10 @@ class Settings:
# Single Sign-On configuration
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
# SSO 免登配置
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
# File Upload
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
@@ -107,6 +130,7 @@ class Settings:
# Server Configuration
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
# ========================================================================
# Internal Configuration (not in .env, used by application code)

View File

@@ -1,165 +0,0 @@
import copy
import re
from io import BytesIO
from PIL import Image
from app.core.rag.nlp import tokenize, is_english
from app.core.rag.nlp import rag_tokenizer
from app.core.rag.deepdoc.parser import PdfParser, PlainParser
from app.core.rag.deepdoc.parser.ppt_parser import RAGPptParser as PptParser
from PyPDF2 import PdfReader as pdf2_read
from app.core.rag.app.naive import by_plaintext, PARSERS
class Ppt(PptParser):
def __call__(self, fnm, from_page, to_page, callback=None):
txts = super().__call__(fnm, from_page, to_page)
callback(0.5, "Text extraction finished.")
import aspose.slides as slides
import aspose.pydrawing as drawing
imgs = []
with slides.Presentation(BytesIO(fnm)) as presentation:
for i, slide in enumerate(presentation.slides[from_page: to_page]):
try:
with BytesIO() as buffered:
slide.get_thumbnail(
0.1, 0.1).save(
buffered, drawing.imaging.ImageFormat.jpeg)
buffered.seek(0)
imgs.append(Image.open(buffered).copy())
except RuntimeError as e:
raise RuntimeError(f'ppt parse error at page {i+1}, original error: {str(e)}') from e
assert len(imgs) == len(
txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
callback(0.9, "Image extraction finished")
self.is_english = is_english(txts)
return [(txts[i], imgs[i]) for i in range(len(txts))]
class Pdf(PdfParser):
def __init__(self):
super().__init__()
def __garbage(self, txt):
txt = txt.lower().strip()
if re.match(r"[0-9\.,%/-]+$", txt):
return True
if len(txt) < 3:
return True
return False
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
from timeit import default_timer as timer
start = timer()
callback(msg="OCR started")
self.__images__(filename if not binary else binary,
zoomin, from_page, to_page, callback)
callback(msg="Page {}~{}: OCR finished ({:.2f}s)".format(from_page, min(to_page, self.total_page), timer() - start))
assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(
len(self.boxes), len(self.page_images))
res = []
for i in range(len(self.boxes)):
lines = "\n".join([b["text"] for b in self.boxes[i]
if not self.__garbage(b["text"])])
res.append((lines, self.page_images[i]))
callback(0.9, "Page {}~{}: Parsing finished".format(
from_page, min(to_page, self.total_page)))
return res, []
class PlainPdf(PlainParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, callback=None, **kwargs):
self.pdf = pdf2_read(filename if not binary else BytesIO(binary))
page_txt = []
for page in self.pdf.pages[from_page: to_page]:
page_txt.append(page.extract_text())
callback(0.9, "Parsing finished")
return [(txt, None) for txt in page_txt], []
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, vision_model=None, parser_config=None, **kwargs):
"""
The supported file formats are pdf, pptx.
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
"""
if parser_config is None:
parser_config = {}
eng = lang.lower() == "english"
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
res = []
if re.search(r"\.pptx?$", filename, re.IGNORECASE):
if not binary:
with open(filename, "rb") as f:
binary = f.read()
ppt_parser = Ppt()
for pn, (txt, img) in enumerate(ppt_parser(
filename if not binary else binary, from_page, 1000000, callback)):
d = copy.deepcopy(doc)
pn += from_page
d["image"] = img
d["doc_type_kwd"] = "image"
d["page_num_int"] = [pn + 1]
d["top_int"] = [0]
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
tokenize(d, txt, eng)
res.append(d)
return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, by_plaintext)
callback(0.1, "Start to parse.")
sections, _, _ = parser(
filename=filename,
binary=binary,
from_page=from_page,
to_page=to_page,
lang=lang,
callback=callback,
vision_model=vision_model,
pdf_cls=Pdf,
**kwargs
)
if not sections:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
for pn, (txt, img) in enumerate(sections):
d = copy.deepcopy(doc)
pn += from_page
if img:
d["image"] = img
d["page_num_int"] = [pn + 1]
d["top_int"] = [0]
d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
tokenize(d, txt, eng)
res.append(d)
return res
raise NotImplementedError(
"file type not supported yet(pptx, pdf supported)")
if __name__ == "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

View File

@@ -36,7 +36,7 @@ def generate_signed_url(
"""
if base_url is None:
# Use SERVER_IP or default to localhost
server_url = f"http://{settings.SERVER_IP}:8000/api"
server_url = settings.FILE_LOCAL_SERVER_URL
base_url = server_url
# Calculate expiration timestamp

View File

@@ -11,16 +11,12 @@ from typing import Any
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.graph_builder import GraphBuilder
from app.core.workflow.expression_evaluator import evaluate_expression
from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.nodes.enums import NodeType
# from app.core.tools.registry import ToolRegistry
# from app.core.tools.executor import ToolExecutor
# from app.core.tools.langchain_adapter import LangchainAdapter
# TOOL_MANAGEMENT_AVAILABLE = True
# from app.db import get_db
from app.core.workflow.template_renderer import render_template
logger = logging.getLogger(__name__)
@@ -55,6 +51,8 @@ class WorkflowExecutor:
self.execution_config = workflow_config.get("execution_config", {})
self.start_node_id = None
self.end_outputs: dict[str, StreamOutputConfig] = {}
self.activate_end: str | None = None
self.checkpoint_config = RunnableConfig(
configurable={
@@ -127,7 +125,6 @@ class WorkflowExecutor:
"user_id": self.user_id,
"error": None,
"error_node": None,
"streaming_buffer": {}, # 流式缓冲区
"cycle_nodes": [
node.get("id")
for node in self.workflow_config.get("nodes")
@@ -139,9 +136,8 @@ class WorkflowExecutor:
}
}
def _build_final_output(self, result, elapsed_time):
def _build_final_output(self, result, elapsed_time, final_output):
node_outputs = result.get("node_outputs", {})
final_output = self._extract_final_output(node_outputs)
token_usage = self._aggregate_token_usage(node_outputs)
conversation_id = None
for node_id, node_output in node_outputs.items():
@@ -161,6 +157,21 @@ class WorkflowExecutor:
"error": result.get("error"),
}
def _update_end_activate(self, node_id):
for node in self.end_outputs.keys():
self.end_outputs[node].update_activate(node_id)
if self.end_outputs[node].activate and self.activate_end is None:
self.activate_end = node
@staticmethod
def _trans_output_string(content):
if isinstance(content, str):
return content
elif isinstance(content, list):
return "\n".join(content)
else:
return str(content)
def build_graph(self, stream=False) -> CompiledStateGraph:
"""构建 LangGraph
@@ -173,6 +184,7 @@ class WorkflowExecutor:
stream=stream,
)
self.start_node_id = builder.start_node_id
self.end_outputs = builder.end_node_map
graph = builder.build()
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
@@ -205,14 +217,34 @@ class WorkflowExecutor:
try:
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
full_content = ''
for end_info in self.end_outputs.values():
output_template = "".join([output.literal for output in end_info.outputs])
full_content += render_template(
output_template,
result.get("variables", {}),
result.get("runtime_vars", {}),
strict=False
)
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "assistant",
"content": full_content
}
]
)
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
return self._build_final_output(result, elapsed_time)
return self._build_final_output(result, elapsed_time, full_content)
except Exception as e:
# 计算耗时(即使失败也记录)
@@ -273,6 +305,7 @@ class WorkflowExecutor:
# 3. Execute workflow
try:
chunk_count = 0
full_content = ''
async for event in graph.astream(
initial_state,
@@ -293,21 +326,27 @@ class WorkflowExecutor:
# Handle custom streaming events (chunks from nodes via stream writer)
chunk_count += 1
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
if event_type in ("message", "node_chunk"):
if event_type == "node_chunk":
node_id = data.get("node_id")
if self.activate_end:
end_info = self.end_outputs.get(self.activate_end)
if not end_info or end_info.cursor >= len(end_info.outputs):
continue
current_output = end_info.outputs[end_info.cursor]
if current_output.is_variable and current_output.depends_on_node(node_id):
if data.get("done"):
end_info.cursor += 1
else:
full_content += data.get("chunk")
yield {
"event": "message",
"data": {
"chunk": data.get("chunk")
}
}
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}"
f"- execution_id: {self.execution_id}")
yield {
"event": event_type, # "message" or "node_chunk"
"data": {
"node_id": data.get("node_id"),
"chunk": data.get("chunk"),
"full_content": data.get("full_content"),
"chunk_index": data.get("chunk_index"),
"is_prefix": data.get("is_prefix"),
"is_suffix": data.get("is_suffix"),
"conversation_id": input_data.get("conversation_id"),
}
}
elif event_type == "node_error":
yield {
"event": event_type, # "message" or "node_chunk"
@@ -376,14 +415,109 @@ class WorkflowExecutor:
elif mode == "updates":
# Handle state updates - store final state
# TODO:流式输出点
for node_id in data.keys():
self._update_end_activate(node_id)
wait = False
state = graph.get_state(config=self.checkpoint_config)
node_outputs = state.values.get("runtime_vars", {})
for _ in data.keys():
node_outputs = node_outputs | data.get(_).get("runtime_vars", {})
while self.activate_end and not wait:
message = ''
logger.info(self.activate_end)
end_info = self.end_outputs[self.activate_end]
content = end_info.outputs[end_info.cursor]
while content.activate:
if not content.is_variable:
full_content += content.literal
message += content.literal
else:
try:
chunk = evaluate_expression(
content.literal,
variables={},
node_outputs=node_outputs
)
chunk = self._trans_output_string(chunk)
message += chunk
full_content += chunk
except ValueError:
pass
end_info.cursor += 1
if end_info.cursor == len(end_info.outputs):
break
content = end_info.outputs[end_info.cursor]
if end_info.cursor != len(end_info.outputs):
wait = True
else:
self.end_outputs.pop(self.activate_end)
self.activate_end = None
for node_id in data.keys():
self._update_end_activate(node_id)
if message:
yield {
"event": "message",
"data": {
"chunk": message
}
}
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
f"- execution_id: {self.execution_id}")
result = graph.get_state(self.checkpoint_config).values
while self.activate_end:
message = ''
end_info = self.end_outputs[self.activate_end]
content = end_info.outputs[end_info.cursor]
if not content.is_variable:
message += content.literal
else:
node_outputs = result.get("runtime_vars", {})
variables = result.get("variables", {})
try:
chunk = evaluate_expression(
content.literal,
variables=variables,
node_outputs=node_outputs
)
chunk = self._trans_output_string(chunk)
message += chunk
full_content += chunk
except ValueError:
pass
end_info.cursor += 1
if end_info.cursor == len(end_info.outputs):
self.end_outputs.pop(self.activate_end)
self.activate_end = None
if self.end_outputs:
self.activate_end = list(self.end_outputs.keys())[0]
if message:
yield {
"event": "message",
"data": {
"chunk": message
}
}
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
result = graph.get_state(self.checkpoint_config).values
logger.info(result)
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "assistant",
"content": full_content
}
]
)
logger.info(
f"Workflow execution completed (streaming), "
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
@@ -392,7 +526,7 @@ class WorkflowExecutor:
# 发送 workflow_end 事件
yield {
"event": "workflow_end",
"data": self._build_final_output(result, elapsed_time)
"data": self._build_final_output(result, elapsed_time, full_content)
}
except Exception as e:
@@ -414,31 +548,6 @@ class WorkflowExecutor:
}
}
@staticmethod
def _extract_final_output(node_outputs: dict[str, Any]) -> str | None:
"""从节点输出中提取最终输出
优先级:
1. 最后一个执行的非 start/end 节点的 output
2. 如果没有节点输出,返回 None
Args:
node_outputs: 所有节点的输出
Returns:
最终输出字符串或 None
"""
if not node_outputs:
return None
# 获取最后一个节点的输出
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
if last_node_output and isinstance(last_node_output, dict):
return last_node_output.get("output")
return None
@staticmethod
def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
"""聚合所有节点的 token 使用情况
@@ -529,178 +638,3 @@ async def execute_workflow_stream(
)
async for event in executor.execute_stream(input_data):
yield event
# ==================== 工具管理系统集成 ====================
# def get_workflow_tools(workspace_id: str, user_id: str) -> list:
# """获取工作流可用的工具列表
#
# Args:
# workspace_id: 工作空间ID
# user_id: 用户ID
#
# Returns:
# 可用工具列表
# """
# if not TOOL_MANAGEMENT_AVAILABLE:
# logger.warning("工具管理系统不可用")
# return []
#
# try:
# db = next(get_db())
#
# # 创建工具注册表
# registry = ToolRegistry(db)
#
# # 注册内置工具类
# from app.core.tools.builtin import (
# DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
# )
# registry.register_tool_class(DateTimeTool)
# registry.register_tool_class(JsonTool)
# registry.register_tool_class(BaiduSearchTool)
# registry.register_tool_class(MinerUTool)
# registry.register_tool_class(TextInTool)
#
# # 获取活跃的工具
# import uuid
# tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
# active_tools = [tool for tool in tools if tool.status.value == "active"]
#
# # 转换为Langchain工具
# langchain_tools = []
# for tool_info in active_tools:
# try:
# tool_instance = registry.get_tool(tool_info.id)
# if tool_instance:
# langchain_tool = LangchainAdapter.convert_tool(tool_instance)
# langchain_tools.append(langchain_tool)
# except Exception as e:
# logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
#
# logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
# return langchain_tools
#
# except Exception as e:
# logger.error(f"获取工作流工具失败: {e}")
# return []
#
#
# class ToolWorkflowNode:
# """工具工作流节点 - 在工作流中执行工具"""
#
# def __init__(self, node_config: dict, workflow_config: dict):
# """初始化工具节点
#
# Args:
# node_config: 节点配置
# workflow_config: 工作流配置
# """
# self.node_config = node_config
# self.workflow_config = workflow_config
# self.tool_id = node_config.get("tool_id")
# self.tool_parameters = node_config.get("parameters", {})
#
# async def run(self, state: WorkflowState) -> WorkflowState:
# """执行工具节点"""
# if not TOOL_MANAGEMENT_AVAILABLE:
# logger.error("工具管理系统不可用")
# state["error"] = "工具管理系统不可用"
# return state
#
# try:
# from sqlalchemy.orm import Session
# db = next(get_db())
#
# # 创建工具执行器
# registry = ToolRegistry(db)
# executor = ToolExecutor(db, registry)
#
# # 准备参数(支持变量替换)
# parameters = self._prepare_parameters(state)
#
# # 执行工具
# result = await executor.execute_tool(
# tool_id=self.tool_id,
# parameters=parameters,
# user_id=uuid.UUID(state["user_id"]),
# workspace_id=uuid.UUID(state["workspace_id"])
# )
#
# # 更新状态
# node_id = self.node_config.get("id")
# if result.success:
# state["node_outputs"][node_id] = {
# "type": "tool",
# "tool_id": self.tool_id,
# "output": result.data,
# "execution_time": result.execution_time,
# "token_usage": result.token_usage
# }
#
# # 更新运行时变量
# if isinstance(result.data, dict):
# for key, value in result.data.items():
# state["runtime_vars"][f"{node_id}.{key}"] = value
# else:
# state["runtime_vars"][f"{node_id}.result"] = result.data
# else:
# state["error"] = result.error
# state["error_node"] = node_id
# state["node_outputs"][node_id] = {
# "type": "tool",
# "tool_id": self.tool_id,
# "error": result.error,
# "execution_time": result.execution_time
# }
#
# return state
#
# except Exception as e:
# logger.error(f"工具节点执行失败: {e}")
# state["error"] = str(e)
# state["error_node"] = self.node_config.get("id")
# return state
#
# def _prepare_parameters(self, state: WorkflowState) -> dict:
# """准备工具参数(支持变量替换)"""
# parameters = {}
#
# for key, value in self.tool_parameters.items():
# if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# # 变量替换
# var_path = value[2:-1]
#
# # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
# if "." in var_path:
# parts = var_path.split(".")
# current = state.get("variables", {})
#
# for part in parts:
# if isinstance(current, dict) and part in current:
# current = current[part]
# else:
# # 尝试从运行时变量获取
# runtime_key = ".".join(parts)
# current = state.get("runtime_vars", {}).get(runtime_key, value)
# break
#
# parameters[key] = current
# else:
# # 简单变量
# variables = state.get("variables", {})
# parameters[key] = variables.get(var_path, value)
# else:
# parameters[key] = value
#
# return parameters
#
#
# # 注册工具节点到NodeFactory如果存在
# try:
# from app.core.workflow.nodes import NodeFactory
# if hasattr(NodeFactory, 'register_node_type'):
# NodeFactory.register_node_type("tool", ToolWorkflowNode)
# logger.info("工具节点已注册到工作流系统")
# except Exception as e:
# logger.warning(f"注册工具节点失败: {e}")

View File

@@ -1,12 +1,15 @@
import logging
import re
import uuid
from collections import defaultdict
from functools import lru_cache
from typing import Any
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import START, END
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.types import Send
from pydantic import BaseModel, Field
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
@@ -15,6 +18,153 @@ from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
logger = logging.getLogger(__name__)
class OutputContent(BaseModel):
"""
Represents a single output segment of an End node.
An output segment can be either:
- literal text (static string)
- a variable placeholder (e.g. {{ node.field }})
Each segment has its own activation state, which is especially
important in stream mode.
"""
literal: str = Field(
...,
description="Raw output content. Can be literal text or a variable placeholder."
)
activate: bool = Field(
...,
description=(
"Whether this output segment is currently active.\n"
"- True: allowed to be emitted/output\n"
"- False: blocked until activated by branch control"
)
)
is_variable: bool = Field(
...,
description=(
"Whether this segment represents a variable placeholder.\n"
"True -> variable (e.g. {{ node.field }})\n"
"False -> literal text"
)
)
def depends_on_node(self, node_id: str) -> bool:
"""
Check if this output segment depends on a specific node's variable.
This method examines the `literal` of the output segment to see if it
contains a variable placeholder referencing the given node in the form:
{{ node_id.field_name }}
It uses a regular expression to match the exact node ID, avoiding
false positives from substring matches (e.g., 'node1' should not match 'node10').
Args:
node_id (str): The ID of the node to check for in this segment's variable placeholders.
Returns:
bool:
- True if the segment contains a variable referencing the given node.
- False otherwise.
Example:
literal = "{{node1.name}}"
depends_on_node("node1") -> True
depends_on_node("node2") -> False
Usage:
This method is primarily used in stream mode to determine whether
a particular variable output segment should be activated when a
specific upstream node completes execution.
"""
variable_pattern = rf"\{{\{{\s*{re.escape(node_id)}\.[a-zA-Z0-9_]+\s*\}}\}}"
pattern = re.compile(variable_pattern)
match = pattern.search(self.literal)
if match:
return True
return False
class StreamOutputConfig(BaseModel):
"""
Streaming output configuration for an End node.
This structure controls:
- whether the End node output is globally active
- which upstream branch nodes are responsible for activation
- how each output segment behaves in streaming mode
"""
activate: bool = Field(
...,
description=(
"Global activation state of the End node output.\n"
"If False, no output should be emitted until all control nodes are resolved."
)
)
control_nodes: list[str] = Field(
...,
description=(
"List of upstream branch node IDs that control this End node.\n"
"Each node must signal completion before output becomes active."
)
)
outputs: list[OutputContent] = Field(
...,
description="Ordered list of output segments parsed from the output template."
)
cursor: int = Field(
...,
description=(
"Streaming cursor index.\n"
"Indicates how many output segments have already been emitted."
)
)
def update_activate(self, node_id):
"""
Update activation state based on an upstream node completion.
This method is typically called when a branch/control node finishes execution.
Behavior:
1. If the node is a control node:
- Remove it from `control_nodes`
- If all control nodes are resolved, activate the entire output
2. Activate variable output segments that depend on this node:
- If an output segment is a variable
- And its literal references the completed node_id
- Mark that segment as active
"""
# Case 1: resolve control branch dependency
if node_id in self.control_nodes:
self.control_nodes.remove(node_id)
# All branch constraints resolved → enable output
if not self.control_nodes:
self.activate = True
# Case 2: activate variable segments related to this node
for i in range(len(self.outputs)):
if (
self.outputs[i].is_variable
and self.outputs[i].depends_on_node(node_id)
):
self.outputs[i].activate = True
class GraphBuilder:
def __init__(
self,
@@ -29,6 +179,12 @@ class GraphBuilder:
self.start_node_id = None
self.end_node_ids = []
self.node_map = {node["id"]: node for node in self.nodes}
self.end_node_map: dict[str, StreamOutputConfig] = {}
self._find_upstream_branch_node = lru_cache(
maxsize=len(self.nodes) * 2
)(self._find_upstream_branch_node)
self._analyze_end_node_output()
self.graph = StateGraph(WorkflowState)
self.add_nodes()
@@ -43,79 +199,182 @@ class GraphBuilder:
def edges(self) -> list[dict[str, Any]]:
return self.workflow_config.get("edges", [])
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""
Analyze the prefix configuration for End nodes.
def get_node_type(self, node_id: str) -> str:
"""Retrieve the type of node given its ID.
This function scans each End node's output template, identifies
references to its direct upstream nodes, and extracts the prefix
string appearing before the first reference.
Args:
node_id (str): The unique identifier of the node.
Returns:
tuple:
- dict[str, str]: Mapping from upstream node ID to its End node prefix
- set[str]: Set of node IDs that are directly adjacent to End nodes and referenced
str: The type of the node.
Raises:
RuntimeError: If no node with the given `node_id` exists.
"""
import re
try:
return self.node_map[node_id]["type"]
except KeyError:
raise RuntimeError(f"Node not found: Id={node_id}")
prefixes = {}
adjacent_and_referenced = set() # Record nodes directly adjacent to End and referenced
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[str]]:
"""Find upstream branch nodes for a given target node in the workflow graph.
# 找到所有 End 节点
This method identifies all upstream control (branch) nodes that can affect
the execution of `target_node`. If `target_node` is reachable from a start
node (i.e., a node with no upstream nodes), the method returns an empty tuple.
The function distinguishes between branch nodes (defined in `BRANCH_NODES`)
and non-branch nodes, recursively traversing upstream through non-branch
nodes. If any non-branch upstream path does not lead to a branch node,
the result will indicate that no valid upstream branch node exists.
Args:
target_node (str): The identifier of the target node.
Returns:
tuple[bool, tuple[str]]:
- has_branch (bool): True if all upstream non-branch paths lead to at least
one branch node; False if any path reaches a start node without a branch.
- branch_nodes (tuple[str]): A deduplicated tuple of upstream branch node IDs
affecting `target_node`. Returns an empty tuple if `has_branch` is False.
"""
source_nodes = [
edge.get("source")
for edge in self.edges
if edge.get("target") == target_node
]
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
return False, tuple()
branch_nodes = []
non_branch_nodes = []
for node_id in source_nodes:
if self.get_node_type(node_id) in BRANCH_NODES:
branch_nodes.append(node_id)
else:
non_branch_nodes.append(node_id)
has_branch = True
for node_id in non_branch_nodes:
node_has_branch, nodes = self._find_upstream_branch_node(node_id)
has_branch = has_branch and node_has_branch
if not has_branch:
break
branch_nodes.extend(nodes)
if not has_branch:
branch_nodes = []
return has_branch, tuple(set(branch_nodes))
def _analyze_end_node_output(self):
"""
Analyze output templates of all End nodes and generate StreamOutputConfig.
This method is responsible for parsing the `output` field of End nodes,
splitting literal text and variable placeholders (e.g. {{ node.field }}),
and determining whether each output segment should be activated immediately
or controlled by upstream branch nodes.
In stream mode:
- If the End node is controlled by any upstream branch node, the output
will be initially inactive and controlled by those branch nodes.
- Otherwise, the output is activated immediately.
In non-stream mode:
- All outputs are activated by default.
"""
# Collect all End nodes in the workflow
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
# Iterate through each End node to analyze its output
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
config = end_node.get("config", {})
output = config.get("output")
logger.info(f"[Prefix Analysis] End node {end_node_id} template: {output_template}")
if not output_template:
# Skip End nodes without output configuration
if not output:
continue
# Find all node references in the template
# Matches {{node_id.xxx}} or {{ node_id.xxx }} format (allowing spaces)
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
# Regex to split output into:
# - variable placeholders: {{ ... }}
# - normal literal text
#
# Example:
# "Hello {{user.name}}!" ->
# ["Hello ", "{{user.name}}", "!"]
pattern = r'\{\{.*?\}\}|[^{}]+'
logger.info(f"[Prefix Analysis] 模板中找到 {len(matches)} 个节点引用")
# Strict variable format: {{ node_id.field_name }}
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
variable_pattern = re.compile(variable_pattern_string)
# Identify all direct upstream nodes connected to the End node
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
# Split output into ordered segments
output_template = list(re.findall(pattern, output))
logger.info(f"[Prefix Analysis] Direct upstream nodes of End node: {direct_upstream_nodes}")
# Determine whether each segment is literal text
# True -> literal (can be directly output)
# False -> variable placeholder (needs runtime value)
output_flag = [
not bool(variable_pattern.match(item))
for item in output_template
]
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[Prefix Analysis] Checking reference: {referenced_node_id}")
# Stream mode: output activation depends on upstream branch nodes
if self.stream:
# Find upstream branch nodes that can control this End node
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
# Build StreamOutputConfig for this End node
self.end_node_map[end_node_id] = StreamOutputConfig(
# If there is no upstream branch, output is active immediately
activate=not has_branch,
logger.info(f"[Prefix Analysis] "
f"✅ Found reference to direct upstream node {referenced_node_id}, prefix: '{prefix}'")
# Branch nodes that control activation of this End node
control_nodes=list(control_nodes),
# 标记这个节点为"相邻且被引用"
adjacent_and_referenced.add(referenced_node_id)
# Convert output segments into OutputContent objects
outputs=list(
[
OutputContent(
literal=output_string,
# Literal text can be activated immediately unless blocked by branch
activate=activate,
# Variable segments are marked explicitly
is_variable=not activate
)
for output_string, activate in zip(output_template, output_flag)
]
),
# Cursor for streaming output (initially 0)
cursor=0
)
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
f"activate: {not has_branch}, "
f"control_nodes: {control_nodes},"
f"output: {output_template},"
f"output_activate: {output_flag}")
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"[Prefix Analysis] "
f"✅ Assign prefix for node {referenced_node_id}: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[Prefix Analysis] Final prefixes: {prefixes}")
logger.info(f"[Prefix Analysis] Nodes adjacent to End and referenced: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
# Non-stream mode: all outputs are activated by default
else:
self.end_node_map[end_node_id] = StreamOutputConfig(
activate=True,
control_nodes=[],
outputs=list(
[
OutputContent(
literal=output_string,
activate=True,
is_variable=not activate
)
for output_string, activate in zip(output_template, output_flag)
]
),
cursor=0
)
def add_nodes(self):
"""Add all nodes from the workflow configuration to the state graph.
@@ -135,9 +394,6 @@ class GraphBuilder:
Returns:
None
"""
# Analyze End node prefixes if in stream mode
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
@@ -171,17 +427,6 @@ class GraphBuilder:
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# Inject End node prefix configuration if in stream mode
if self.stream and node_id in end_prefixes:
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"Injected End prefix for node {node_id}")
# Mark nodes as adjacent and referenced to End node in stream mode
if self.stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"Node {node_id} marked as adjacent and referenced to End node")
# Wrap node's run method to avoid closure issues
if self.stream:
# Stream mode: create an async generator function
@@ -261,6 +506,7 @@ class GraphBuilder:
for source_node, branches in conditional_edges.items():
def make_router(src, branch_list):
"""reate a router function for each source node that routes to a NOP node for later merging."""
def make_branch_node(node_name, targets):
def node(s):
# NOTE: NOP NODE MUST NOT MODIFY STATE

View File

@@ -67,10 +67,6 @@ class WorkflowState(TypedDict):
error: str | None
error_node: str | None
# Streaming buffer (stores real-time streaming output of nodes)
# Format: {node_id: {"chunks": [...], "full_content": "..."}}
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# node activate status
activate: Annotated[dict[str, bool], merge_activate_state]
@@ -300,7 +296,7 @@ class BaseNode(ABC):
"""
if not self.check_activate(state):
yield self.trans_activate(state)
logger.info(f"跳过节点{self.node_id}")
logger.info(f"jump node: {self.node_id}")
return
import time
@@ -313,19 +309,6 @@ class BaseNode(ABC):
# Get LangGraph's stream writer for sending custom data
writer = get_stream_writer()
# Check if this is an End node
# End nodes CAN send chunks (for suffix), but only after LLM content
is_end_node = self.node_type == "end"
# Check if this node is adjacent to End node (for message type)
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
logger.debug(
f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
# Accumulate complete result (for final wrapping)
chunks = []
final_result = None
@@ -340,66 +323,25 @@ class BaseNode(ABC):
raise TimeoutError()
# Check if it's a completion marker
if isinstance(item, dict) and item.get("__final__"):
if item.get("__final__"):
final_result = item["result"]
elif isinstance(item, str):
# String is a chunk
else:
chunk_count += 1
chunks.append(item)
full_content = "".join(chunks)
content = str(item.get("chunk"))
done = item.get("done", False)
chunks.append(content)
# Send chunks for all nodes (including End nodes for suffix)
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {content[:50]}...")
# 1. Send via stream writer (for real-time client updates)
writer({
"type": chunk_type, # "message" or "node_chunk"
"type": "node_chunk",
"node_id": self.node_id,
"chunk": item,
"full_content": full_content,
"chunk_index": chunk_count
"chunk": content,
"done": done
})
# 2. Update streaming buffer in state (for downstream nodes)
# Only non-End nodes need streaming buffer
if not is_end_node:
yield {
"streaming_buffer": {
self.node_id: {
"full_content": full_content,
"chunk_count": chunk_count,
"is_complete": False
}
}
}
else:
# Other types are also treated as chunks
chunk_count += 1
chunk_str = str(item)
chunks.append(chunk_str)
full_content = "".join(chunks)
# Send chunks for all nodes
writer({
"type": chunk_type, # "message" or "node_chunk"
"node_id": self.node_id,
"chunk": chunk_str,
"full_content": full_content,
"chunk_index": chunk_count
})
# Only non-End nodes need streaming buffer
if not is_end_node:
yield {
"streaming_buffer": {
self.node_id: {
"full_content": full_content,
"chunk_count": chunk_count,
"is_complete": False
}
}
}
elapsed_time = time.time() - start_time
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
@@ -426,16 +368,6 @@ class BaseNode(ABC):
"looping": state["looping"]
}
# Add streaming buffer for non-End nodes
if not is_end_node:
state_update["streaming_buffer"] = {
self.node_id: {
"full_content": "".join(chunks),
"chunk_count": chunk_count,
"is_complete": True # Mark as complete
}
}
# Finally yield state update
# LangGraph will merge this into state
yield state_update | self.trans_activate(state)

View File

@@ -1,3 +1,3 @@
from app.core.workflow.nodes.code.node import CodeNode
__all__ = ["CodeNode"]
__all__ = ["CodeNode"]

View File

@@ -7,7 +7,6 @@ from textwrap import dedent
from typing import Any
import httpx
from sympy.physics.vector import vlatex
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.nodes.base_config import VariableType

View File

@@ -1,5 +1,4 @@
import asyncio
import copy
import logging
import re
from typing import Any

View File

@@ -6,7 +6,6 @@ from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
from app.core.workflow.nodes.enums import NodeType

View File

@@ -5,10 +5,8 @@ End 节点实现
"""
import logging
import re
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__)
@@ -37,24 +35,8 @@ class EndNode(BaseNode):
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state, strict=False)
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": output
}
])
else:
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
])
output = "工作流已完成"
output = ""
# 统计信息(用于日志)
node_outputs = state.get("node_outputs", {})
@@ -63,274 +45,3 @@ class EndNode(BaseNode):
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
return output
def _extract_referenced_nodes(self, template: str) -> list[str]:
"""从模板中提取引用的节点 ID
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
Args:
template: 模板字符串
Returns:
引用的节点 ID 列表
"""
# 匹配 {{node_id.xxx}} 格式
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
matches = re.findall(pattern, template)
return list(set(matches)) # 去重
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
"""解析模板,分离静态文本和动态引用
例如:'你好 {{llm.output}}, 这是后缀'
返回:[
{"type": "static", "content": "你好 "},
{"type": "dynamic", "node_id": "llm", "field": "output"},
{"type": "static", "content": ", 这是后缀"}
]
Args:
template: 模板字符串
state: 工作流状态
Returns:
模板部分列表
"""
import re
parts = []
last_end = 0
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
for match in re.finditer(pattern, template):
start, end = match.span()
# 添加前面的静态文本
if start > last_end:
static_text = template[last_end:start]
if static_text:
parts.append({"type": "static", "content": static_text})
# 解析动态引用
ref = match.group(1).strip()
# 检查是否是节点引用(如 llm.output 或 llm_qa.output
if '.' in ref:
node_id, field = ref.split('.', 1)
parts.append({
"type": "dynamic",
"node_id": node_id,
"field": field,
"raw": ref
})
else:
# 其他引用(如 {{var.xxx}}),当作静态处理
# 直接渲染这部分
rendered = self._render_template(f"{{{{{ref}}}}}", state)
parts.append({"type": "static", "content": rendered})
last_end = end
# 添加最后的静态文本
if last_end < len(template):
static_text = template[last_end:]
if static_text:
parts.append({"type": "static", "content": static_text})
return parts
async def execute_stream(self, state: WorkflowState):
"""Execute End node business logic (streaming)
Smart output strategy:
1. Check if template references a direct upstream LLM node
2. If yes, only output the part AFTER that reference (suffix)
3. Prefix and LLM content have already been sent during LLM node streaming
Note: Only LLM nodes get this special treatment. Other node types output normally.
Example: '{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- Direct upstream LLM node is llm_qa
- Prefix '{{start.test}}hahaha ' was sent before LLM node streaming
- LLM content was streamed during LLM node execution
- End node only outputs ' lalalalala a' (suffix, sent as one chunk)
Args:
state: Workflow state
Yields:
Completion marker
"""
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
# 获取配置的输出模板
output_template = self.config.get("output")
if not output_template:
output = "工作流已完成"
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End node output uses message type
"node_id": self.node_id,
"chunk": "",
"full_content": output,
"chunk_index": 1,
"is_suffix": False
})
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
}
])
yield {"__final__": True, "result": output}
return
# Find direct upstream LLM nodes
direct_upstream_llm_nodes = []
for edge in self.workflow_config.get("edges", []):
if edge.get("target") == self.node_id:
source_node_id = edge.get("source")
# Check if the source node is an LLM node
for node in self.workflow_config.get("nodes", []):
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
direct_upstream_llm_nodes.append(source_node_id)
break
logger.info(f"节点 {self.node_id} 的直接上游 LLM 节点: {direct_upstream_llm_nodes}")
# Parse template parts
parts = self._parse_template_parts(output_template, state)
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
for i, part in enumerate(parts):
logger.info(f"[模板解析] part[{i}]: {part}")
# Find the first reference to a direct upstream LLM node
upstream_llm_ref_index = None
for i, part in enumerate(parts):
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_llm_nodes:
upstream_llm_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游 LLM 节点 {part['node_id']} 的引用,索引: {i}")
break
if upstream_llm_ref_index is None:
# No reference to direct upstream LLM node, output complete template content
output = self._render_template(output_template, state, strict=False)
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
# Send complete content via writer (as a single message chunk)
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End node output uses message type
"node_id": self.node_id,
"chunk": output,
"full_content": output,
"chunk_index": 1,
"is_suffix": False
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": output
}
])
# yield completion marker
yield {"__final__": True, "result": output}
return
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
logger.info(
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
# Collect suffix parts
suffix_parts = []
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_llm_ref_index + 1}{len(parts) - 1}")
for i in range(upstream_llm_ref_index + 1, len(parts)):
part = parts[i]
logger.info(f"[后缀调试] 处理 part[{i}]: {part}")
if part["type"] == "static":
# 静态文本
logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'")
suffix_parts.append(part["content"])
elif part["type"] == "dynamic":
# Other dynamic references (if there are multiple references)
node_id = part["node_id"]
field = part["field"]
# Use VariablePool to get variable value
pool = self.get_variable_pool(state)
try:
# Try to get variable value with default empty string
content = pool.get([node_id, field], default="")
logger.info(f"[后缀调试] 获取变量 {node_id}.{field} 成功: '{content}'")
except Exception as e:
logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}")
content = ""
# Convert to string if not None
suffix_parts.append(str(content) if content is not None else "")
# 拼接后缀
suffix = "".join(suffix_parts)
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state, strict=False)
state['messages'].extend([
{
"role": "user",
"content": self.get_variable("sys.message", state)
},
{
"role": "assistant",
"content": full_output
}
])
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
logger.info(f"[后缀调试] 后缀是否为空: {not suffix}")
if suffix:
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})")
# 一次性输出后缀(作为单个 chunk
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
# 而是通过 writer 直接发送
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End 节点的输出使用 message 类型
"node_id": self.node_id,
"chunk": suffix,
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
"chunk_index": 1,
"is_suffix": True
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else:
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
# 统计信息
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
# yield 完成标记(包含完整输出)
yield {"__final__": True, "result": full_output}

View File

@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: IfElseNodeConfig | None= None
self.typed_config: IfElseNodeConfig | None = None
@staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:

View File

@@ -7,18 +7,18 @@ LLM 节点实现
import logging
import re
from typing import Any
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from langchain_core.messages import AIMessage
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.db import get_db_context
from app.models import ModelType
from app.services.model_service import ModelConfigService
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = logging.getLogger(__name__)
@@ -231,42 +231,14 @@ class LLMNode(BaseNode):
文本片段chunk或完成标记
"""
self.typed_config = LLMNodeConfig(**self.config)
from langgraph.config import get_stream_writer
llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# 检查是否有注入的 End 节点前缀配置
writer = get_stream_writer()
end_prefix = getattr(self, '_end_node_prefix', None)
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
if end_prefix:
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
if end_prefix:
# 渲染前缀(可能包含其他变量)
try:
rendered_prefix = self._render_template(end_prefix, state)
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
# 提前发送 End 节点的前缀(使用 "message" 类型)
writer({
"type": "message", # End 相关的内容都是 message 类型
"node_id": "end", # 标记为 end 节点的输出
"chunk": rendered_prefix,
"full_content": rendered_prefix,
"chunk_index": 0,
"is_prefix": True # 标记这是前缀
})
except Exception as e:
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
# 累积完整响应
full_response = ""
last_chunk = None
chunk_count = 0
# 调用 LLM流式支持字符串或消息列表
@@ -284,12 +256,19 @@ class LLMNode(BaseNode):
# 只有当内容不为空时才处理
if content:
full_response += content
last_chunk = chunk
chunk_count += 1
# 流式返回每个文本片段
yield content
yield {
"__final__": False,
"chunk": content
}
yield {
"__final__": False,
"chunk": "",
"done": True
}
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# 构建完整的 AIMessage包含元数据

View File

@@ -1,8 +1,6 @@
import uuid
from uuid import UUID
from pydantic import Field
from typing import Literal
from app.core.workflow.nodes.base_config import BaseNodeConfig
@@ -12,7 +10,7 @@ class MemoryReadNodeConfig(BaseNodeConfig):
...
)
config_id: UUID = Field(
config_id: UUID | int = Field(
...
)

View File

@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
return await MemoryAgentService().read_memory(
end_user_id=end_user_id,
message=self._render_template(self.typed_config.message, state),
config_id=str(self.typed_config.config_id),
config_id=self.typed_config.config_id,
search_switch=self.typed_config.search_switch,
history=[],
db=db,

View File

@@ -5,6 +5,7 @@ from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig
class ClassifierConfig(BaseModel):
"""分类器节点配置"""
@@ -13,7 +14,7 @@ class ClassifierConfig(BaseModel):
class QuestionClassifierNodeConfig(BaseNodeConfig):
"""问题分类器节点配置"""
model_id: uuid.UUID = Field(..., description="LLM模型ID")
input_variable: str = Field(default="{{sys.message}}", description="输入变量选择器(用户问题)")
user_supplement_prompt: Optional[str] = Field(default=None, description="用户补充提示词,额外分类指令")

View File

@@ -18,30 +18,30 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
class QuestionClassifierNode(BaseNode):
"""问题分类器节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {}
def _get_llm_instance(self) -> RedBearLLM:
"""获取LLM实例"""
with get_db_read() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id)
if not config:
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
api_config = config.api_keys[0]
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
base_url = api_config.api_base
model_type = config.type
return RedBearLLM(
RedBearModelConfig(
model_name=model_name,
@@ -64,7 +64,7 @@ class QuestionClassifierNode(BaseNode):
case_tag = f"{DEFAULT_CASE_PREFIX}{idx}"
category_map[category_name] = case_tag
return category_map
async def execute(self, state: WorkflowState) -> dict:
"""执行问题分类"""
self.typed_config = QuestionClassifierNodeConfig(**self.config)
@@ -74,11 +74,12 @@ class QuestionClassifierNode(BaseNode):
categories = self.typed_config.categories or []
category_names = [class_item.class_name.strip() for class_item in categories]
category_count = len(category_names)
if not question:
logger.warning(
f"节点 {self.node_id} 未获取到输入问题,使用默认分支"
f"默认分支{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count}"
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE}"
f"分类总数: {category_count})"
)
# 若分类列表为空返回默认unknown分支否则返回CASE1
if category_count > 0:

View File

@@ -1,4 +1,4 @@
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.tool.node import ToolNode
__all__ = ["ToolNode", "ToolNodeConfig"]
__all__ = ["ToolNode", "ToolNodeConfig"]

View File

@@ -16,11 +16,11 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
class ToolNode(BaseNode):
"""工具节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: ToolNodeConfig | None = None
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行工具"""
self.typed_config = ToolNodeConfig(**self.config)
@@ -28,21 +28,21 @@ class ToolNode(BaseNode):
tenant_id = self.get_variable("sys.tenant_id", state)
user_id = self.get_variable("sys.user_id", state)
workspace_id = self.get_variable("sys.workspace_id", state)
# 如果没有租户ID尝试从工作流ID获取
if not tenant_id:
if workspace_id:
from app.repositories.tool_repository import ToolRepository
with get_db_read() as db:
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
if not tenant_id:
logger.error(f"节点 {self.node_id} 缺少租户ID")
return {
"success": False,
"data": "缺少租户ID"
}
# 渲染工具参数
rendered_parameters = {}
for param_name, param_template in self.typed_config.tool_parameters.items():
@@ -55,9 +55,9 @@ class ToolNode(BaseNode):
# 非模板参数(数字/布尔/普通字符串)直接保留原值
rendered_value = param_template
rendered_parameters[param_name] = rendered_value
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
# 执行工具
with get_db_read() as db:
tool_service = ToolService(db)
@@ -79,7 +79,7 @@ class ToolNode(BaseNode):
else:
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
return {
"data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False),
"data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False),
"error_code": result.error_code,
"execution_time": result.execution_time
}
}

View File

@@ -10,7 +10,7 @@ class MemoryConfig(Base):
# 主键
config_id = Column(UUID(as_uuid=True), primary_key=True, comment="配置ID")
config_id_old = Column(Integer, nullable=True, comment="备份的配置ID")
# 基本信息
config_name = Column(String, nullable=False, comment="配置名称")
config_desc = Column(String, nullable=True, comment="配置描述")

View File

@@ -16,6 +16,10 @@ class Tenants(Base):
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
is_active = Column(Boolean, default=True)
# SSO 外部关联字段
external_id = Column(String(100), nullable=True, index=True) # 外部企业ID
external_source = Column(String(50), nullable=True) # 来源系统
# Relationship to users - one tenant has many users
users = relationship("User", back_populates="tenant")

View File

@@ -18,6 +18,10 @@ class User(Base):
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
last_login_at = Column(DateTime, nullable=True) # 最后登录时间,可为空
# SSO 外部关联字段
external_id = Column(String(100), nullable=True) # 外部用户ID
external_source = Column(String(50), nullable=True) # 来源系统
current_workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=True) # 当前工作空间ID可为空
# Foreign key to tenant - each user belongs to exactly one tenant

View File

@@ -0,0 +1,74 @@
# app/plugins/__init__.py
"""
插件系统 - 支持开源核心 + 闭源增值模块
使用方式:
1. 开源版community基础功能
2. 商业版enterprise加载 premium 包中的高级实现
"""
import os
from typing import Dict, Any, Optional
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# 版本标识
EDITION = os.environ.get("EDITION", "community")
IS_ENTERPRISE = EDITION == "enterprise"
# 插件注册表
_plugins: Dict[str, Any] = {}
# 路由注册表(用于动态注册闭源模块的路由)
_routers: list = []
def is_enterprise() -> bool:
"""是否为商业版"""
return IS_ENTERPRISE
def list_plugins() -> list:
"""列出所有已注册插件"""
return list(_plugins.keys())
def register_plugin(name: str, instance: Any):
"""注册插件"""
_plugins[name] = instance
logger.info(f"插件已注册: {name}")
def get_plugin(name: str) -> Optional[Any]:
"""获取插件实例"""
return _plugins.get(name)
def register_router(router, prefix: str = "", tags: list = None):
"""注册路由(供闭源模块使用)"""
_routers.append({
"router": router,
"prefix": prefix,
"tags": tags or []
})
logger.info(f"路由已注册: {prefix}")
def get_registered_routers() -> list:
"""获取所有注册的路由"""
return _routers
def register_premium_routers(app):
"""
注册 premium 模块的路由到 FastAPI app
在商业版 main.py 中调用
"""
for router_info in _routers:
app.include_router(
router_info["router"],
prefix=f"/api{router_info['prefix']}",
tags=router_info["tags"]
)
logger.info(f"Premium 路由已挂载: /api{router_info['prefix']}")

View File

@@ -528,7 +528,8 @@ class WorkflowService:
self.conversation_service.add_message(
conversation_id=conversation_id_uuid,
role=message["role"],
content=message["content"]
content=message["content"],
meta_data=None if message["role"] == "user" else {"usage": token_usage}
)
logger.info(f"Workflow Run Success, "
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
@@ -678,7 +679,8 @@ class WorkflowService:
self.conversation_service.add_message(
conversation_id=conversation_id_uuid,
role=message["role"],
content=message["content"]
content=message["content"],
meta_data=None if message["role"] == "user" else {"usage": token_usage}
)
logger.info(f"Workflow Run Success, "
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")

View File

@@ -15,6 +15,7 @@ services:
networks:
- default
- celery
- sandbox
depends_on:
- worker-memory
- worker-document
@@ -63,5 +64,16 @@ services:
depends_on:
- worker-memory
sandbox:
image: redbear_sandbox:latest
container_name: sandbox
ports:
- "8194"
command: /code/.venv/bin/python main.py
restart: unless-stopped
networks:
- sandbox
networks:
celery:
sandbox:

View File

@@ -75,6 +75,7 @@ ENABLE_SINGLE_SESSION=
MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024
FILE_PATH=/files
FILE_LOCAL_SERVER_URL="http://localhost:8000/api"
# Storage Backend Configuration
# Supported values: local, oss, s3
# Default: local

View File

@@ -0,0 +1,57 @@
"""202601271517
Revision ID: 75f0ec80e50b
Revises: 325b759cd66b
Create Date: 2026-01-27 15:26:48.696600
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '75f0ec80e50b'
down_revision: Union[str, None] = '325b759cd66b'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('memory_config', 'config_id',
existing_type=sa.UUID(),
comment='配置ID',
existing_nullable=False)
op.alter_column('memory_config', 'config_id_old',
existing_type=sa.INTEGER(),
comment='备份的配置ID',
existing_comment='配置ID',
existing_nullable=True)
op.add_column('tenants', sa.Column('external_id', sa.String(length=100), nullable=True))
op.add_column('tenants', sa.Column('external_source', sa.String(length=50), nullable=True))
op.create_index(op.f('ix_tenants_external_id'), 'tenants', ['external_id'], unique=False)
op.add_column('users', sa.Column('external_id', sa.String(length=100), nullable=True))
op.add_column('users', sa.Column('external_source', sa.String(length=50), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('users', 'external_source')
op.drop_column('users', 'external_id')
op.drop_index(op.f('ix_tenants_external_id'), table_name='tenants')
op.drop_column('tenants', 'external_source')
op.drop_column('tenants', 'external_id')
op.alter_column('memory_config', 'config_id_old',
existing_type=sa.INTEGER(),
comment='配置ID',
existing_comment='备份的配置ID',
existing_nullable=True)
op.alter_column('memory_config', 'config_id',
existing_type=sa.UUID(),
comment=None,
existing_comment='配置ID',
existing_nullable=False)
# ### end Alembic commands ###

View File

@@ -88,7 +88,6 @@ dependencies = [
"cachetools==6.2.1",
"ruamel.yaml==0.18.10",
"strenum==0.4.15",
"aspose-slides==24.12.0",
"opencv-python==4.10.0.84",
"numpy>=1.26.0,<2.0.0",
"huggingface-hub==0.25.2",

View File

@@ -83,7 +83,6 @@ olefile==0.47
cachetools==6.2.1
ruamel.yaml==0.18.10
strenum==0.4.15
aspose-slides==24.12.0
opencv-python==4.10.0.84
numpy>=1.26.0,<2.0.0
huggingface-hub==0.25.2