Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop
This commit is contained in:
@@ -117,7 +117,7 @@ async def get_prompt_opt(
|
||||
user_require=data.message
|
||||
):
|
||||
# chunk 是 prompt 的增量内容
|
||||
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n"
|
||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
|
||||
@@ -29,7 +29,7 @@ class WorkflowState(TypedDict):
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
looping: bool
|
||||
looping: Annotated[bool, lambda x, y: x and y]
|
||||
|
||||
# Input variables (passed from configured variables)
|
||||
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
||||
|
||||
@@ -208,17 +208,12 @@ class HttpRequestNode(BaseNode):
|
||||
retries -= 1
|
||||
if retries > 0:
|
||||
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
|
||||
elif self.typed_config.error_handle.method == HttpErrorHandle.NONE:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"HTTP request node exception: {e}")
|
||||
else:
|
||||
match self.typed_config.error_handle.method:
|
||||
case HttpErrorHandle.NONE:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, returning error response"
|
||||
)
|
||||
return HttpRequestNodeOutput(
|
||||
body="",
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
).model_dump()
|
||||
case HttpErrorHandle.DEFAULT:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, returning default result"
|
||||
@@ -229,3 +224,4 @@ class HttpRequestNode(BaseNode):
|
||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||
)
|
||||
return "ERROR"
|
||||
raise RuntimeError("http request failed")
|
||||
|
||||
@@ -203,15 +203,16 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
# Deduplicate hybrid retrieval results
|
||||
# Deduplicate hy brid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
# TODO:其他重排序方式支持
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
logger.info(
|
||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||
)
|
||||
return [chunk.model_dump() for chunk in final_rs]
|
||||
return [chunk.page_content for chunk in final_rs]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""LLM 节点配置"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
@@ -7,17 +9,17 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefiniti
|
||||
|
||||
class MessageConfig(BaseModel):
|
||||
"""消息配置"""
|
||||
|
||||
|
||||
role: str = Field(
|
||||
...,
|
||||
description="消息角色:system, user, assistant"
|
||||
)
|
||||
|
||||
|
||||
content: str = Field(
|
||||
...,
|
||||
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||
)
|
||||
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, v: str) -> str:
|
||||
@@ -35,24 +37,29 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
1. 简单模式:使用 prompt 字段
|
||||
2. 消息模式:使用 messages 字段(推荐)
|
||||
"""
|
||||
|
||||
|
||||
model_id: str = Field(
|
||||
...,
|
||||
description="模型配置 ID"
|
||||
)
|
||||
|
||||
|
||||
context: Any = Field(
|
||||
default="",
|
||||
description="上下文"
|
||||
)
|
||||
|
||||
# 简单模式
|
||||
prompt: str | None = Field(
|
||||
default=None,
|
||||
description="提示词模板(简单模式),支持变量引用"
|
||||
)
|
||||
|
||||
|
||||
# 消息模式(推荐)
|
||||
messages: list[MessageConfig] | None = Field(
|
||||
default=None,
|
||||
description="消息列表(消息模式),支持多轮对话"
|
||||
)
|
||||
|
||||
|
||||
# 模型参数
|
||||
temperature: float | None = Field(
|
||||
default=0.7,
|
||||
@@ -60,35 +67,35 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
le=2.0,
|
||||
description="温度参数,控制输出的随机性"
|
||||
)
|
||||
|
||||
|
||||
max_tokens: int | None = Field(
|
||||
default=1000,
|
||||
ge=1,
|
||||
le=32000,
|
||||
description="最大生成 token 数"
|
||||
)
|
||||
|
||||
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Top-p 采样参数"
|
||||
)
|
||||
|
||||
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="频率惩罚"
|
||||
)
|
||||
|
||||
|
||||
presence_penalty: float | None = Field(
|
||||
default=None,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="存在惩罚"
|
||||
)
|
||||
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
@@ -105,14 +112,14 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
|
||||
|
||||
@field_validator("messages", "prompt")
|
||||
@classmethod
|
||||
def validate_input_mode(cls, v, info):
|
||||
"""验证输入模式:prompt 和 messages 至少有一个"""
|
||||
# 这个验证在 model_validator 中更合适
|
||||
return v
|
||||
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
|
||||
@@ -5,15 +5,17 @@ LLM 节点实现
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
@@ -63,8 +65,15 @@ class LLMNode(BaseNode):
|
||||
- user/human: 用户消息(HumanMessage)
|
||||
- ai/assistant: AI 消息(AIMessage)
|
||||
"""
|
||||
|
||||
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
|
||||
def _render_context(self, message,state):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
||||
return re.sub(r"{{context}}", context, message)
|
||||
|
||||
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
|
||||
"""准备 LLM 实例(公共逻辑)
|
||||
|
||||
Args:
|
||||
@@ -76,15 +85,16 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 1. 处理消息格式(优先使用 messages)
|
||||
messages_config = self.config.get("messages")
|
||||
|
||||
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
for msg_config in messages_config:
|
||||
role = msg_config.get("role", "user").lower()
|
||||
content_template = msg_config.get("content", "")
|
||||
content_template = self._render_context(content_template, state)
|
||||
content = self._render_template(content_template, state)
|
||||
|
||||
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append(SystemMessage(content=content))
|
||||
@@ -95,7 +105,7 @@ class LLMNode(BaseNode):
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append(HumanMessage(content=content))
|
||||
|
||||
|
||||
prompt_or_messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
@@ -106,17 +116,17 @@ class LLMNode(BaseNode):
|
||||
model_id = self.config.get("model_id")
|
||||
if not model_id:
|
||||
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
|
||||
|
||||
|
||||
# 3. 在 with 块内完成所有数据库操作和数据提取
|
||||
with get_db_context() as db:
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
|
||||
if not config:
|
||||
|
||||
if not config:
|
||||
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
|
||||
# 在 Session 关闭前提取所有需要的数据
|
||||
api_config = config.api_keys[0]
|
||||
model_name = api_config.model_name
|
||||
@@ -124,26 +134,26 @@ class LLMNode(BaseNode):
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
model_type = config.type
|
||||
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||
extra_params = {"streaming": stream} if stream else {}
|
||||
|
||||
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
extra_params=extra_params
|
||||
),
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||
|
||||
|
||||
return llm, prompt_or_messages
|
||||
|
||||
|
||||
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||
"""非流式执行 LLM 调用
|
||||
|
||||
@@ -153,10 +163,10 @@ class LLMNode(BaseNode):
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state,True)
|
||||
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
|
||||
# 调用 LLM(支持字符串或消息列表)
|
||||
response = await llm.ainvoke(prompt_or_messages)
|
||||
# 提取内容
|
||||
@@ -164,16 +174,16 @@ class LLMNode(BaseNode):
|
||||
content = response.content
|
||||
else:
|
||||
content = str(response)
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||
|
||||
|
||||
# 返回 AIMessage(包含响应元数据)
|
||||
return response if isinstance(response, AIMessage) else AIMessage(content=content)
|
||||
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)"""
|
||||
_, prompt_or_messages = self._prepare_llm(state)
|
||||
|
||||
|
||||
return {
|
||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||
"messages": [
|
||||
@@ -186,13 +196,13 @@ class LLMNode(BaseNode):
|
||||
"max_tokens": self.config.get("max_tokens")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _extract_output(self, business_result: Any) -> str:
|
||||
"""从 AIMessage 中提取文本内容"""
|
||||
if isinstance(business_result, AIMessage):
|
||||
return business_result.content
|
||||
return str(business_result)
|
||||
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
"""从 AIMessage 中提取 token 使用情况"""
|
||||
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
|
||||
@@ -204,7 +214,7 @@ class LLMNode(BaseNode):
|
||||
"total_tokens": usage.get('total_tokens', 0)
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 LLM 调用
|
||||
|
||||
@@ -215,26 +225,26 @@ class LLMNode(BaseNode):
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
|
||||
|
||||
# 检查是否有注入的 End 节点前缀配置
|
||||
writer = get_stream_writer()
|
||||
end_prefix = getattr(self, '_end_node_prefix', None)
|
||||
|
||||
|
||||
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
|
||||
if end_prefix:
|
||||
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
|
||||
|
||||
|
||||
if end_prefix:
|
||||
# 渲染前缀(可能包含其他变量)
|
||||
try:
|
||||
rendered_prefix = self._render_template(end_prefix, state)
|
||||
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
||||
|
||||
|
||||
# 提前发送 End 节点的前缀(使用 "message" 类型)
|
||||
writer({
|
||||
"type": "message", # End 相关的内容都是 message 类型
|
||||
@@ -246,12 +256,12 @@ class LLMNode(BaseNode):
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
|
||||
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
last_chunk = None
|
||||
chunk_count = 0
|
||||
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
async for chunk in llm.astream(prompt_or_messages):
|
||||
# 提取内容
|
||||
@@ -259,18 +269,18 @@ class LLMNode(BaseNode):
|
||||
content = chunk.content
|
||||
else:
|
||||
content = str(chunk)
|
||||
|
||||
|
||||
# 只有当内容不为空时才处理
|
||||
if content:
|
||||
full_response += content
|
||||
last_chunk = chunk
|
||||
chunk_count += 1
|
||||
|
||||
|
||||
# 流式返回每个文本片段
|
||||
yield content
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||
|
||||
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
if isinstance(last_chunk, AIMessage):
|
||||
final_message = AIMessage(
|
||||
@@ -279,6 +289,6 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
else:
|
||||
final_message = AIMessage(content=full_response)
|
||||
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": final_message}
|
||||
|
||||
@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
|
||||
|
||||
return await MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
message=self.typed_config.message,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
config_id=self.typed_config.config_id,
|
||||
search_switch=self.typed_config.search_switch,
|
||||
history=[],
|
||||
@@ -51,7 +51,7 @@ class MemoryWriteNode(BaseNode):
|
||||
|
||||
return await MemoryAgentService().write_memory(
|
||||
group_id=end_user_id,
|
||||
message=self.typed_config.message,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
config_id=self.typed_config.config_id,
|
||||
db=db,
|
||||
storage_type="neo4j",
|
||||
|
||||
@@ -65,7 +65,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
category_map[category_name] = case_tag
|
||||
return category_map
|
||||
|
||||
async def execute(self, state: WorkflowState) -> str:
|
||||
async def execute(self, state: WorkflowState) -> dict:
|
||||
"""执行问题分类"""
|
||||
question = self.typed_config.input_variable
|
||||
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
||||
@@ -79,7 +79,15 @@ class QuestionClassifierNode(BaseNode):
|
||||
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})"
|
||||
)
|
||||
# 若分类列表为空,返回默认unknown分支,否则返回CASE1
|
||||
return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown"
|
||||
if category_count > 0:
|
||||
return {
|
||||
"class_name": category_names[0],
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
return {
|
||||
"class_name": "unknown",
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
|
||||
try:
|
||||
llm = self._get_llm_instance()
|
||||
@@ -111,7 +119,10 @@ class QuestionClassifierNode(BaseNode):
|
||||
log_supplement = supplement_prompt if supplement_prompt else "无"
|
||||
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
||||
|
||||
return f"CASE{category_names.index(category) + 1}"
|
||||
return {
|
||||
"class_name": category,
|
||||
"output": f"CASE{category_names.index(category) + 1}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"节点 {self.node_id} 分类执行异常:{str(e)}",
|
||||
@@ -119,5 +130,11 @@ class QuestionClassifierNode(BaseNode):
|
||||
)
|
||||
# 异常时返回默认分支,保证工作流容错性
|
||||
if category_count > 0:
|
||||
return DEFAULT_EMPTY_QUESTION_CASE
|
||||
return "unknown"
|
||||
return {
|
||||
"class_name": category_names[0],
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
return {
|
||||
"class_name": "unknown",
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from pydantic import Field
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
@@ -6,4 +8,4 @@ class ToolNodeConfig(BaseNodeConfig):
|
||||
"""工具节点配置"""
|
||||
|
||||
tool_id: str = Field(..., description="工具ID")
|
||||
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
|
||||
tool_parameters: dict[str, Any] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
import uuid
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
@@ -9,6 +9,8 @@ from app.db import get_db_read
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
"""工具节点"""
|
||||
@@ -25,25 +27,33 @@ class ToolNode(BaseNode):
|
||||
|
||||
# 如果没有租户ID,尝试从工作流ID获取
|
||||
if not tenant_id:
|
||||
workflow_id = self.get_variable("sys.workflow_id", state)
|
||||
if workflow_id:
|
||||
workspace_id = self.get_variable("sys.workspace_id", state)
|
||||
if workspace_id:
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
with get_db_read() as db:
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
|
||||
|
||||
if not tenant_id:
|
||||
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
|
||||
# logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||
# return {"error": "缺少租户ID"}
|
||||
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||
return {
|
||||
"success": False,
|
||||
"data": "缺少租户ID"
|
||||
}
|
||||
|
||||
# 渲染工具参数
|
||||
rendered_parameters = {}
|
||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||
rendered_value = self._render_template(param_template, state)
|
||||
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
|
||||
try:
|
||||
rendered_value = self._render_template(param_template, state)
|
||||
except Exception as e:
|
||||
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
|
||||
else:
|
||||
# 非模板参数(数字/布尔/普通字符串)直接保留原值
|
||||
rendered_value = param_template
|
||||
rendered_parameters[param_name] = rendered_value
|
||||
|
||||
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
|
||||
print(self.typed_config.tool_id)
|
||||
|
||||
# 执行工具
|
||||
with get_db_read() as db:
|
||||
@@ -54,7 +64,7 @@ class ToolNode(BaseNode):
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id
|
||||
)
|
||||
print(result)
|
||||
|
||||
if result.success:
|
||||
logger.info(f"节点 {self.node_id} 工具执行成功")
|
||||
return {
|
||||
@@ -66,7 +76,7 @@ class ToolNode(BaseNode):
|
||||
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": result.error,
|
||||
"data": result.error,
|
||||
"error_code": result.error_code,
|
||||
"execution_time": result.execution_time
|
||||
}
|
||||
@@ -87,10 +87,11 @@ class WorkflowValidator:
|
||||
return graphs
|
||||
|
||||
@classmethod
|
||||
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
||||
def validate(cls, workflow_config: Union[dict[str, Any], Any], publish=False) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置
|
||||
|
||||
Args:
|
||||
publish: 发布验证标识
|
||||
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
|
||||
|
||||
Returns:
|
||||
@@ -114,7 +115,7 @@ class WorkflowValidator:
|
||||
|
||||
graphs = cls.get_subgraph(workflow_config)
|
||||
logger.info(graphs)
|
||||
for graph in graphs:
|
||||
for index, graph in enumerate(graphs):
|
||||
nodes = graph.get("nodes", [])
|
||||
edges = graph.get("edges", [])
|
||||
variables = graph.get("variables", [])
|
||||
@@ -125,10 +126,11 @@ class WorkflowValidator:
|
||||
elif len(start_nodes) > 1:
|
||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||
|
||||
# 2. 验证 end 节点(至少一个)
|
||||
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
||||
if len(end_nodes) == 0:
|
||||
errors.append("工作流必须至少有一个 end 节点")
|
||||
if index == len(graphs) - 1:
|
||||
# 2. 验证 主图end 节点(至少一个)
|
||||
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
||||
if len(end_nodes) == 0:
|
||||
errors.append("工作流必须至少有一个 end 节点")
|
||||
|
||||
# 3. 验证节点 ID 唯一性
|
||||
node_ids = [n.get("id") for n in nodes]
|
||||
@@ -159,15 +161,17 @@ class WorkflowValidator:
|
||||
elif target not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
||||
|
||||
# 6. 验证所有节点可达(从 start 节点出发)
|
||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||
reachable = WorkflowValidator._get_reachable_nodes(
|
||||
start_nodes[0]["id"],
|
||||
edges
|
||||
)
|
||||
unreachable = node_id_set - reachable
|
||||
if unreachable:
|
||||
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||
if publish:
|
||||
# 仅在发布时验证所有节点可达
|
||||
# 6. 验证所有节点可达(从 start 节点出发)
|
||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||
reachable = WorkflowValidator._get_reachable_nodes(
|
||||
start_nodes[0]["id"],
|
||||
edges
|
||||
)
|
||||
unreachable = node_id_set - reachable
|
||||
if unreachable:
|
||||
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||
|
||||
# 7. 检测循环依赖(非 loop 节点)
|
||||
if not errors: # 只有在前面验证通过时才检查循环
|
||||
@@ -288,7 +292,7 @@ class WorkflowValidator:
|
||||
(is_valid, errors): 是否有效和错误列表
|
||||
"""
|
||||
# 先执行基础验证
|
||||
is_valid, errors = WorkflowValidator.validate(workflow_config)
|
||||
is_valid, errors = WorkflowValidator.validate(workflow_config, publish=True)
|
||||
|
||||
if not is_valid:
|
||||
return False, errors
|
||||
|
||||
@@ -38,6 +38,33 @@ class ToolRepository:
|
||||
|
||||
return result[0] if result else None
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_id_by_workspace_id(db: Session, workspace_id: str) -> Optional[uuid.UUID]:
|
||||
"""
|
||||
根据空间ID获取tenant_id
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 空间ID
|
||||
|
||||
Returns:
|
||||
tenant_id或None
|
||||
"""
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
tenant_id = db.query(Workspace.tenant_id).filter(
|
||||
Workspace.id == workspace_id
|
||||
).scalar()
|
||||
|
||||
if tenant_id is not None and not isinstance(tenant_id, uuid.UUID):
|
||||
# 兼容数据库中字段类型不匹配的情况(比如存储为字符串)
|
||||
try:
|
||||
tenant_id = uuid.UUID(tenant_id)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
return tenant_id
|
||||
|
||||
@staticmethod
|
||||
def find_by_tenant(
|
||||
db: Session,
|
||||
|
||||
@@ -231,9 +231,9 @@ class PromptOptimizerService:
|
||||
if m:
|
||||
prompt_index = m.start()
|
||||
prompt_finished = True
|
||||
yield {"type": "delta", "content": buffer[idx:prompt_index]}
|
||||
yield {"content": buffer[idx:prompt_index]}
|
||||
else:
|
||||
yield {"type": "delta", "content": cache[idx:]}
|
||||
yield {"content": cache[idx:]}
|
||||
if len(cache) != 0:
|
||||
idx = len(cache)
|
||||
|
||||
@@ -249,8 +249,8 @@ class PromptOptimizerService:
|
||||
role=RoleType.ASSISTANT,
|
||||
content=desc
|
||||
)
|
||||
|
||||
yield {"type": "done", "desc": optim_result.get("desc")}
|
||||
variables = self.parser_prompt_variables(optim_result.get("prompt"))
|
||||
yield {"desc": optim_result.get("desc"), "variables": variables}
|
||||
|
||||
@staticmethod
|
||||
def parser_prompt_variables(prompt: str):
|
||||
|
||||
@@ -344,14 +344,16 @@ class ToolService:
|
||||
break
|
||||
|
||||
if operation_param:
|
||||
# 有多个操作
|
||||
# 有多个操作,为每个操作生成具体参数
|
||||
methods = []
|
||||
for operation in operation_param.enum:
|
||||
# 获取该操作的具体参数
|
||||
operation_params = self._get_operation_specific_params(tool_instance, operation)
|
||||
methods.append({
|
||||
"method_id": f"{config.name}_{operation}",
|
||||
"name": operation,
|
||||
"description": f"{config.description} - {operation}",
|
||||
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
|
||||
"parameters": operation_params
|
||||
})
|
||||
return methods
|
||||
else:
|
||||
@@ -362,6 +364,243 @@ class ToolService:
|
||||
"description": config.description,
|
||||
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
|
||||
}]
|
||||
|
||||
def _get_operation_specific_params(self, tool_instance: BaseTool, operation: str) -> List[Dict[str, Any]]:
|
||||
"""获取特定操作的参数列表"""
|
||||
# 对于datetime_tool,根据操作类型返回相关参数
|
||||
if hasattr(tool_instance, 'name') and tool_instance.name == 'datetime_tool':
|
||||
return self._get_datetime_tool_params(operation)
|
||||
# 对于json_tool,根据操作类型返回相关参数
|
||||
elif hasattr(tool_instance, 'name') and tool_instance.name == 'json_tool':
|
||||
return self._get_json_tool_params(operation)
|
||||
|
||||
# 其他工具的默认处理:返回除operation外的所有参数
|
||||
return [{
|
||||
"name": param.name,
|
||||
"type": param.type.value,
|
||||
"description": param.description,
|
||||
"required": param.required,
|
||||
"default": param.default,
|
||||
"enum": param.enum,
|
||||
"minimum": param.minimum,
|
||||
"maximum": param.maximum,
|
||||
"pattern": param.pattern
|
||||
} for param in tool_instance.parameters if param.name != "operation"]
|
||||
|
||||
def _get_datetime_tool_params(self, operation: str) -> List[Dict[str, Any]]:
|
||||
"""获取datetime_tool特定操作的参数"""
|
||||
if operation == "now":
|
||||
return [
|
||||
{
|
||||
"name": "to_timezone",
|
||||
"type": "string",
|
||||
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||
"required": False,
|
||||
"default": "Asia/Shanghai"
|
||||
},
|
||||
{
|
||||
"name": "output_format",
|
||||
"type": "string",
|
||||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
}
|
||||
]
|
||||
elif operation == "format":
|
||||
return [
|
||||
{
|
||||
"name": "input_value",
|
||||
"type": "string",
|
||||
"description": "输入值(时间字符串或时间戳)",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "input_format",
|
||||
"type": "string",
|
||||
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
},
|
||||
{
|
||||
"name": "output_format",
|
||||
"type": "string",
|
||||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
}
|
||||
]
|
||||
elif operation == "convert_timezone":
|
||||
return [
|
||||
{
|
||||
"name": "input_value",
|
||||
"type": "string",
|
||||
"description": "输入值(时间字符串或时间戳)",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "input_format",
|
||||
"type": "string",
|
||||
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
},
|
||||
{
|
||||
"name": "output_format",
|
||||
"type": "string",
|
||||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
},
|
||||
{
|
||||
"name": "from_timezone",
|
||||
"type": "string",
|
||||
"description": "源时区(如:UTC, Asia/Shanghai)",
|
||||
"required": False,
|
||||
"default": "Asia/Shanghai"
|
||||
},
|
||||
{
|
||||
"name": "to_timezone",
|
||||
"type": "string",
|
||||
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||
"required": False,
|
||||
"default": "Asia/Shanghai"
|
||||
}
|
||||
]
|
||||
elif operation == "timestamp_to_datetime":
|
||||
return [
|
||||
{
|
||||
"name": "input_value",
|
||||
"type": "string",
|
||||
"description": "输入值(时间字符串或时间戳)",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "output_format",
|
||||
"type": "string",
|
||||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
},
|
||||
{
|
||||
"name": "to_timezone",
|
||||
"type": "string",
|
||||
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||
"required": False,
|
||||
"default": "Asia/Shanghai"
|
||||
}
|
||||
]
|
||||
else:
|
||||
# 默认返回所有参数(除了operation)
|
||||
return [
|
||||
{
|
||||
"name": "input_value",
|
||||
"type": "string",
|
||||
"description": "输入值(时间字符串或时间戳)",
|
||||
"required": False
|
||||
},
|
||||
{
|
||||
"name": "input_format",
|
||||
"type": "string",
|
||||
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
},
|
||||
{
|
||||
"name": "output_format",
|
||||
"type": "string",
|
||||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
},
|
||||
{
|
||||
"name": "from_timezone",
|
||||
"type": "string",
|
||||
"description": "源时区(如:UTC, Asia/Shanghai)",
|
||||
"required": False,
|
||||
"default": "Asia/Shanghai"
|
||||
},
|
||||
{
|
||||
"name": "to_timezone",
|
||||
"type": "string",
|
||||
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||
"required": False,
|
||||
"default": "Asia/Shanghai"
|
||||
},
|
||||
{
|
||||
"name": "calculation",
|
||||
"type": "string",
|
||||
"description": "时间计算表达式(如:+1d, -2h, +30m)",
|
||||
"required": False
|
||||
}
|
||||
]
|
||||
|
||||
def _get_json_tool_params(self, operation: str) -> List[Dict[str, Any]]:
|
||||
"""获取json_tool特定操作的参数"""
|
||||
base_params = [
|
||||
{
|
||||
"name": "input_data",
|
||||
"type": "string",
|
||||
"description": "输入数据(JSON字符串、YAML字符串或XML字符串)",
|
||||
"required": True
|
||||
}
|
||||
]
|
||||
|
||||
if operation == "insert":
|
||||
return base_params + [
|
||||
{
|
||||
"name": "json_path",
|
||||
"type": "string",
|
||||
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "new_value",
|
||||
"type": "string",
|
||||
"description": "新值(用于insert操作)",
|
||||
"required": True
|
||||
}
|
||||
]
|
||||
elif operation == "replace":
|
||||
return base_params + [
|
||||
{
|
||||
"name": "json_path",
|
||||
"type": "string",
|
||||
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "old_text",
|
||||
"type": "string",
|
||||
"description": "要替换的原文本(用于replace操作)",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "new_text",
|
||||
"type": "string",
|
||||
"description": "替换后的新文本(用于replace操作)",
|
||||
"required": True
|
||||
}
|
||||
]
|
||||
elif operation == "delete":
|
||||
return base_params + [
|
||||
{
|
||||
"name": "json_path",
|
||||
"type": "string",
|
||||
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
"required": True
|
||||
}
|
||||
]
|
||||
elif operation == "parse":
|
||||
return base_params + [
|
||||
{
|
||||
"name": "json_path",
|
||||
"type": "string",
|
||||
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
"required": True
|
||||
}
|
||||
]
|
||||
|
||||
return base_params
|
||||
|
||||
async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
||||
"""获取自定义工具的方法"""
|
||||
|
||||
Reference in New Issue
Block a user