Merge pull request #995 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
refactor(http_request)
This commit is contained in:
@@ -272,6 +272,11 @@ class HttpRequestNodeOutput(BaseModel):
|
|||||||
description="HTTP response body",
|
description="HTTP response body",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
process_data: dict = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Raw HTTP request details for debugging",
|
||||||
|
)
|
||||||
|
|
||||||
# files: list[File] = Field(
|
# files: list[File] = Field(
|
||||||
# ...
|
# ...
|
||||||
# )
|
# )
|
||||||
|
|||||||
@@ -160,7 +160,6 @@ class HttpRequestNode(BaseNode):
|
|||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: HttpRequestNodeConfig | None = None
|
self.typed_config: HttpRequestNodeConfig | None = None
|
||||||
self.last_request: str = ""
|
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {
|
return {
|
||||||
@@ -171,47 +170,6 @@ class HttpRequestNode(BaseNode):
|
|||||||
"output": VariableType.STRING
|
"output": VariableType.STRING
|
||||||
}
|
}
|
||||||
|
|
||||||
def _extract_output(self, business_result: Any) -> Any:
|
|
||||||
if isinstance(business_result, dict):
|
|
||||||
result = {k: v for k, v in business_result.items() if k != "request"}
|
|
||||||
return result
|
|
||||||
return business_result
|
|
||||||
|
|
||||||
def _extract_extra_fields(self, business_result: Any) -> dict[str, Any]:
|
|
||||||
if isinstance(business_result, dict) and "request" in business_result:
|
|
||||||
return {
|
|
||||||
"process": {
|
|
||||||
"request": business_result.get("request", "")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _wrap_error(
|
|
||||||
self,
|
|
||||||
error_message: str,
|
|
||||||
elapsed_time: float,
|
|
||||||
state: WorkflowState,
|
|
||||||
variable_pool: VariablePool
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
input_data = self._extract_input(state, variable_pool)
|
|
||||||
node_output = {
|
|
||||||
"node_id": self.node_id,
|
|
||||||
"node_type": self.node_type,
|
|
||||||
"node_name": self.node_name,
|
|
||||||
"status": "failed",
|
|
||||||
"input": input_data,
|
|
||||||
"output": None,
|
|
||||||
"process": {"request": self.last_request} if self.last_request else None,
|
|
||||||
"elapsed_time": elapsed_time,
|
|
||||||
"token_usage": None,
|
|
||||||
"error": error_message
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
"node_outputs": {self.node_id: node_output},
|
|
||||||
"error": error_message,
|
|
||||||
"error_node": self.node_id
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_timeout(self) -> Timeout:
|
def _build_timeout(self) -> Timeout:
|
||||||
"""
|
"""
|
||||||
Build httpx Timeout configuration.
|
Build httpx Timeout configuration.
|
||||||
@@ -297,13 +255,9 @@ class HttpRequestNode(BaseNode):
|
|||||||
case HttpContentType.NONE:
|
case HttpContentType.NONE:
|
||||||
return {}
|
return {}
|
||||||
case HttpContentType.JSON:
|
case HttpContentType.JSON:
|
||||||
rendered_body = self._render_template(
|
content["json"] = json.loads(self._render_template(
|
||||||
self.typed_config.body.data, variable_pool
|
self.typed_config.body.data, variable_pool
|
||||||
).strip()
|
))
|
||||||
if not rendered_body:
|
|
||||||
content["json"] = {}
|
|
||||||
else:
|
|
||||||
content["json"] = json.loads(rendered_body)
|
|
||||||
case HttpContentType.FROM_DATA:
|
case HttpContentType.FROM_DATA:
|
||||||
data = {}
|
data = {}
|
||||||
files = []
|
files = []
|
||||||
@@ -371,61 +325,15 @@ class HttpRequestNode(BaseNode):
|
|||||||
case _:
|
case _:
|
||||||
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
|
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
|
||||||
|
|
||||||
def _generate_raw_request(
|
def _extract_output(self, business_result: Any) -> Any:
|
||||||
self,
|
if isinstance(business_result, dict):
|
||||||
variable_pool: VariablePool,
|
return {k: v for k, v in business_result.items() if k != "process_data"}
|
||||||
url: str,
|
return business_result
|
||||||
headers: dict[str, str],
|
|
||||||
params: dict[str, str],
|
|
||||||
content: dict[str, Any]
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Generate raw HTTP request format for debugging.
|
|
||||||
|
|
||||||
Args:
|
def _extract_extra_fields(self, business_result: Any) -> dict:
|
||||||
variable_pool: Variable Pool
|
if isinstance(business_result, dict) and "process_data" in business_result:
|
||||||
url: Rendered URL
|
return {"process": business_result["process_data"]}
|
||||||
headers: Request headers
|
return {}
|
||||||
params: Query parameters
|
|
||||||
content: Request body content
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Raw HTTP request string
|
|
||||||
"""
|
|
||||||
method = self.typed_config.method.value
|
|
||||||
|
|
||||||
if params:
|
|
||||||
param_str = "&".join([f"{k}={v}" for k, v in params.items()])
|
|
||||||
full_url = f"{url}?{param_str}" if "?" not in url else f"{url}&{param_str}"
|
|
||||||
else:
|
|
||||||
full_url = url
|
|
||||||
|
|
||||||
lines = [f"{method} {full_url} HTTP/1.1"]
|
|
||||||
|
|
||||||
for key, value in headers.items():
|
|
||||||
lines.append(f"{key}: {value}")
|
|
||||||
|
|
||||||
if "json" in content and content["json"]:
|
|
||||||
json_body = json.dumps(content["json"], ensure_ascii=False)
|
|
||||||
lines.append(f"Content-Length: {len(json_body)}")
|
|
||||||
lines.append("")
|
|
||||||
lines.append(json_body)
|
|
||||||
elif "data" in content and "files" not in content:
|
|
||||||
if isinstance(content["data"], dict):
|
|
||||||
body_str = "&".join([f"{k}={v}" for k, v in content["data"].items()])
|
|
||||||
lines.append(f"Content-Length: {len(body_str)}")
|
|
||||||
lines.append("")
|
|
||||||
lines.append(body_str)
|
|
||||||
elif "content" in content:
|
|
||||||
lines.append(f"Content-Length: {len(content['content'])}")
|
|
||||||
lines.append("")
|
|
||||||
lines.append(content["content"])
|
|
||||||
elif "files" in content:
|
|
||||||
lines.append("Content-Length: 0")
|
|
||||||
lines.append("")
|
|
||||||
lines.append("# Note: This request includes file uploads")
|
|
||||||
|
|
||||||
return "\r\n".join(lines)
|
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str:
|
||||||
"""
|
"""
|
||||||
@@ -445,47 +353,42 @@ class HttpRequestNode(BaseNode):
|
|||||||
- str: Branch identifier (e.g. "ERROR") when branching is enabled
|
- str: Branch identifier (e.g. "ERROR") when branching is enabled
|
||||||
"""
|
"""
|
||||||
self.typed_config = HttpRequestNodeConfig(**self.config)
|
self.typed_config = HttpRequestNodeConfig(**self.config)
|
||||||
|
rendered_url = self._render_template(self.typed_config.url, variable_pool)
|
||||||
# Build request components
|
built_headers = self._build_header(variable_pool) | self._build_auth(variable_pool)
|
||||||
headers = self._build_header(variable_pool) | self._build_auth(variable_pool)
|
built_params = self._build_params(variable_pool)
|
||||||
params = self._build_params(variable_pool)
|
|
||||||
content = await self._build_content(variable_pool)
|
|
||||||
url = self._render_template(self.typed_config.url, variable_pool)
|
|
||||||
|
|
||||||
logger.info(f"Node {self.node_id}: headers={headers}, params={params}, content keys={list(content.keys())}")
|
|
||||||
|
|
||||||
# Generate raw HTTP request for debugging
|
|
||||||
raw_request = self._generate_raw_request(variable_pool, url, headers, params, content)
|
|
||||||
self.last_request = raw_request
|
|
||||||
logger.info(f"Node {self.node_id}: Generated HTTP request:\n{raw_request}")
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
verify=self.typed_config.verify_ssl,
|
verify=self.typed_config.verify_ssl,
|
||||||
timeout=self._build_timeout(),
|
timeout=self._build_timeout(),
|
||||||
headers=headers,
|
headers=built_headers,
|
||||||
params=params,
|
params=built_params,
|
||||||
follow_redirects=True
|
follow_redirects=True
|
||||||
) as client:
|
) as client:
|
||||||
retries = self.typed_config.retry.max_attempts
|
retries = self.typed_config.retry.max_attempts
|
||||||
while retries > 0:
|
while retries > 0:
|
||||||
try:
|
try:
|
||||||
request_func = self._get_client_method(client)
|
request_func = self._get_client_method(client)
|
||||||
|
built_content = await self._build_content(variable_pool)
|
||||||
resp = await request_func(
|
resp = await request_func(
|
||||||
url=url,
|
url=rendered_url,
|
||||||
**content
|
**built_content
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||||
response = HttpResponse(resp)
|
response = HttpResponse(resp)
|
||||||
return {
|
# Build raw request summary for process_data
|
||||||
**HttpRequestNodeOutput(
|
raw_request = (
|
||||||
body=response.body,
|
f"{self.typed_config.method.upper()} {resp.request.url} HTTP/1.1\r\n"
|
||||||
status_code=resp.status_code,
|
+ "".join(f"{k}: {v}\r\n" for k, v in resp.request.headers.items())
|
||||||
headers=resp.headers,
|
+ "\r\n"
|
||||||
files=response.files
|
+ (resp.request.content.decode(errors="replace") if resp.request.content else "")
|
||||||
).model_dump(),
|
)
|
||||||
"request": raw_request
|
return HttpRequestNodeOutput(
|
||||||
}
|
body=response.body,
|
||||||
|
status_code=resp.status_code,
|
||||||
|
headers=resp.headers,
|
||||||
|
files=response.files,
|
||||||
|
process_data={"request": raw_request},
|
||||||
|
).model_dump()
|
||||||
except (httpx.HTTPStatusError, httpx.RequestError) as e:
|
except (httpx.HTTPStatusError, httpx.RequestError) as e:
|
||||||
logger.error(f"HTTP request node exception: {e}")
|
logger.error(f"HTTP request node exception: {e}")
|
||||||
retries -= 1
|
retries -= 1
|
||||||
@@ -501,19 +404,10 @@ class HttpRequestNode(BaseNode):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"Node {self.node_id}: HTTP request failed, returning default result"
|
f"Node {self.node_id}: HTTP request failed, returning default result"
|
||||||
)
|
)
|
||||||
error_result = self.typed_config.error_handle.default.model_dump()
|
return self.typed_config.error_handle.default.model_dump()
|
||||||
error_result["request"] = raw_request
|
|
||||||
return error_result
|
|
||||||
case HttpErrorHandle.BRANCH:
|
case HttpErrorHandle.BRANCH:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||||
)
|
)
|
||||||
return {
|
return {"output": "ERROR"}
|
||||||
"output": "ERROR",
|
|
||||||
"body": "",
|
|
||||||
"status_code": 500,
|
|
||||||
"headers": {},
|
|
||||||
"files": [],
|
|
||||||
"request": raw_request
|
|
||||||
}
|
|
||||||
raise RuntimeError("http request failed")
|
raise RuntimeError("http request failed")
|
||||||
|
|||||||
@@ -399,24 +399,6 @@ class AppChatService:
|
|||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.model_parameters
|
model_parameters = config.model_parameters
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
|
||||||
agent = LangChainAgent(
|
|
||||||
model_name=api_key_obj.model_name,
|
|
||||||
api_key=api_key_obj.api_key,
|
|
||||||
provider=api_key_obj.provider,
|
|
||||||
api_base=api_key_obj.api_base,
|
|
||||||
is_omni=api_key_obj.is_omni,
|
|
||||||
temperature=model_parameters.get("temperature", 0.7),
|
|
||||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
tools=tools,
|
|
||||||
streaming=True,
|
|
||||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
|
||||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
|
||||||
json_output=model_parameters.get("json_output", False),
|
|
||||||
capability=api_key_obj.capability or [],
|
|
||||||
)
|
|
||||||
|
|
||||||
model_info = ModelInfo(
|
model_info = ModelInfo(
|
||||||
model_name=api_key_obj.model_name,
|
model_name=api_key_obj.model_name,
|
||||||
provider=api_key_obj.provider,
|
provider=api_key_obj.provider,
|
||||||
@@ -471,16 +453,28 @@ class AppChatService:
|
|||||||
f.type == FileType.DOCUMENT for f in files
|
f.type == FileType.DOCUMENT for f in files
|
||||||
):
|
):
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
agent.system_prompt += (
|
system_prompt += (
|
||||||
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
|
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
||||||
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
|
||||||
)
|
|
||||||
agent.agent = create_agent(
|
|
||||||
model=agent.llm,
|
|
||||||
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
|
||||||
system_prompt=agent.system_prompt
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 创建 LangChain Agent
|
||||||
|
agent = LangChainAgent(
|
||||||
|
model_name=api_key_obj.model_name,
|
||||||
|
api_key=api_key_obj.api_key,
|
||||||
|
provider=api_key_obj.provider,
|
||||||
|
api_base=api_key_obj.api_base,
|
||||||
|
is_omni=api_key_obj.is_omni,
|
||||||
|
temperature=model_parameters.get("temperature", 0.7),
|
||||||
|
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
streaming=True,
|
||||||
|
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||||
|
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||||
|
json_output=model_parameters.get("json_output", False),
|
||||||
|
capability=api_key_obj.capability or [],
|
||||||
|
)
|
||||||
|
|
||||||
# 为需要运行时上下文的工具注入上下文
|
# 为需要运行时上下文的工具注入上下文
|
||||||
for t in tools:
|
for t in tools:
|
||||||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||||||
|
|||||||
@@ -875,24 +875,6 @@ class AgentRunService:
|
|||||||
user_rag_memory_id)
|
user_rag_memory_id)
|
||||||
tools.extend(memory_tools)
|
tools.extend(memory_tools)
|
||||||
|
|
||||||
# 4. 创建 LangChain Agent
|
|
||||||
agent = LangChainAgent(
|
|
||||||
model_name=api_key_config["model_name"],
|
|
||||||
api_key=api_key_config["api_key"],
|
|
||||||
provider=api_key_config.get("provider", "openai"),
|
|
||||||
api_base=api_key_config.get("api_base"),
|
|
||||||
is_omni=api_key_config.get("is_omni", False),
|
|
||||||
temperature=effective_params.get("temperature", 0.7),
|
|
||||||
max_tokens=effective_params.get("max_tokens", 2000),
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
tools=tools,
|
|
||||||
streaming=True,
|
|
||||||
deep_thinking=effective_params.get("deep_thinking", False),
|
|
||||||
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
|
||||||
json_output=effective_params.get("json_output", False),
|
|
||||||
capability=api_key_config.get("capability", []),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||||
is_new_conversation = not conversation_id
|
is_new_conversation = not conversation_id
|
||||||
opening, suggested_questions = None, None
|
opening, suggested_questions = None, None
|
||||||
@@ -948,18 +930,28 @@ class AgentRunService:
|
|||||||
and any(f.type == FileType.DOCUMENT for f in files)
|
and any(f.type == FileType.DOCUMENT for f in files)
|
||||||
)
|
)
|
||||||
if has_doc_with_images:
|
if has_doc_with_images:
|
||||||
agent.system_prompt += (
|
system_prompt += (
|
||||||
"\n\n文档中包含图片,图片位置已在文本中以 [图片 第N页 第M张图片]: URL 标记。"
|
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
||||||
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
|
||||||
"**规则1:图片URL必须原封不动、一字不差地复制,禁止修改、禁止省略任何字符**"
|
|
||||||
"**规则2:禁止修改URL中UUID里的任何数字和字母**"
|
|
||||||
"**规则3:直接使用  格式输出**"
|
|
||||||
)
|
|
||||||
agent.agent = create_agent(
|
|
||||||
model=agent.llm,
|
|
||||||
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
|
||||||
system_prompt=agent.system_prompt
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 创建 LangChain Agent
|
||||||
|
agent = LangChainAgent(
|
||||||
|
model_name=api_key_config["model_name"],
|
||||||
|
api_key=api_key_config["api_key"],
|
||||||
|
provider=api_key_config.get("provider", "openai"),
|
||||||
|
api_base=api_key_config.get("api_base"),
|
||||||
|
is_omni=api_key_config.get("is_omni", False),
|
||||||
|
temperature=effective_params.get("temperature", 0.7),
|
||||||
|
max_tokens=effective_params.get("max_tokens", 2000),
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
streaming=True,
|
||||||
|
deep_thinking=effective_params.get("deep_thinking", False),
|
||||||
|
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
||||||
|
json_output=effective_params.get("json_output", False),
|
||||||
|
capability=api_key_config.get("capability", []),
|
||||||
|
)
|
||||||
|
|
||||||
# 为需要运行时上下文的工具注入上下文
|
# 为需要运行时上下文的工具注入上下文
|
||||||
for t in tools:
|
for t in tools:
|
||||||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||||||
|
|||||||
Reference in New Issue
Block a user