Merge branch 'refs/heads/develop' into fix/memory_bug_fix
This commit is contained in:
@@ -310,7 +310,7 @@ async def get_file_url(
|
||||
try:
|
||||
if permanent:
|
||||
# Generate permanent URL (no expiration check)
|
||||
server_url = f"http://{settings.SERVER_IP}:8000/api"
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
url = f"{server_url}/storage/permanent/{file_id}"
|
||||
return success(
|
||||
data={
|
||||
|
||||
@@ -9,6 +9,25 @@ load_dotenv()
|
||||
|
||||
|
||||
class Settings:
|
||||
# ========================================================================
|
||||
# Deployment Mode Configuration
|
||||
# ========================================================================
|
||||
# community: 社区版(开源,功能受限)
|
||||
# cloud: SaaS 云服务版(全功能,按量计费)
|
||||
# enterprise: 企业私有化版(License 控制)
|
||||
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
|
||||
|
||||
# License 配置(企业版)
|
||||
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
|
||||
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
|
||||
|
||||
# 计费服务配置(SaaS 版)
|
||||
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
|
||||
|
||||
# 基础 URL(用于 SSO 回调等)
|
||||
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
|
||||
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
|
||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||
# API Keys Configuration
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
@@ -72,6 +91,10 @@ class Settings:
|
||||
|
||||
# Single Sign-On configuration
|
||||
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||
|
||||
# SSO 免登配置
|
||||
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
|
||||
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
|
||||
|
||||
# File Upload
|
||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||
@@ -107,6 +130,7 @@ class Settings:
|
||||
|
||||
# Server Configuration
|
||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||
FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||
|
||||
# ========================================================================
|
||||
# Internal Configuration (not in .env, used by application code)
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
import copy
|
||||
import re
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
from app.core.rag.nlp import tokenize, is_english
|
||||
from app.core.rag.nlp import rag_tokenizer
|
||||
from app.core.rag.deepdoc.parser import PdfParser, PlainParser
|
||||
from app.core.rag.deepdoc.parser.ppt_parser import RAGPptParser as PptParser
|
||||
from PyPDF2 import PdfReader as pdf2_read
|
||||
from app.core.rag.app.naive import by_plaintext, PARSERS
|
||||
|
||||
class Ppt(PptParser):
|
||||
def __call__(self, fnm, from_page, to_page, callback=None):
|
||||
txts = super().__call__(fnm, from_page, to_page)
|
||||
|
||||
callback(0.5, "Text extraction finished.")
|
||||
import aspose.slides as slides
|
||||
import aspose.pydrawing as drawing
|
||||
imgs = []
|
||||
with slides.Presentation(BytesIO(fnm)) as presentation:
|
||||
for i, slide in enumerate(presentation.slides[from_page: to_page]):
|
||||
try:
|
||||
with BytesIO() as buffered:
|
||||
slide.get_thumbnail(
|
||||
0.1, 0.1).save(
|
||||
buffered, drawing.imaging.ImageFormat.jpeg)
|
||||
buffered.seek(0)
|
||||
imgs.append(Image.open(buffered).copy())
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(f'ppt parse error at page {i+1}, original error: {str(e)}') from e
|
||||
assert len(imgs) == len(
|
||||
txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
|
||||
callback(0.9, "Image extraction finished")
|
||||
self.is_english = is_english(txts)
|
||||
return [(txts[i], imgs[i]) for i in range(len(txts))]
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __garbage(self, txt):
|
||||
txt = txt.lower().strip()
|
||||
if re.match(r"[0-9\.,%/-]+$", txt):
|
||||
return True
|
||||
if len(txt) < 3:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
from timeit import default_timer as timer
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(filename if not binary else binary,
|
||||
zoomin, from_page, to_page, callback)
|
||||
callback(msg="Page {}~{}: OCR finished ({:.2f}s)".format(from_page, min(to_page, self.total_page), timer() - start))
|
||||
assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(
|
||||
len(self.boxes), len(self.page_images))
|
||||
res = []
|
||||
for i in range(len(self.boxes)):
|
||||
lines = "\n".join([b["text"] for b in self.boxes[i]
|
||||
if not self.__garbage(b["text"])])
|
||||
res.append((lines, self.page_images[i]))
|
||||
callback(0.9, "Page {}~{}: Parsing finished".format(
|
||||
from_page, min(to_page, self.total_page)))
|
||||
return res, []
|
||||
|
||||
|
||||
class PlainPdf(PlainParser):
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, callback=None, **kwargs):
|
||||
self.pdf = pdf2_read(filename if not binary else BytesIO(binary))
|
||||
page_txt = []
|
||||
for page in self.pdf.pages[from_page: to_page]:
|
||||
page_txt.append(page.extract_text())
|
||||
callback(0.9, "Parsing finished")
|
||||
return [(txt, None) for txt in page_txt], []
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, vision_model=None, parser_config=None, **kwargs):
|
||||
"""
|
||||
The supported file formats are pdf, pptx.
|
||||
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
|
||||
PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
|
||||
"""
|
||||
if parser_config is None:
|
||||
parser_config = {}
|
||||
eng = lang.lower() == "english"
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
res = []
|
||||
if re.search(r"\.pptx?$", filename, re.IGNORECASE):
|
||||
if not binary:
|
||||
with open(filename, "rb") as f:
|
||||
binary = f.read()
|
||||
ppt_parser = Ppt()
|
||||
for pn, (txt, img) in enumerate(ppt_parser(
|
||||
filename if not binary else binary, from_page, 1000000, callback)):
|
||||
d = copy.deepcopy(doc)
|
||||
pn += from_page
|
||||
d["image"] = img
|
||||
d["doc_type_kwd"] = "image"
|
||||
d["page_num_int"] = [pn + 1]
|
||||
d["top_int"] = [0]
|
||||
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
|
||||
tokenize(d, txt, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
|
||||
name = layout_recognizer.strip().lower()
|
||||
parser = PARSERS.get(name, by_plaintext)
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
sections, _, _ = parser(
|
||||
filename=filename,
|
||||
binary=binary,
|
||||
from_page=from_page,
|
||||
to_page=to_page,
|
||||
lang=lang,
|
||||
callback=callback,
|
||||
vision_model=vision_model,
|
||||
pdf_cls=Pdf,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not sections:
|
||||
return []
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
for pn, (txt, img) in enumerate(sections):
|
||||
d = copy.deepcopy(doc)
|
||||
pn += from_page
|
||||
if img:
|
||||
d["image"] = img
|
||||
d["page_num_int"] = [pn + 1]
|
||||
d["top_int"] = [0]
|
||||
d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
|
||||
tokenize(d, txt, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
raise NotImplementedError(
|
||||
"file type not supported yet(pptx, pdf supported)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(a, b):
|
||||
pass
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
@@ -36,7 +36,7 @@ def generate_signed_url(
|
||||
"""
|
||||
if base_url is None:
|
||||
# Use SERVER_IP or default to localhost
|
||||
server_url = f"http://{settings.SERVER_IP}:8000/api"
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
base_url = server_url
|
||||
|
||||
# Calculate expiration timestamp
|
||||
|
||||
@@ -11,16 +11,12 @@ from typing import Any
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.graph_builder import GraphBuilder
|
||||
from app.core.workflow.expression_evaluator import evaluate_expression
|
||||
from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
# from app.core.tools.registry import ToolRegistry
|
||||
# from app.core.tools.executor import ToolExecutor
|
||||
# from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
# TOOL_MANAGEMENT_AVAILABLE = True
|
||||
# from app.db import get_db
|
||||
from app.core.workflow.template_renderer import render_template
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,6 +51,8 @@ class WorkflowExecutor:
|
||||
self.execution_config = workflow_config.get("execution_config", {})
|
||||
|
||||
self.start_node_id = None
|
||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||
self.activate_end: str | None = None
|
||||
|
||||
self.checkpoint_config = RunnableConfig(
|
||||
configurable={
|
||||
@@ -127,7 +125,6 @@ class WorkflowExecutor:
|
||||
"user_id": self.user_id,
|
||||
"error": None,
|
||||
"error_node": None,
|
||||
"streaming_buffer": {}, # 流式缓冲区
|
||||
"cycle_nodes": [
|
||||
node.get("id")
|
||||
for node in self.workflow_config.get("nodes")
|
||||
@@ -139,9 +136,8 @@ class WorkflowExecutor:
|
||||
}
|
||||
}
|
||||
|
||||
def _build_final_output(self, result, elapsed_time):
|
||||
def _build_final_output(self, result, elapsed_time, final_output):
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
final_output = self._extract_final_output(node_outputs)
|
||||
token_usage = self._aggregate_token_usage(node_outputs)
|
||||
conversation_id = None
|
||||
for node_id, node_output in node_outputs.items():
|
||||
@@ -161,6 +157,21 @@ class WorkflowExecutor:
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
def _update_end_activate(self, node_id):
|
||||
for node in self.end_outputs.keys():
|
||||
self.end_outputs[node].update_activate(node_id)
|
||||
if self.end_outputs[node].activate and self.activate_end is None:
|
||||
self.activate_end = node
|
||||
|
||||
@staticmethod
|
||||
def _trans_output_string(content):
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
return "\n".join(content)
|
||||
else:
|
||||
return str(content)
|
||||
|
||||
def build_graph(self, stream=False) -> CompiledStateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
@@ -173,6 +184,7 @@ class WorkflowExecutor:
|
||||
stream=stream,
|
||||
)
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.end_outputs = builder.end_node_map
|
||||
graph = builder.build()
|
||||
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
||||
|
||||
@@ -205,14 +217,34 @@ class WorkflowExecutor:
|
||||
try:
|
||||
|
||||
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||
|
||||
full_content = ''
|
||||
for end_info in self.end_outputs.values():
|
||||
output_template = "".join([output.literal for output in end_info.outputs])
|
||||
full_content += render_template(
|
||||
output_template,
|
||||
result.get("variables", {}),
|
||||
result.get("runtime_vars", {}),
|
||||
strict=False
|
||||
)
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return self._build_final_output(result, elapsed_time)
|
||||
return self._build_final_output(result, elapsed_time, full_content)
|
||||
|
||||
except Exception as e:
|
||||
# 计算耗时(即使失败也记录)
|
||||
@@ -273,6 +305,7 @@ class WorkflowExecutor:
|
||||
# 3. Execute workflow
|
||||
try:
|
||||
chunk_count = 0
|
||||
full_content = ''
|
||||
|
||||
async for event in graph.astream(
|
||||
initial_state,
|
||||
@@ -293,21 +326,27 @@ class WorkflowExecutor:
|
||||
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||
chunk_count += 1
|
||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||
if event_type in ("message", "node_chunk"):
|
||||
if event_type == "node_chunk":
|
||||
node_id = data.get("node_id")
|
||||
if self.activate_end:
|
||||
end_info = self.end_outputs.get(self.activate_end)
|
||||
if not end_info or end_info.cursor >= len(end_info.outputs):
|
||||
continue
|
||||
current_output = end_info.outputs[end_info.cursor]
|
||||
if current_output.is_variable and current_output.depends_on_node(node_id):
|
||||
if data.get("done"):
|
||||
end_info.cursor += 1
|
||||
else:
|
||||
full_content += data.get("chunk")
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": data.get("chunk")
|
||||
}
|
||||
}
|
||||
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}"
|
||||
f"- execution_id: {self.execution_id}")
|
||||
yield {
|
||||
"event": event_type, # "message" or "node_chunk"
|
||||
"data": {
|
||||
"node_id": data.get("node_id"),
|
||||
"chunk": data.get("chunk"),
|
||||
"full_content": data.get("full_content"),
|
||||
"chunk_index": data.get("chunk_index"),
|
||||
"is_prefix": data.get("is_prefix"),
|
||||
"is_suffix": data.get("is_suffix"),
|
||||
"conversation_id": input_data.get("conversation_id"),
|
||||
}
|
||||
}
|
||||
|
||||
elif event_type == "node_error":
|
||||
yield {
|
||||
"event": event_type, # "message" or "node_chunk"
|
||||
@@ -376,14 +415,109 @@ class WorkflowExecutor:
|
||||
|
||||
elif mode == "updates":
|
||||
# Handle state updates - store final state
|
||||
# TODO:流式输出点
|
||||
for node_id in data.keys():
|
||||
self._update_end_activate(node_id)
|
||||
wait = False
|
||||
state = graph.get_state(config=self.checkpoint_config)
|
||||
node_outputs = state.values.get("runtime_vars", {})
|
||||
for _ in data.keys():
|
||||
node_outputs = node_outputs | data.get(_).get("runtime_vars", {})
|
||||
|
||||
while self.activate_end and not wait:
|
||||
message = ''
|
||||
logger.info(self.activate_end)
|
||||
end_info = self.end_outputs[self.activate_end]
|
||||
content = end_info.outputs[end_info.cursor]
|
||||
while content.activate:
|
||||
if not content.is_variable:
|
||||
full_content += content.literal
|
||||
message += content.literal
|
||||
else:
|
||||
try:
|
||||
chunk = evaluate_expression(
|
||||
content.literal,
|
||||
variables={},
|
||||
node_outputs=node_outputs
|
||||
)
|
||||
chunk = self._trans_output_string(chunk)
|
||||
message += chunk
|
||||
full_content += chunk
|
||||
except ValueError:
|
||||
pass
|
||||
end_info.cursor += 1
|
||||
if end_info.cursor == len(end_info.outputs):
|
||||
break
|
||||
content = end_info.outputs[end_info.cursor]
|
||||
if end_info.cursor != len(end_info.outputs):
|
||||
wait = True
|
||||
else:
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
for node_id in data.keys():
|
||||
self._update_end_activate(node_id)
|
||||
if message:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": message
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
result = graph.get_state(self.checkpoint_config).values
|
||||
while self.activate_end:
|
||||
message = ''
|
||||
end_info = self.end_outputs[self.activate_end]
|
||||
content = end_info.outputs[end_info.cursor]
|
||||
if not content.is_variable:
|
||||
message += content.literal
|
||||
else:
|
||||
node_outputs = result.get("runtime_vars", {})
|
||||
variables = result.get("variables", {})
|
||||
try:
|
||||
chunk = evaluate_expression(
|
||||
content.literal,
|
||||
variables=variables,
|
||||
node_outputs=node_outputs
|
||||
)
|
||||
chunk = self._trans_output_string(chunk)
|
||||
message += chunk
|
||||
full_content += chunk
|
||||
except ValueError:
|
||||
pass
|
||||
end_info.cursor += 1
|
||||
if end_info.cursor == len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
if self.end_outputs:
|
||||
self.activate_end = list(self.end_outputs.keys())[0]
|
||||
if message:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": message
|
||||
}
|
||||
}
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
result = graph.get_state(self.checkpoint_config).values
|
||||
logger.info(result)
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
logger.info(
|
||||
f"Workflow execution completed (streaming), "
|
||||
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
|
||||
@@ -392,7 +526,7 @@ class WorkflowExecutor:
|
||||
# 发送 workflow_end 事件
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": self._build_final_output(result, elapsed_time)
|
||||
"data": self._build_final_output(result, elapsed_time, full_content)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -414,31 +548,6 @@ class WorkflowExecutor:
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_final_output(node_outputs: dict[str, Any]) -> str | None:
|
||||
"""从节点输出中提取最终输出
|
||||
|
||||
优先级:
|
||||
1. 最后一个执行的非 start/end 节点的 output
|
||||
2. 如果没有节点输出,返回 None
|
||||
|
||||
Args:
|
||||
node_outputs: 所有节点的输出
|
||||
|
||||
Returns:
|
||||
最终输出字符串或 None
|
||||
"""
|
||||
if not node_outputs:
|
||||
return None
|
||||
|
||||
# 获取最后一个节点的输出
|
||||
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
|
||||
|
||||
if last_node_output and isinstance(last_node_output, dict):
|
||||
return last_node_output.get("output")
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||
"""聚合所有节点的 token 使用情况
|
||||
@@ -529,178 +638,3 @@ async def execute_workflow_stream(
|
||||
)
|
||||
async for event in executor.execute_stream(input_data):
|
||||
yield event
|
||||
|
||||
# ==================== 工具管理系统集成 ====================
|
||||
|
||||
# def get_workflow_tools(workspace_id: str, user_id: str) -> list:
|
||||
# """获取工作流可用的工具列表
|
||||
#
|
||||
# Args:
|
||||
# workspace_id: 工作空间ID
|
||||
# user_id: 用户ID
|
||||
#
|
||||
# Returns:
|
||||
# 可用工具列表
|
||||
# """
|
||||
# if not TOOL_MANAGEMENT_AVAILABLE:
|
||||
# logger.warning("工具管理系统不可用")
|
||||
# return []
|
||||
#
|
||||
# try:
|
||||
# db = next(get_db())
|
||||
#
|
||||
# # 创建工具注册表
|
||||
# registry = ToolRegistry(db)
|
||||
#
|
||||
# # 注册内置工具类
|
||||
# from app.core.tools.builtin import (
|
||||
# DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
|
||||
# )
|
||||
# registry.register_tool_class(DateTimeTool)
|
||||
# registry.register_tool_class(JsonTool)
|
||||
# registry.register_tool_class(BaiduSearchTool)
|
||||
# registry.register_tool_class(MinerUTool)
|
||||
# registry.register_tool_class(TextInTool)
|
||||
#
|
||||
# # 获取活跃的工具
|
||||
# import uuid
|
||||
# tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
|
||||
# active_tools = [tool for tool in tools if tool.status.value == "active"]
|
||||
#
|
||||
# # 转换为Langchain工具
|
||||
# langchain_tools = []
|
||||
# for tool_info in active_tools:
|
||||
# try:
|
||||
# tool_instance = registry.get_tool(tool_info.id)
|
||||
# if tool_instance:
|
||||
# langchain_tool = LangchainAdapter.convert_tool(tool_instance)
|
||||
# langchain_tools.append(langchain_tool)
|
||||
# except Exception as e:
|
||||
# logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
|
||||
#
|
||||
# logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
|
||||
# return langchain_tools
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取工作流工具失败: {e}")
|
||||
# return []
|
||||
#
|
||||
#
|
||||
# class ToolWorkflowNode:
|
||||
# """工具工作流节点 - 在工作流中执行工具"""
|
||||
#
|
||||
# def __init__(self, node_config: dict, workflow_config: dict):
|
||||
# """初始化工具节点
|
||||
#
|
||||
# Args:
|
||||
# node_config: 节点配置
|
||||
# workflow_config: 工作流配置
|
||||
# """
|
||||
# self.node_config = node_config
|
||||
# self.workflow_config = workflow_config
|
||||
# self.tool_id = node_config.get("tool_id")
|
||||
# self.tool_parameters = node_config.get("parameters", {})
|
||||
#
|
||||
# async def run(self, state: WorkflowState) -> WorkflowState:
|
||||
# """执行工具节点"""
|
||||
# if not TOOL_MANAGEMENT_AVAILABLE:
|
||||
# logger.error("工具管理系统不可用")
|
||||
# state["error"] = "工具管理系统不可用"
|
||||
# return state
|
||||
#
|
||||
# try:
|
||||
# from sqlalchemy.orm import Session
|
||||
# db = next(get_db())
|
||||
#
|
||||
# # 创建工具执行器
|
||||
# registry = ToolRegistry(db)
|
||||
# executor = ToolExecutor(db, registry)
|
||||
#
|
||||
# # 准备参数(支持变量替换)
|
||||
# parameters = self._prepare_parameters(state)
|
||||
#
|
||||
# # 执行工具
|
||||
# result = await executor.execute_tool(
|
||||
# tool_id=self.tool_id,
|
||||
# parameters=parameters,
|
||||
# user_id=uuid.UUID(state["user_id"]),
|
||||
# workspace_id=uuid.UUID(state["workspace_id"])
|
||||
# )
|
||||
#
|
||||
# # 更新状态
|
||||
# node_id = self.node_config.get("id")
|
||||
# if result.success:
|
||||
# state["node_outputs"][node_id] = {
|
||||
# "type": "tool",
|
||||
# "tool_id": self.tool_id,
|
||||
# "output": result.data,
|
||||
# "execution_time": result.execution_time,
|
||||
# "token_usage": result.token_usage
|
||||
# }
|
||||
#
|
||||
# # 更新运行时变量
|
||||
# if isinstance(result.data, dict):
|
||||
# for key, value in result.data.items():
|
||||
# state["runtime_vars"][f"{node_id}.{key}"] = value
|
||||
# else:
|
||||
# state["runtime_vars"][f"{node_id}.result"] = result.data
|
||||
# else:
|
||||
# state["error"] = result.error
|
||||
# state["error_node"] = node_id
|
||||
# state["node_outputs"][node_id] = {
|
||||
# "type": "tool",
|
||||
# "tool_id": self.tool_id,
|
||||
# "error": result.error,
|
||||
# "execution_time": result.execution_time
|
||||
# }
|
||||
#
|
||||
# return state
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"工具节点执行失败: {e}")
|
||||
# state["error"] = str(e)
|
||||
# state["error_node"] = self.node_config.get("id")
|
||||
# return state
|
||||
#
|
||||
# def _prepare_parameters(self, state: WorkflowState) -> dict:
|
||||
# """准备工具参数(支持变量替换)"""
|
||||
# parameters = {}
|
||||
#
|
||||
# for key, value in self.tool_parameters.items():
|
||||
# if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||
# # 变量替换
|
||||
# var_path = value[2:-1]
|
||||
#
|
||||
# # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
|
||||
# if "." in var_path:
|
||||
# parts = var_path.split(".")
|
||||
# current = state.get("variables", {})
|
||||
#
|
||||
# for part in parts:
|
||||
# if isinstance(current, dict) and part in current:
|
||||
# current = current[part]
|
||||
# else:
|
||||
# # 尝试从运行时变量获取
|
||||
# runtime_key = ".".join(parts)
|
||||
# current = state.get("runtime_vars", {}).get(runtime_key, value)
|
||||
# break
|
||||
#
|
||||
# parameters[key] = current
|
||||
# else:
|
||||
# # 简单变量
|
||||
# variables = state.get("variables", {})
|
||||
# parameters[key] = variables.get(var_path, value)
|
||||
# else:
|
||||
# parameters[key] = value
|
||||
#
|
||||
# return parameters
|
||||
#
|
||||
#
|
||||
# # 注册工具节点到NodeFactory(如果存在)
|
||||
# try:
|
||||
# from app.core.workflow.nodes import NodeFactory
|
||||
# if hasattr(NodeFactory, 'register_node_type'):
|
||||
# NodeFactory.register_node_type("tool", ToolWorkflowNode)
|
||||
# logger.info("工具节点已注册到工作流系统")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"注册工具节点失败: {e}")
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import START, END
|
||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||
from langgraph.types import Send
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
@@ -15,6 +18,153 @@ from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputContent(BaseModel):
|
||||
"""
|
||||
Represents a single output segment of an End node.
|
||||
|
||||
An output segment can be either:
|
||||
- literal text (static string)
|
||||
- a variable placeholder (e.g. {{ node.field }})
|
||||
|
||||
Each segment has its own activation state, which is especially
|
||||
important in stream mode.
|
||||
"""
|
||||
|
||||
literal: str = Field(
|
||||
...,
|
||||
description="Raw output content. Can be literal text or a variable placeholder."
|
||||
)
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this output segment is currently active.\n"
|
||||
"- True: allowed to be emitted/output\n"
|
||||
"- False: blocked until activated by branch control"
|
||||
)
|
||||
)
|
||||
|
||||
is_variable: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this segment represents a variable placeholder.\n"
|
||||
"True -> variable (e.g. {{ node.field }})\n"
|
||||
"False -> literal text"
|
||||
)
|
||||
)
|
||||
|
||||
def depends_on_node(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if this output segment depends on a specific node's variable.
|
||||
|
||||
This method examines the `literal` of the output segment to see if it
|
||||
contains a variable placeholder referencing the given node in the form:
|
||||
|
||||
{{ node_id.field_name }}
|
||||
|
||||
It uses a regular expression to match the exact node ID, avoiding
|
||||
false positives from substring matches (e.g., 'node1' should not match 'node10').
|
||||
|
||||
Args:
|
||||
node_id (str): The ID of the node to check for in this segment's variable placeholders.
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
- True if the segment contains a variable referencing the given node.
|
||||
- False otherwise.
|
||||
|
||||
Example:
|
||||
literal = "{{node1.name}}"
|
||||
|
||||
depends_on_node("node1") -> True
|
||||
depends_on_node("node2") -> False
|
||||
|
||||
Usage:
|
||||
This method is primarily used in stream mode to determine whether
|
||||
a particular variable output segment should be activated when a
|
||||
specific upstream node completes execution.
|
||||
"""
|
||||
variable_pattern = rf"\{{\{{\s*{re.escape(node_id)}\.[a-zA-Z0-9_]+\s*\}}\}}"
|
||||
pattern = re.compile(variable_pattern)
|
||||
match = pattern.search(self.literal)
|
||||
if match:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class StreamOutputConfig(BaseModel):
|
||||
"""
|
||||
Streaming output configuration for an End node.
|
||||
|
||||
This structure controls:
|
||||
- whether the End node output is globally active
|
||||
- which upstream branch nodes are responsible for activation
|
||||
- how each output segment behaves in streaming mode
|
||||
"""
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Global activation state of the End node output.\n"
|
||||
"If False, no output should be emitted until all control nodes are resolved."
|
||||
)
|
||||
)
|
||||
|
||||
control_nodes: list[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"List of upstream branch node IDs that control this End node.\n"
|
||||
"Each node must signal completion before output becomes active."
|
||||
)
|
||||
)
|
||||
|
||||
outputs: list[OutputContent] = Field(
|
||||
...,
|
||||
description="Ordered list of output segments parsed from the output template."
|
||||
)
|
||||
|
||||
cursor: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Streaming cursor index.\n"
|
||||
"Indicates how many output segments have already been emitted."
|
||||
)
|
||||
)
|
||||
|
||||
def update_activate(self, node_id):
|
||||
"""
|
||||
Update activation state based on an upstream node completion.
|
||||
|
||||
This method is typically called when a branch/control node finishes execution.
|
||||
|
||||
Behavior:
|
||||
1. If the node is a control node:
|
||||
- Remove it from `control_nodes`
|
||||
- If all control nodes are resolved, activate the entire output
|
||||
|
||||
2. Activate variable output segments that depend on this node:
|
||||
- If an output segment is a variable
|
||||
- And its literal references the completed node_id
|
||||
- Mark that segment as active
|
||||
"""
|
||||
|
||||
# Case 1: resolve control branch dependency
|
||||
if node_id in self.control_nodes:
|
||||
self.control_nodes.remove(node_id)
|
||||
|
||||
# All branch constraints resolved → enable output
|
||||
if not self.control_nodes:
|
||||
self.activate = True
|
||||
|
||||
# Case 2: activate variable segments related to this node
|
||||
for i in range(len(self.outputs)):
|
||||
if (
|
||||
self.outputs[i].is_variable
|
||||
and self.outputs[i].depends_on_node(node_id)
|
||||
):
|
||||
self.outputs[i].activate = True
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -29,6 +179,12 @@ class GraphBuilder:
|
||||
|
||||
self.start_node_id = None
|
||||
self.end_node_ids = []
|
||||
self.node_map = {node["id"]: node for node in self.nodes}
|
||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||
self._find_upstream_branch_node = lru_cache(
|
||||
maxsize=len(self.nodes) * 2
|
||||
)(self._find_upstream_branch_node)
|
||||
self._analyze_end_node_output()
|
||||
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
@@ -43,79 +199,182 @@ class GraphBuilder:
|
||||
def edges(self) -> list[dict[str, Any]]:
|
||||
return self.workflow_config.get("edges", [])
|
||||
|
||||
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
||||
"""
|
||||
Analyze the prefix configuration for End nodes.
|
||||
def get_node_type(self, node_id: str) -> str:
|
||||
"""Retrieve the type of node given its ID.
|
||||
|
||||
This function scans each End node's output template, identifies
|
||||
references to its direct upstream nodes, and extracts the prefix
|
||||
string appearing before the first reference.
|
||||
Args:
|
||||
node_id (str): The unique identifier of the node.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- dict[str, str]: Mapping from upstream node ID to its End node prefix
|
||||
- set[str]: Set of node IDs that are directly adjacent to End nodes and referenced
|
||||
str: The type of the node.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no node with the given `node_id` exists.
|
||||
"""
|
||||
import re
|
||||
try:
|
||||
return self.node_map[node_id]["type"]
|
||||
except KeyError:
|
||||
raise RuntimeError(f"Node not found: Id={node_id}")
|
||||
|
||||
prefixes = {}
|
||||
adjacent_and_referenced = set() # Record nodes directly adjacent to End and referenced
|
||||
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[str]]:
|
||||
"""Find upstream branch nodes for a given target node in the workflow graph.
|
||||
|
||||
# 找到所有 End 节点
|
||||
This method identifies all upstream control (branch) nodes that can affect
|
||||
the execution of `target_node`. If `target_node` is reachable from a start
|
||||
node (i.e., a node with no upstream nodes), the method returns an empty tuple.
|
||||
|
||||
The function distinguishes between branch nodes (defined in `BRANCH_NODES`)
|
||||
and non-branch nodes, recursively traversing upstream through non-branch
|
||||
nodes. If any non-branch upstream path does not lead to a branch node,
|
||||
the result will indicate that no valid upstream branch node exists.
|
||||
|
||||
Args:
|
||||
target_node (str): The identifier of the target node.
|
||||
|
||||
Returns:
|
||||
tuple[bool, tuple[str]]:
|
||||
- has_branch (bool): True if all upstream non-branch paths lead to at least
|
||||
one branch node; False if any path reaches a start node without a branch.
|
||||
- branch_nodes (tuple[str]): A deduplicated tuple of upstream branch node IDs
|
||||
affecting `target_node`. Returns an empty tuple if `has_branch` is False.
|
||||
"""
|
||||
source_nodes = [
|
||||
edge.get("source")
|
||||
for edge in self.edges
|
||||
if edge.get("target") == target_node
|
||||
]
|
||||
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
||||
return False, tuple()
|
||||
|
||||
branch_nodes = []
|
||||
non_branch_nodes = []
|
||||
|
||||
for node_id in source_nodes:
|
||||
if self.get_node_type(node_id) in BRANCH_NODES:
|
||||
branch_nodes.append(node_id)
|
||||
else:
|
||||
non_branch_nodes.append(node_id)
|
||||
|
||||
has_branch = True
|
||||
for node_id in non_branch_nodes:
|
||||
node_has_branch, nodes = self._find_upstream_branch_node(node_id)
|
||||
has_branch = has_branch and node_has_branch
|
||||
if not has_branch:
|
||||
break
|
||||
branch_nodes.extend(nodes)
|
||||
if not has_branch:
|
||||
branch_nodes = []
|
||||
|
||||
return has_branch, tuple(set(branch_nodes))
|
||||
|
||||
def _analyze_end_node_output(self):
|
||||
"""
|
||||
Analyze output templates of all End nodes and generate StreamOutputConfig.
|
||||
|
||||
This method is responsible for parsing the `output` field of End nodes,
|
||||
splitting literal text and variable placeholders (e.g. {{ node.field }}),
|
||||
and determining whether each output segment should be activated immediately
|
||||
or controlled by upstream branch nodes.
|
||||
|
||||
In stream mode:
|
||||
- If the End node is controlled by any upstream branch node, the output
|
||||
will be initially inactive and controlled by those branch nodes.
|
||||
- Otherwise, the output is activated immediately.
|
||||
|
||||
In non-stream mode:
|
||||
- All outputs are activated by default.
|
||||
"""
|
||||
|
||||
# Collect all End nodes in the workflow
|
||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
|
||||
|
||||
# Iterate through each End node to analyze its output
|
||||
for end_node in end_nodes:
|
||||
end_node_id = end_node.get("id")
|
||||
output_template = end_node.get("config", {}).get("output")
|
||||
config = end_node.get("config", {})
|
||||
output = config.get("output")
|
||||
|
||||
logger.info(f"[Prefix Analysis] End node {end_node_id} template: {output_template}")
|
||||
|
||||
if not output_template:
|
||||
# Skip End nodes without output configuration
|
||||
if not output:
|
||||
continue
|
||||
|
||||
# Find all node references in the template
|
||||
# Matches {{node_id.xxx}} or {{ node_id.xxx }} format (allowing spaces)
|
||||
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
matches = list(re.finditer(pattern, output_template))
|
||||
# Regex to split output into:
|
||||
# - variable placeholders: {{ ... }}
|
||||
# - normal literal text
|
||||
#
|
||||
# Example:
|
||||
# "Hello {{user.name}}!" ->
|
||||
# ["Hello ", "{{user.name}}", "!"]
|
||||
pattern = r'\{\{.*?\}\}|[^{}]+'
|
||||
|
||||
logger.info(f"[Prefix Analysis] 模板中找到 {len(matches)} 个节点引用")
|
||||
# Strict variable format: {{ node_id.field_name }}
|
||||
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
variable_pattern = re.compile(variable_pattern_string)
|
||||
|
||||
# Identify all direct upstream nodes connected to the End node
|
||||
direct_upstream_nodes = []
|
||||
for edge in self.edges:
|
||||
if edge.get("target") == end_node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
# Split output into ordered segments
|
||||
output_template = list(re.findall(pattern, output))
|
||||
|
||||
logger.info(f"[Prefix Analysis] Direct upstream nodes of End node: {direct_upstream_nodes}")
|
||||
# Determine whether each segment is literal text
|
||||
# True -> literal (can be directly output)
|
||||
# False -> variable placeholder (needs runtime value)
|
||||
output_flag = [
|
||||
not bool(variable_pattern.match(item))
|
||||
for item in output_template
|
||||
]
|
||||
|
||||
# 找到第一个直接上游节点的引用
|
||||
for match in matches:
|
||||
referenced_node_id = match.group(1)
|
||||
logger.info(f"[Prefix Analysis] Checking reference: {referenced_node_id}")
|
||||
# Stream mode: output activation depends on upstream branch nodes
|
||||
if self.stream:
|
||||
# Find upstream branch nodes that can control this End node
|
||||
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
|
||||
|
||||
if referenced_node_id in direct_upstream_nodes:
|
||||
# 这是直接上游节点的引用,提取前缀
|
||||
prefix = output_template[:match.start()]
|
||||
# Build StreamOutputConfig for this End node
|
||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||
# If there is no upstream branch, output is active immediately
|
||||
activate=not has_branch,
|
||||
|
||||
logger.info(f"[Prefix Analysis] "
|
||||
f"✅ Found reference to direct upstream node {referenced_node_id}, prefix: '{prefix}'")
|
||||
# Branch nodes that control activation of this End node
|
||||
control_nodes=list(control_nodes),
|
||||
|
||||
# 标记这个节点为"相邻且被引用"
|
||||
adjacent_and_referenced.add(referenced_node_id)
|
||||
# Convert output segments into OutputContent objects
|
||||
outputs=list(
|
||||
[
|
||||
OutputContent(
|
||||
literal=output_string,
|
||||
# Literal text can be activated immediately unless blocked by branch
|
||||
activate=activate,
|
||||
# Variable segments are marked explicitly
|
||||
is_variable=not activate
|
||||
)
|
||||
for output_string, activate in zip(output_template, output_flag)
|
||||
]
|
||||
),
|
||||
# Cursor for streaming output (initially 0)
|
||||
cursor=0
|
||||
)
|
||||
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
||||
f"activate: {not has_branch}, "
|
||||
f"control_nodes: {control_nodes},"
|
||||
f"output: {output_template},"
|
||||
f"output_activate: {output_flag}")
|
||||
|
||||
if prefix:
|
||||
prefixes[referenced_node_id] = prefix
|
||||
logger.info(f"[Prefix Analysis] "
|
||||
f"✅ Assign prefix for node {referenced_node_id}: '{prefix[:50]}...'")
|
||||
|
||||
# 只处理第一个直接上游节点的引用
|
||||
break
|
||||
|
||||
logger.info(f"[Prefix Analysis] Final prefixes: {prefixes}")
|
||||
logger.info(f"[Prefix Analysis] Nodes adjacent to End and referenced: {adjacent_and_referenced}")
|
||||
return prefixes, adjacent_and_referenced
|
||||
# Non-stream mode: all outputs are activated by default
|
||||
else:
|
||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||
activate=True,
|
||||
control_nodes=[],
|
||||
outputs=list(
|
||||
[
|
||||
OutputContent(
|
||||
literal=output_string,
|
||||
activate=True,
|
||||
is_variable=not activate
|
||||
)
|
||||
for output_string, activate in zip(output_template, output_flag)
|
||||
]
|
||||
),
|
||||
cursor=0
|
||||
)
|
||||
|
||||
def add_nodes(self):
|
||||
"""Add all nodes from the workflow configuration to the state graph.
|
||||
@@ -135,9 +394,6 @@ class GraphBuilder:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Analyze End node prefixes if in stream mode
|
||||
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
|
||||
|
||||
for node in self.nodes:
|
||||
node_type = node.get("type")
|
||||
node_id = node.get("id")
|
||||
@@ -171,17 +427,6 @@ class GraphBuilder:
|
||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||
|
||||
if node_instance:
|
||||
# Inject End node prefix configuration if in stream mode
|
||||
if self.stream and node_id in end_prefixes:
|
||||
node_instance._end_node_prefix = end_prefixes[node_id]
|
||||
logger.info(f"Injected End prefix for node {node_id}")
|
||||
|
||||
# Mark nodes as adjacent and referenced to End node in stream mode
|
||||
if self.stream:
|
||||
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
|
||||
if node_id in adjacent_and_referenced:
|
||||
logger.info(f"Node {node_id} marked as adjacent and referenced to End node")
|
||||
|
||||
# Wrap node's run method to avoid closure issues
|
||||
if self.stream:
|
||||
# Stream mode: create an async generator function
|
||||
@@ -261,6 +506,7 @@ class GraphBuilder:
|
||||
for source_node, branches in conditional_edges.items():
|
||||
def make_router(src, branch_list):
|
||||
"""reate a router function for each source node that routes to a NOP node for later merging."""
|
||||
|
||||
def make_branch_node(node_name, targets):
|
||||
def node(s):
|
||||
# NOTE: NOP NODE MUST NOT MODIFY STATE
|
||||
|
||||
@@ -67,10 +67,6 @@ class WorkflowState(TypedDict):
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
# Streaming buffer (stores real-time streaming output of nodes)
|
||||
# Format: {node_id: {"chunks": [...], "full_content": "..."}}
|
||||
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||
|
||||
@@ -300,7 +296,7 @@ class BaseNode(ABC):
|
||||
"""
|
||||
if not self.check_activate(state):
|
||||
yield self.trans_activate(state)
|
||||
logger.info(f"跳过节点{self.node_id}")
|
||||
logger.info(f"jump node: {self.node_id}")
|
||||
return
|
||||
|
||||
import time
|
||||
@@ -313,19 +309,6 @@ class BaseNode(ABC):
|
||||
# Get LangGraph's stream writer for sending custom data
|
||||
writer = get_stream_writer()
|
||||
|
||||
# Check if this is an End node
|
||||
# End nodes CAN send chunks (for suffix), but only after LLM content
|
||||
is_end_node = self.node_type == "end"
|
||||
|
||||
# Check if this node is adjacent to End node (for message type)
|
||||
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
|
||||
|
||||
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
|
||||
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
|
||||
|
||||
logger.debug(
|
||||
f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
|
||||
|
||||
# Accumulate complete result (for final wrapping)
|
||||
chunks = []
|
||||
final_result = None
|
||||
@@ -340,66 +323,25 @@ class BaseNode(ABC):
|
||||
raise TimeoutError()
|
||||
|
||||
# Check if it's a completion marker
|
||||
if isinstance(item, dict) and item.get("__final__"):
|
||||
if item.get("__final__"):
|
||||
final_result = item["result"]
|
||||
elif isinstance(item, str):
|
||||
# String is a chunk
|
||||
else:
|
||||
chunk_count += 1
|
||||
chunks.append(item)
|
||||
full_content = "".join(chunks)
|
||||
content = str(item.get("chunk"))
|
||||
done = item.get("done", False)
|
||||
chunks.append(content)
|
||||
|
||||
# Send chunks for all nodes (including End nodes for suffix)
|
||||
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
|
||||
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {content[:50]}...")
|
||||
|
||||
# 1. Send via stream writer (for real-time client updates)
|
||||
writer({
|
||||
"type": chunk_type, # "message" or "node_chunk"
|
||||
"type": "node_chunk",
|
||||
"node_id": self.node_id,
|
||||
"chunk": item,
|
||||
"full_content": full_content,
|
||||
"chunk_index": chunk_count
|
||||
"chunk": content,
|
||||
"done": done
|
||||
})
|
||||
|
||||
# 2. Update streaming buffer in state (for downstream nodes)
|
||||
# Only non-End nodes need streaming buffer
|
||||
if not is_end_node:
|
||||
yield {
|
||||
"streaming_buffer": {
|
||||
self.node_id: {
|
||||
"full_content": full_content,
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": False
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
# Other types are also treated as chunks
|
||||
chunk_count += 1
|
||||
chunk_str = str(item)
|
||||
chunks.append(chunk_str)
|
||||
full_content = "".join(chunks)
|
||||
|
||||
# Send chunks for all nodes
|
||||
writer({
|
||||
"type": chunk_type, # "message" or "node_chunk"
|
||||
"node_id": self.node_id,
|
||||
"chunk": chunk_str,
|
||||
"full_content": full_content,
|
||||
"chunk_index": chunk_count
|
||||
})
|
||||
|
||||
# Only non-End nodes need streaming buffer
|
||||
if not is_end_node:
|
||||
yield {
|
||||
"streaming_buffer": {
|
||||
self.node_id: {
|
||||
"full_content": full_content,
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||
@@ -426,16 +368,6 @@ class BaseNode(ABC):
|
||||
"looping": state["looping"]
|
||||
}
|
||||
|
||||
# Add streaming buffer for non-End nodes
|
||||
if not is_end_node:
|
||||
state_update["streaming_buffer"] = {
|
||||
self.node_id: {
|
||||
"full_content": "".join(chunks),
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": True # Mark as complete
|
||||
}
|
||||
}
|
||||
|
||||
# Finally yield state update
|
||||
# LangGraph will merge this into state
|
||||
yield state_update | self.trans_activate(state)
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from app.core.workflow.nodes.code.node import CodeNode
|
||||
|
||||
__all__ = ["CodeNode"]
|
||||
__all__ = ["CodeNode"]
|
||||
|
||||
@@ -7,7 +7,6 @@ from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from sympy.physics.vector import vlatex
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
@@ -6,7 +6,6 @@ from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
|
||||
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
@@ -5,10 +5,8 @@ End 节点实现
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,24 +35,8 @@ class EndNode(BaseNode):
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state, strict=False)
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": output
|
||||
}
|
||||
])
|
||||
else:
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
])
|
||||
output = "工作流已完成"
|
||||
output = ""
|
||||
|
||||
# 统计信息(用于日志)
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
@@ -63,274 +45,3 @@ class EndNode(BaseNode):
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||
|
||||
return output
|
||||
|
||||
def _extract_referenced_nodes(self, template: str) -> list[str]:
|
||||
"""从模板中提取引用的节点 ID
|
||||
|
||||
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
Returns:
|
||||
引用的节点 ID 列表
|
||||
"""
|
||||
# 匹配 {{node_id.xxx}} 格式
|
||||
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
|
||||
matches = re.findall(pattern, template)
|
||||
return list(set(matches)) # 去重
|
||||
|
||||
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
|
||||
"""解析模板,分离静态文本和动态引用
|
||||
|
||||
例如:'你好 {{llm.output}}, 这是后缀'
|
||||
返回:[
|
||||
{"type": "static", "content": "你好 "},
|
||||
{"type": "dynamic", "node_id": "llm", "field": "output"},
|
||||
{"type": "static", "content": ", 这是后缀"}
|
||||
]
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
模板部分列表
|
||||
"""
|
||||
import re
|
||||
|
||||
parts = []
|
||||
last_end = 0
|
||||
|
||||
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
|
||||
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
|
||||
|
||||
for match in re.finditer(pattern, template):
|
||||
start, end = match.span()
|
||||
|
||||
# 添加前面的静态文本
|
||||
if start > last_end:
|
||||
static_text = template[last_end:start]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
# 解析动态引用
|
||||
ref = match.group(1).strip()
|
||||
|
||||
# 检查是否是节点引用(如 llm.output 或 llm_qa.output)
|
||||
if '.' in ref:
|
||||
node_id, field = ref.split('.', 1)
|
||||
parts.append({
|
||||
"type": "dynamic",
|
||||
"node_id": node_id,
|
||||
"field": field,
|
||||
"raw": ref
|
||||
})
|
||||
else:
|
||||
# 其他引用(如 {{var.xxx}}),当作静态处理
|
||||
# 直接渲染这部分
|
||||
rendered = self._render_template(f"{{{{{ref}}}}}", state)
|
||||
parts.append({"type": "static", "content": rendered})
|
||||
|
||||
last_end = end
|
||||
|
||||
# 添加最后的静态文本
|
||||
if last_end < len(template):
|
||||
static_text = template[last_end:]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
return parts
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""Execute End node business logic (streaming)
|
||||
|
||||
Smart output strategy:
|
||||
1. Check if template references a direct upstream LLM node
|
||||
2. If yes, only output the part AFTER that reference (suffix)
|
||||
3. Prefix and LLM content have already been sent during LLM node streaming
|
||||
|
||||
Note: Only LLM nodes get this special treatment. Other node types output normally.
|
||||
|
||||
Example: '{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
|
||||
- Direct upstream LLM node is llm_qa
|
||||
- Prefix '{{start.test}}hahaha ' was sent before LLM node streaming
|
||||
- LLM content was streamed during LLM node execution
|
||||
- End node only outputs ' lalalalala a' (suffix, sent as one chunk)
|
||||
|
||||
Args:
|
||||
state: Workflow state
|
||||
|
||||
Yields:
|
||||
Completion marker
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
if not output_template:
|
||||
output = "工作流已完成"
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "message", # End node output uses message type
|
||||
"node_id": self.node_id,
|
||||
"chunk": "",
|
||||
"full_content": output,
|
||||
"chunk_index": 1,
|
||||
"is_suffix": False
|
||||
})
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
}
|
||||
])
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# Find direct upstream LLM nodes
|
||||
direct_upstream_llm_nodes = []
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
if edge.get("target") == self.node_id:
|
||||
source_node_id = edge.get("source")
|
||||
# Check if the source node is an LLM node
|
||||
for node in self.workflow_config.get("nodes", []):
|
||||
logger.info(f"节点 {self.node_id} 的类型 {node.get("type")}")
|
||||
if node.get("id") == source_node_id and node.get("type") == NodeType.LLM:
|
||||
direct_upstream_llm_nodes.append(source_node_id)
|
||||
break
|
||||
|
||||
logger.info(f"节点 {self.node_id} 的直接上游 LLM 节点: {direct_upstream_llm_nodes}")
|
||||
|
||||
# Parse template parts
|
||||
parts = self._parse_template_parts(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
|
||||
for i, part in enumerate(parts):
|
||||
logger.info(f"[模板解析] part[{i}]: {part}")
|
||||
|
||||
# Find the first reference to a direct upstream LLM node
|
||||
upstream_llm_ref_index = None
|
||||
for i, part in enumerate(parts):
|
||||
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_llm_nodes:
|
||||
upstream_llm_ref_index = i
|
||||
logger.info(f"节点 {self.node_id} 找到直接上游 LLM 节点 {part['node_id']} 的引用,索引: {i}")
|
||||
break
|
||||
|
||||
if upstream_llm_ref_index is None:
|
||||
# No reference to direct upstream LLM node, output complete template content
|
||||
output = self._render_template(output_template, state, strict=False)
|
||||
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
|
||||
|
||||
# Send complete content via writer (as a single message chunk)
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "message", # End node output uses message type
|
||||
"node_id": self.node_id,
|
||||
"chunk": output,
|
||||
"full_content": output,
|
||||
"chunk_index": 1,
|
||||
"is_suffix": False
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
||||
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": output
|
||||
}
|
||||
])
|
||||
|
||||
# yield completion marker
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# Has reference to direct upstream LLM node, only output the part after that reference (suffix)
|
||||
logger.info(
|
||||
f"节点 {self.node_id} 检测到直接上游 LLM 节点引用,只输出后缀部分(从索引 {upstream_llm_ref_index + 1} 开始)")
|
||||
|
||||
# Collect suffix parts
|
||||
suffix_parts = []
|
||||
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_llm_ref_index + 1} 到 {len(parts) - 1}")
|
||||
for i in range(upstream_llm_ref_index + 1, len(parts)):
|
||||
part = parts[i]
|
||||
logger.info(f"[后缀调试] 处理 part[{i}]: {part}")
|
||||
if part["type"] == "static":
|
||||
# 静态文本
|
||||
logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'")
|
||||
suffix_parts.append(part["content"])
|
||||
|
||||
elif part["type"] == "dynamic":
|
||||
# Other dynamic references (if there are multiple references)
|
||||
node_id = part["node_id"]
|
||||
field = part["field"]
|
||||
|
||||
# Use VariablePool to get variable value
|
||||
pool = self.get_variable_pool(state)
|
||||
try:
|
||||
# Try to get variable value with default empty string
|
||||
content = pool.get([node_id, field], default="")
|
||||
logger.info(f"[后缀调试] 获取变量 {node_id}.{field} 成功: '{content}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}")
|
||||
content = ""
|
||||
|
||||
# Convert to string if not None
|
||||
suffix_parts.append(str(content) if content is not None else "")
|
||||
|
||||
# 拼接后缀
|
||||
suffix = "".join(suffix_parts)
|
||||
|
||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||
full_output = self._render_template(output_template, state, strict=False)
|
||||
|
||||
state['messages'].extend([
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.get_variable("sys.message", state)
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_output
|
||||
}
|
||||
])
|
||||
|
||||
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
||||
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
||||
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
|
||||
logger.info(f"[后缀调试] 后缀是否为空: {not suffix}")
|
||||
|
||||
if suffix:
|
||||
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})")
|
||||
# 一次性输出后缀(作为单个 chunk)
|
||||
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
|
||||
# 而是通过 writer 直接发送
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "message", # End 节点的输出使用 message 类型
|
||||
"node_id": self.node_id,
|
||||
"chunk": suffix,
|
||||
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
|
||||
"chunk_index": 1,
|
||||
"is_suffix": True
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||
else:
|
||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!"
|
||||
f"upstream_llm_ref_index={upstream_llm_ref_index}, parts数量={len(parts)}")
|
||||
|
||||
# 统计信息
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
|
||||
|
||||
# yield 完成标记(包含完整输出)
|
||||
yield {"__final__": True, "result": full_output}
|
||||
|
||||
@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
||||
class IfElseNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: IfElseNodeConfig | None= None
|
||||
self.typed_config: IfElseNodeConfig | None = None
|
||||
|
||||
@staticmethod
|
||||
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||
|
||||
@@ -7,18 +7,18 @@ LLM 节点实现
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -231,42 +231,14 @@ class LLMNode(BaseNode):
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
|
||||
# 检查是否有注入的 End 节点前缀配置
|
||||
writer = get_stream_writer()
|
||||
end_prefix = getattr(self, '_end_node_prefix', None)
|
||||
|
||||
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
|
||||
if end_prefix:
|
||||
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
|
||||
|
||||
if end_prefix:
|
||||
# 渲染前缀(可能包含其他变量)
|
||||
try:
|
||||
rendered_prefix = self._render_template(end_prefix, state)
|
||||
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
||||
|
||||
# 提前发送 End 节点的前缀(使用 "message" 类型)
|
||||
writer({
|
||||
"type": "message", # End 相关的内容都是 message 类型
|
||||
"node_id": "end", # 标记为 end 节点的输出
|
||||
"chunk": rendered_prefix,
|
||||
"full_content": rendered_prefix,
|
||||
"chunk_index": 0,
|
||||
"is_prefix": True # 标记这是前缀
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
last_chunk = None
|
||||
chunk_count = 0
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
@@ -284,12 +256,19 @@ class LLMNode(BaseNode):
|
||||
# 只有当内容不为空时才处理
|
||||
if content:
|
||||
full_response += content
|
||||
last_chunk = chunk
|
||||
chunk_count += 1
|
||||
|
||||
# 流式返回每个文本片段
|
||||
yield content
|
||||
yield {
|
||||
"__final__": False,
|
||||
"chunk": content
|
||||
}
|
||||
|
||||
yield {
|
||||
"__final__": False,
|
||||
"chunk": "",
|
||||
"done": True
|
||||
}
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import Field
|
||||
from typing import Literal
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
@@ -12,7 +10,7 @@ class MemoryReadNodeConfig(BaseNodeConfig):
|
||||
...
|
||||
)
|
||||
|
||||
config_id: UUID = Field(
|
||||
config_id: UUID | int = Field(
|
||||
...
|
||||
)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
|
||||
return await MemoryAgentService().read_memory(
|
||||
end_user_id=end_user_id,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
config_id=str(self.typed_config.config_id),
|
||||
config_id=self.typed_config.config_id,
|
||||
search_switch=self.typed_config.search_switch,
|
||||
history=[],
|
||||
db=db,
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import Field, BaseModel
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
class ClassifierConfig(BaseModel):
|
||||
"""分类器节点配置"""
|
||||
|
||||
@@ -13,7 +14,7 @@ class ClassifierConfig(BaseModel):
|
||||
|
||||
class QuestionClassifierNodeConfig(BaseNodeConfig):
|
||||
"""问题分类器节点配置"""
|
||||
|
||||
|
||||
model_id: uuid.UUID = Field(..., description="LLM模型ID")
|
||||
input_variable: str = Field(default="{{sys.message}}", description="输入变量选择器(用户问题)")
|
||||
user_supplement_prompt: Optional[str] = Field(default=None, description="用户补充提示词,额外分类指令")
|
||||
|
||||
@@ -18,30 +18,30 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
|
||||
|
||||
class QuestionClassifierNode(BaseNode):
|
||||
"""问题分类器节点"""
|
||||
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||
self.category_to_case_map = {}
|
||||
|
||||
|
||||
def _get_llm_instance(self) -> RedBearLLM:
|
||||
"""获取LLM实例"""
|
||||
with get_db_read() as db:
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id)
|
||||
|
||||
|
||||
if not config:
|
||||
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
|
||||
api_config = config.api_keys[0]
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
base_url = api_config.api_base
|
||||
model_type = config.type
|
||||
|
||||
|
||||
return RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
@@ -64,7 +64,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
case_tag = f"{DEFAULT_CASE_PREFIX}{idx}"
|
||||
category_map[category_name] = case_tag
|
||||
return category_map
|
||||
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict:
|
||||
"""执行问题分类"""
|
||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||
@@ -74,11 +74,12 @@ class QuestionClassifierNode(BaseNode):
|
||||
categories = self.typed_config.categories or []
|
||||
category_names = [class_item.class_name.strip() for class_item in categories]
|
||||
category_count = len(category_names)
|
||||
|
||||
|
||||
if not question:
|
||||
logger.warning(
|
||||
f"节点 {self.node_id} 未获取到输入问题,使用默认分支"
|
||||
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})"
|
||||
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE}"
|
||||
f"分类总数: {category_count})"
|
||||
)
|
||||
# 若分类列表为空,返回默认unknown分支,否则返回CASE1
|
||||
if category_count > 0:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.nodes.tool.node import ToolNode
|
||||
|
||||
__all__ = ["ToolNode", "ToolNodeConfig"]
|
||||
__all__ = ["ToolNode", "ToolNodeConfig"]
|
||||
|
||||
@@ -16,11 +16,11 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
"""工具节点"""
|
||||
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: ToolNodeConfig | None = None
|
||||
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行工具"""
|
||||
self.typed_config = ToolNodeConfig(**self.config)
|
||||
@@ -28,21 +28,21 @@ class ToolNode(BaseNode):
|
||||
tenant_id = self.get_variable("sys.tenant_id", state)
|
||||
user_id = self.get_variable("sys.user_id", state)
|
||||
workspace_id = self.get_variable("sys.workspace_id", state)
|
||||
|
||||
|
||||
# 如果没有租户ID,尝试从工作流ID获取
|
||||
if not tenant_id:
|
||||
if workspace_id:
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
with get_db_read() as db:
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
|
||||
|
||||
|
||||
if not tenant_id:
|
||||
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||
return {
|
||||
"success": False,
|
||||
"data": "缺少租户ID"
|
||||
}
|
||||
|
||||
|
||||
# 渲染工具参数
|
||||
rendered_parameters = {}
|
||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||
@@ -55,9 +55,9 @@ class ToolNode(BaseNode):
|
||||
# 非模板参数(数字/布尔/普通字符串)直接保留原值
|
||||
rendered_value = param_template
|
||||
rendered_parameters[param_name] = rendered_value
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
|
||||
|
||||
|
||||
# 执行工具
|
||||
with get_db_read() as db:
|
||||
tool_service = ToolService(db)
|
||||
@@ -79,7 +79,7 @@ class ToolNode(BaseNode):
|
||||
else:
|
||||
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
||||
return {
|
||||
"data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False),
|
||||
"data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False),
|
||||
"error_code": result.error_code,
|
||||
"execution_time": result.execution_time
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ class MemoryConfig(Base):
|
||||
|
||||
# 主键
|
||||
config_id = Column(UUID(as_uuid=True), primary_key=True, comment="配置ID")
|
||||
|
||||
config_id_old = Column(Integer, nullable=True, comment="备份的配置ID")
|
||||
# 基本信息
|
||||
config_name = Column(String, nullable=False, comment="配置名称")
|
||||
config_desc = Column(String, nullable=True, comment="配置描述")
|
||||
|
||||
@@ -16,6 +16,10 @@ class Tenants(Base):
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# SSO 外部关联字段
|
||||
external_id = Column(String(100), nullable=True, index=True) # 外部企业ID
|
||||
external_source = Column(String(50), nullable=True) # 来源系统
|
||||
|
||||
# Relationship to users - one tenant has many users
|
||||
users = relationship("User", back_populates="tenant")
|
||||
|
||||
|
||||
@@ -18,6 +18,10 @@ class User(Base):
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||
last_login_at = Column(DateTime, nullable=True) # 最后登录时间,可为空
|
||||
|
||||
# SSO 外部关联字段
|
||||
external_id = Column(String(100), nullable=True) # 外部用户ID
|
||||
external_source = Column(String(50), nullable=True) # 来源系统
|
||||
|
||||
current_workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=True) # 当前工作空间ID,可为空
|
||||
|
||||
# Foreign key to tenant - each user belongs to exactly one tenant
|
||||
|
||||
74
api/app/plugins/__init__.py
Normal file
74
api/app/plugins/__init__.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# app/plugins/__init__.py
|
||||
"""
|
||||
插件系统 - 支持开源核心 + 闭源增值模块
|
||||
|
||||
使用方式:
|
||||
1. 开源版(community):基础功能
|
||||
2. 商业版(enterprise):加载 premium 包中的高级实现
|
||||
"""
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 版本标识
|
||||
EDITION = os.environ.get("EDITION", "community")
|
||||
IS_ENTERPRISE = EDITION == "enterprise"
|
||||
|
||||
# 插件注册表
|
||||
_plugins: Dict[str, Any] = {}
|
||||
|
||||
# 路由注册表(用于动态注册闭源模块的路由)
|
||||
_routers: list = []
|
||||
|
||||
|
||||
def is_enterprise() -> bool:
|
||||
"""是否为商业版"""
|
||||
return IS_ENTERPRISE
|
||||
|
||||
|
||||
def list_plugins() -> list:
|
||||
"""列出所有已注册插件"""
|
||||
return list(_plugins.keys())
|
||||
|
||||
|
||||
def register_plugin(name: str, instance: Any):
|
||||
"""注册插件"""
|
||||
_plugins[name] = instance
|
||||
logger.info(f"插件已注册: {name}")
|
||||
|
||||
|
||||
def get_plugin(name: str) -> Optional[Any]:
|
||||
"""获取插件实例"""
|
||||
return _plugins.get(name)
|
||||
|
||||
|
||||
def register_router(router, prefix: str = "", tags: list = None):
|
||||
"""注册路由(供闭源模块使用)"""
|
||||
_routers.append({
|
||||
"router": router,
|
||||
"prefix": prefix,
|
||||
"tags": tags or []
|
||||
})
|
||||
logger.info(f"路由已注册: {prefix}")
|
||||
|
||||
|
||||
def get_registered_routers() -> list:
|
||||
"""获取所有注册的路由"""
|
||||
return _routers
|
||||
|
||||
|
||||
def register_premium_routers(app):
|
||||
"""
|
||||
注册 premium 模块的路由到 FastAPI app
|
||||
|
||||
在商业版 main.py 中调用
|
||||
"""
|
||||
for router_info in _routers:
|
||||
app.include_router(
|
||||
router_info["router"],
|
||||
prefix=f"/api{router_info['prefix']}",
|
||||
tags=router_info["tags"]
|
||||
)
|
||||
logger.info(f"Premium 路由已挂载: /api{router_info['prefix']}")
|
||||
@@ -528,7 +528,8 @@ class WorkflowService:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role=message["role"],
|
||||
content=message["content"]
|
||||
content=message["content"],
|
||||
meta_data=None if message["role"] == "user" else {"usage": token_usage}
|
||||
)
|
||||
logger.info(f"Workflow Run Success, "
|
||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||
@@ -678,7 +679,8 @@ class WorkflowService:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id_uuid,
|
||||
role=message["role"],
|
||||
content=message["content"]
|
||||
content=message["content"],
|
||||
meta_data=None if message["role"] == "user" else {"usage": token_usage}
|
||||
)
|
||||
logger.info(f"Workflow Run Success, "
|
||||
f"execution_id: {execution.execution_id}, message count: {len(final_messages)}")
|
||||
|
||||
@@ -15,6 +15,7 @@ services:
|
||||
networks:
|
||||
- default
|
||||
- celery
|
||||
- sandbox
|
||||
depends_on:
|
||||
- worker-memory
|
||||
- worker-document
|
||||
@@ -63,5 +64,16 @@ services:
|
||||
depends_on:
|
||||
- worker-memory
|
||||
|
||||
sandbox:
|
||||
image: redbear_sandbox:latest
|
||||
container_name: sandbox
|
||||
ports:
|
||||
- "8194"
|
||||
command: /code/.venv/bin/python main.py
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- sandbox
|
||||
|
||||
networks:
|
||||
celery:
|
||||
sandbox:
|
||||
|
||||
@@ -75,6 +75,7 @@ ENABLE_SINGLE_SESSION=
|
||||
MAX_FILE_SIZE=52428800 # 50MB:10 * 1024 * 1024
|
||||
FILE_PATH=/files
|
||||
|
||||
FILE_LOCAL_SERVER_URL="http://localhost:8000/api"
|
||||
# Storage Backend Configuration
|
||||
# Supported values: local, oss, s3
|
||||
# Default: local
|
||||
|
||||
57
api/migrations/versions/75f0ec80e50b_202601271517.py
Normal file
57
api/migrations/versions/75f0ec80e50b_202601271517.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""202601271517
|
||||
|
||||
Revision ID: 75f0ec80e50b
|
||||
Revises: 325b759cd66b
|
||||
Create Date: 2026-01-27 15:26:48.696600
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '75f0ec80e50b'
|
||||
down_revision: Union[str, None] = '325b759cd66b'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column('memory_config', 'config_id',
|
||||
existing_type=sa.UUID(),
|
||||
comment='配置ID',
|
||||
existing_nullable=False)
|
||||
op.alter_column('memory_config', 'config_id_old',
|
||||
existing_type=sa.INTEGER(),
|
||||
comment='备份的配置ID',
|
||||
existing_comment='配置ID',
|
||||
existing_nullable=True)
|
||||
op.add_column('tenants', sa.Column('external_id', sa.String(length=100), nullable=True))
|
||||
op.add_column('tenants', sa.Column('external_source', sa.String(length=50), nullable=True))
|
||||
op.create_index(op.f('ix_tenants_external_id'), 'tenants', ['external_id'], unique=False)
|
||||
op.add_column('users', sa.Column('external_id', sa.String(length=100), nullable=True))
|
||||
op.add_column('users', sa.Column('external_source', sa.String(length=50), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('users', 'external_source')
|
||||
op.drop_column('users', 'external_id')
|
||||
op.drop_index(op.f('ix_tenants_external_id'), table_name='tenants')
|
||||
op.drop_column('tenants', 'external_source')
|
||||
op.drop_column('tenants', 'external_id')
|
||||
op.alter_column('memory_config', 'config_id_old',
|
||||
existing_type=sa.INTEGER(),
|
||||
comment='配置ID',
|
||||
existing_comment='备份的配置ID',
|
||||
existing_nullable=True)
|
||||
op.alter_column('memory_config', 'config_id',
|
||||
existing_type=sa.UUID(),
|
||||
comment=None,
|
||||
existing_comment='配置ID',
|
||||
existing_nullable=False)
|
||||
# ### end Alembic commands ###
|
||||
@@ -88,7 +88,6 @@ dependencies = [
|
||||
"cachetools==6.2.1",
|
||||
"ruamel.yaml==0.18.10",
|
||||
"strenum==0.4.15",
|
||||
"aspose-slides==24.12.0",
|
||||
"opencv-python==4.10.0.84",
|
||||
"numpy>=1.26.0,<2.0.0",
|
||||
"huggingface-hub==0.25.2",
|
||||
|
||||
@@ -83,7 +83,6 @@ olefile==0.47
|
||||
cachetools==6.2.1
|
||||
ruamel.yaml==0.18.10
|
||||
strenum==0.4.15
|
||||
aspose-slides==24.12.0
|
||||
opencv-python==4.10.0.84
|
||||
numpy>=1.26.0,<2.0.0
|
||||
huggingface-hub==0.25.2
|
||||
|
||||
Reference in New Issue
Block a user