Merge pull request #995 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
refactor(http_request)
This commit is contained in:
@@ -272,6 +272,11 @@ class HttpRequestNodeOutput(BaseModel):
|
||||
description="HTTP response body",
|
||||
)
|
||||
|
||||
process_data: dict = Field(
|
||||
default_factory=dict,
|
||||
description="Raw HTTP request details for debugging",
|
||||
)
|
||||
|
||||
# files: list[File] = Field(
|
||||
# ...
|
||||
# )
|
||||
|
||||
@@ -160,7 +160,6 @@ class HttpRequestNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: HttpRequestNodeConfig | None = None
|
||||
self.last_request: str = ""
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
@@ -171,47 +170,6 @@ class HttpRequestNode(BaseNode):
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> Any:
|
||||
if isinstance(business_result, dict):
|
||||
result = {k: v for k, v in business_result.items() if k != "request"}
|
||||
return result
|
||||
return business_result
|
||||
|
||||
def _extract_extra_fields(self, business_result: Any) -> dict[str, Any]:
|
||||
if isinstance(business_result, dict) and "request" in business_result:
|
||||
return {
|
||||
"process": {
|
||||
"request": business_result.get("request", "")
|
||||
}
|
||||
}
|
||||
return {}
|
||||
|
||||
def _wrap_error(
|
||||
self,
|
||||
error_message: str,
|
||||
elapsed_time: float,
|
||||
state: WorkflowState,
|
||||
variable_pool: VariablePool
|
||||
) -> dict[str, Any]:
|
||||
input_data = self._extract_input(state, variable_pool)
|
||||
node_output = {
|
||||
"node_id": self.node_id,
|
||||
"node_type": self.node_type,
|
||||
"node_name": self.node_name,
|
||||
"status": "failed",
|
||||
"input": input_data,
|
||||
"output": None,
|
||||
"process": {"request": self.last_request} if self.last_request else None,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": None,
|
||||
"error": error_message
|
||||
}
|
||||
return {
|
||||
"node_outputs": {self.node_id: node_output},
|
||||
"error": error_message,
|
||||
"error_node": self.node_id
|
||||
}
|
||||
|
||||
def _build_timeout(self) -> Timeout:
|
||||
"""
|
||||
Build httpx Timeout configuration.
|
||||
@@ -297,13 +255,9 @@ class HttpRequestNode(BaseNode):
|
||||
case HttpContentType.NONE:
|
||||
return {}
|
||||
case HttpContentType.JSON:
|
||||
rendered_body = self._render_template(
|
||||
content["json"] = json.loads(self._render_template(
|
||||
self.typed_config.body.data, variable_pool
|
||||
).strip()
|
||||
if not rendered_body:
|
||||
content["json"] = {}
|
||||
else:
|
||||
content["json"] = json.loads(rendered_body)
|
||||
))
|
||||
case HttpContentType.FROM_DATA:
|
||||
data = {}
|
||||
files = []
|
||||
@@ -371,61 +325,15 @@ class HttpRequestNode(BaseNode):
|
||||
case _:
|
||||
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
|
||||
|
||||
def _generate_raw_request(
|
||||
self,
|
||||
variable_pool: VariablePool,
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
params: dict[str, str],
|
||||
content: dict[str, Any]
|
||||
) -> str:
|
||||
"""
|
||||
Generate raw HTTP request format for debugging.
|
||||
def _extract_output(self, business_result: Any) -> Any:
|
||||
if isinstance(business_result, dict):
|
||||
return {k: v for k, v in business_result.items() if k != "process_data"}
|
||||
return business_result
|
||||
|
||||
Args:
|
||||
variable_pool: Variable Pool
|
||||
url: Rendered URL
|
||||
headers: Request headers
|
||||
params: Query parameters
|
||||
content: Request body content
|
||||
|
||||
Returns:
|
||||
Raw HTTP request string
|
||||
"""
|
||||
method = self.typed_config.method.value
|
||||
|
||||
if params:
|
||||
param_str = "&".join([f"{k}={v}" for k, v in params.items()])
|
||||
full_url = f"{url}?{param_str}" if "?" not in url else f"{url}&{param_str}"
|
||||
else:
|
||||
full_url = url
|
||||
|
||||
lines = [f"{method} {full_url} HTTP/1.1"]
|
||||
|
||||
for key, value in headers.items():
|
||||
lines.append(f"{key}: {value}")
|
||||
|
||||
if "json" in content and content["json"]:
|
||||
json_body = json.dumps(content["json"], ensure_ascii=False)
|
||||
lines.append(f"Content-Length: {len(json_body)}")
|
||||
lines.append("")
|
||||
lines.append(json_body)
|
||||
elif "data" in content and "files" not in content:
|
||||
if isinstance(content["data"], dict):
|
||||
body_str = "&".join([f"{k}={v}" for k, v in content["data"].items()])
|
||||
lines.append(f"Content-Length: {len(body_str)}")
|
||||
lines.append("")
|
||||
lines.append(body_str)
|
||||
elif "content" in content:
|
||||
lines.append(f"Content-Length: {len(content['content'])}")
|
||||
lines.append("")
|
||||
lines.append(content["content"])
|
||||
elif "files" in content:
|
||||
lines.append("Content-Length: 0")
|
||||
lines.append("")
|
||||
lines.append("# Note: This request includes file uploads")
|
||||
|
||||
return "\r\n".join(lines)
|
||||
def _extract_extra_fields(self, business_result: Any) -> dict:
|
||||
if isinstance(business_result, dict) and "process_data" in business_result:
|
||||
return {"process": business_result["process_data"]}
|
||||
return {}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str:
|
||||
"""
|
||||
@@ -445,47 +353,42 @@ class HttpRequestNode(BaseNode):
|
||||
- str: Branch identifier (e.g. "ERROR") when branching is enabled
|
||||
"""
|
||||
self.typed_config = HttpRequestNodeConfig(**self.config)
|
||||
|
||||
# Build request components
|
||||
headers = self._build_header(variable_pool) | self._build_auth(variable_pool)
|
||||
params = self._build_params(variable_pool)
|
||||
content = await self._build_content(variable_pool)
|
||||
url = self._render_template(self.typed_config.url, variable_pool)
|
||||
|
||||
logger.info(f"Node {self.node_id}: headers={headers}, params={params}, content keys={list(content.keys())}")
|
||||
|
||||
# Generate raw HTTP request for debugging
|
||||
raw_request = self._generate_raw_request(variable_pool, url, headers, params, content)
|
||||
self.last_request = raw_request
|
||||
logger.info(f"Node {self.node_id}: Generated HTTP request:\n{raw_request}")
|
||||
|
||||
rendered_url = self._render_template(self.typed_config.url, variable_pool)
|
||||
built_headers = self._build_header(variable_pool) | self._build_auth(variable_pool)
|
||||
built_params = self._build_params(variable_pool)
|
||||
async with httpx.AsyncClient(
|
||||
verify=self.typed_config.verify_ssl,
|
||||
timeout=self._build_timeout(),
|
||||
headers=headers,
|
||||
params=params,
|
||||
headers=built_headers,
|
||||
params=built_params,
|
||||
follow_redirects=True
|
||||
) as client:
|
||||
retries = self.typed_config.retry.max_attempts
|
||||
while retries > 0:
|
||||
try:
|
||||
request_func = self._get_client_method(client)
|
||||
built_content = await self._build_content(variable_pool)
|
||||
resp = await request_func(
|
||||
url=url,
|
||||
**content
|
||||
url=rendered_url,
|
||||
**built_content
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||
response = HttpResponse(resp)
|
||||
return {
|
||||
**HttpRequestNodeOutput(
|
||||
body=response.body,
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
files=response.files
|
||||
).model_dump(),
|
||||
"request": raw_request
|
||||
}
|
||||
# Build raw request summary for process_data
|
||||
raw_request = (
|
||||
f"{self.typed_config.method.upper()} {resp.request.url} HTTP/1.1\r\n"
|
||||
+ "".join(f"{k}: {v}\r\n" for k, v in resp.request.headers.items())
|
||||
+ "\r\n"
|
||||
+ (resp.request.content.decode(errors="replace") if resp.request.content else "")
|
||||
)
|
||||
return HttpRequestNodeOutput(
|
||||
body=response.body,
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
files=response.files,
|
||||
process_data={"request": raw_request},
|
||||
).model_dump()
|
||||
except (httpx.HTTPStatusError, httpx.RequestError) as e:
|
||||
logger.error(f"HTTP request node exception: {e}")
|
||||
retries -= 1
|
||||
@@ -501,19 +404,10 @@ class HttpRequestNode(BaseNode):
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, returning default result"
|
||||
)
|
||||
error_result = self.typed_config.error_handle.default.model_dump()
|
||||
error_result["request"] = raw_request
|
||||
return error_result
|
||||
return self.typed_config.error_handle.default.model_dump()
|
||||
case HttpErrorHandle.BRANCH:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||
)
|
||||
return {
|
||||
"output": "ERROR",
|
||||
"body": "",
|
||||
"status_code": 500,
|
||||
"headers": {},
|
||||
"files": [],
|
||||
"request": raw_request
|
||||
}
|
||||
return {"output": "ERROR"}
|
||||
raise RuntimeError("http request failed")
|
||||
|
||||
@@ -399,24 +399,6 @@ class AppChatService:
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
|
||||
# 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
model_name=api_key_obj.model_name,
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
streaming=True,
|
||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||
json_output=model_parameters.get("json_output", False),
|
||||
capability=api_key_obj.capability or [],
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
@@ -471,16 +453,28 @@ class AppChatService:
|
||||
f.type == FileType.DOCUMENT for f in files
|
||||
):
|
||||
from langchain.agents import create_agent
|
||||
agent.system_prompt += (
|
||||
"\n\n文档中包含图片,图片位置已在文本中以 [第N页 第M张图片]: URL 标记。"
|
||||
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
||||
)
|
||||
agent.agent = create_agent(
|
||||
model=agent.llm,
|
||||
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
||||
system_prompt=agent.system_prompt
|
||||
system_prompt += (
|
||||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
||||
)
|
||||
|
||||
# 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
model_name=api_key_obj.model_name,
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
streaming=True,
|
||||
deep_thinking=model_parameters.get("deep_thinking", False),
|
||||
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
|
||||
json_output=model_parameters.get("json_output", False),
|
||||
capability=api_key_obj.capability or [],
|
||||
)
|
||||
|
||||
# 为需要运行时上下文的工具注入上下文
|
||||
for t in tools:
|
||||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||||
|
||||
@@ -875,24 +875,6 @@ class AgentRunService:
|
||||
user_rag_memory_id)
|
||||
tools.extend(memory_tools)
|
||||
|
||||
# 4. 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
model_name=api_key_config["model_name"],
|
||||
api_key=api_key_config["api_key"],
|
||||
provider=api_key_config.get("provider", "openai"),
|
||||
api_base=api_key_config.get("api_base"),
|
||||
is_omni=api_key_config.get("is_omni", False),
|
||||
temperature=effective_params.get("temperature", 0.7),
|
||||
max_tokens=effective_params.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
streaming=True,
|
||||
deep_thinking=effective_params.get("deep_thinking", False),
|
||||
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
||||
json_output=effective_params.get("json_output", False),
|
||||
capability=api_key_config.get("capability", []),
|
||||
)
|
||||
|
||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||
is_new_conversation = not conversation_id
|
||||
opening, suggested_questions = None, None
|
||||
@@ -948,18 +930,28 @@ class AgentRunService:
|
||||
and any(f.type == FileType.DOCUMENT for f in files)
|
||||
)
|
||||
if has_doc_with_images:
|
||||
agent.system_prompt += (
|
||||
"\n\n文档中包含图片,图片位置已在文本中以 [图片 第N页 第M张图片]: URL 标记。"
|
||||
"请在回答中用 Markdown 格式  展示相关图片,做到图文并茂。"
|
||||
"**规则1:图片URL必须原封不动、一字不差地复制,禁止修改、禁止省略任何字符**"
|
||||
"**规则2:禁止修改URL中UUID里的任何数字和字母**"
|
||||
"**规则3:直接使用  格式输出**"
|
||||
)
|
||||
agent.agent = create_agent(
|
||||
model=agent.llm,
|
||||
tools=agent._wrap_tools_with_tracking(agent.tools) if agent.tools else None,
|
||||
system_prompt=agent.system_prompt
|
||||
system_prompt += (
|
||||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
||||
)
|
||||
|
||||
# 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
model_name=api_key_config["model_name"],
|
||||
api_key=api_key_config["api_key"],
|
||||
provider=api_key_config.get("provider", "openai"),
|
||||
api_base=api_key_config.get("api_base"),
|
||||
is_omni=api_key_config.get("is_omni", False),
|
||||
temperature=effective_params.get("temperature", 0.7),
|
||||
max_tokens=effective_params.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
streaming=True,
|
||||
deep_thinking=effective_params.get("deep_thinking", False),
|
||||
thinking_budget_tokens=effective_params.get("thinking_budget_tokens"),
|
||||
json_output=effective_params.get("json_output", False),
|
||||
capability=api_key_config.get("capability", []),
|
||||
)
|
||||
|
||||
# 为需要运行时上下文的工具注入上下文
|
||||
for t in tools:
|
||||
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
|
||||
|
||||
Reference in New Issue
Block a user