Merge pull request #928 from SuanmoSuanyangTechnology/feature/agent-tool_xjn

fix(llm)
This commit is contained in:
山程漫悟
2026-04-17 14:23:16 +08:00
committed by GitHub
5 changed files with 51 additions and 26 deletions

View File

@@ -12,7 +12,7 @@ import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langgraph.errors import GraphRecursionError from langgraph.errors import GraphRecursionError
@@ -83,7 +83,12 @@ class LangChainAgent:
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format # ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format
# 在 system prompt 中注入 JSON 要求 # 在 system prompt 中注入 JSON 要求
from app.models.models_model import ModelProvider from app.models.models_model import ModelProvider
if json_output and provider.lower() == ModelProvider.DASHSCOPE and not is_omni: if json_output and (
(provider.lower() == ModelProvider.DASHSCOPE and not is_omni)
or provider.lower() == ModelProvider.VOLCANO
# 有工具时 response_format 会被移除,所有 provider 都需要 system prompt 注入保证 JSON 输出
or bool(tools)
):
self.system_prompt += "\n请以JSON格式输出。" self.system_prompt += "\n请以JSON格式输出。"
logger.debug( logger.debug(
@@ -240,9 +245,7 @@ class LangChainAgent:
Returns: Returns:
List[BaseMessage]: 消息列表 List[BaseMessage]: 消息列表
""" """
messages:list = [SystemMessage(content=self.system_prompt)] messages: list = []
# 添加系统提示词
# 添加历史消息 # 添加历史消息
if history: if history:

View File

@@ -101,12 +101,10 @@ class RedBearModelFactory:
extra_body["enable_thinking"] = True extra_body["enable_thinking"] = True
if config.thinking_budget_tokens: if config.thinking_budget_tokens:
extra_body["thinking_budget"] = config.thinking_budget_tokens extra_body["thinking_budget"] = config.thinking_budget_tokens
params["extra_body"] = extra_body
# JSON 输出模式 # JSON 输出模式
if config.json_output: if config.json_output:
model_kwargs = params.setdefault("model_kwargs", {}) model_kwargs = params.setdefault("model_kwargs", {})
model_kwargs["response_format"] = {"type": "json_object"} model_kwargs["response_format"] = {"type": "json_object"}
params["model_kwargs"] = model_kwargs
return params return params
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]: if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
@@ -148,11 +146,12 @@ class RedBearModelFactory:
extra_body["enable_thinking"] = True extra_body["enable_thinking"] = True
if config.thinking_budget_tokens: if config.thinking_budget_tokens:
extra_body["thinking_budget"] = config.thinking_budget_tokens extra_body["thinking_budget"] = config.thinking_budget_tokens
params["extra_body"] = extra_body
# JSON 输出模式 # JSON 输出模式
if config.json_output: if config.json_output:
params.setdefault("model_kwargs", {}) model_kwargs = params.setdefault("model_kwargs", {})
params["model_kwargs"]["response_format"] = {"type": "json_object"} # VOLCANO 模型不支持 response_formatJSON 输出由 system prompt 注入实现
if provider != ModelProvider.VOLCANO:
model_kwargs["response_format"] = {"type": "json_object"}
return params return params
elif provider == ModelProvider.DASHSCOPE: elif provider == ModelProvider.DASHSCOPE:
params = { params = {
@@ -172,11 +171,9 @@ class RedBearModelFactory:
model_kwargs["incremental_output"] = True model_kwargs["incremental_output"] = True
if config.thinking_budget_tokens: if config.thinking_budget_tokens:
model_kwargs["thinking_budget"] = config.thinking_budget_tokens model_kwargs["thinking_budget"] = config.thinking_budget_tokens
params["model_kwargs"] = model_kwargs
if config.json_output: if config.json_output:
model_kwargs = params.setdefault("model_kwargs", {}) model_kwargs = params.setdefault("model_kwargs", {})
model_kwargs["response_format"] = {"type": "json_object"} model_kwargs["response_format"] = {"type": "json_object"}
params["model_kwargs"] = model_kwargs
return params return params
elif provider == ModelProvider.BEDROCK: elif provider == ModelProvider.BEDROCK:
# Bedrock 使用 AWS 凭证 # Bedrock 使用 AWS 凭证
@@ -225,8 +222,8 @@ class RedBearModelFactory:
} }
# JSON 输出模式 # JSON 输出模式
if config.json_output: if config.json_output:
params.setdefault("model_kwargs", {}) model_kwargs = params.setdefault("model_kwargs", {})
params["model_kwargs"]["response_format"] = {"type": "json_object"} model_kwargs["response_format"] = {"type": "json_object"}
return params return params
else: else:
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
@@ -261,12 +258,13 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
if provider == ModelProvider.VOLCANO: if provider == ModelProvider.VOLCANO:
return CompatibleChatOpenAI return CompatibleChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
if type == ModelType.LLM: return CompatibleChatOpenAI
return OpenAI # if type == ModelType.LLM:
elif type == ModelType.CHAT: # return OpenAI
return ChatOpenAI # elif type == ModelType.CHAT:
else: # return CompatibleChatOpenAI
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED) # else:
# raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
elif provider == ModelProvider.DASHSCOPE: elif provider == ModelProvider.DASHSCOPE:
return ChatTongyi return ChatTongyi
elif provider == ModelProvider.OLLAMA: elif provider == ModelProvider.OLLAMA:

View File

@@ -8,12 +8,33 @@ from __future__ import annotations
from typing import Any, Optional, Union from typing import Any, Optional, Union
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
class CompatibleChatOpenAI(ChatOpenAI): class CompatibleChatOpenAI(ChatOpenAI):
"""火山和千问的omni兼容模型支持深度思考内容reasoning_content的流式和非流式透传。""" """火山和千问的omni兼容模型支持深度思考内容reasoning_content的流式和非流式透传。
同时修复 json_output + tools 同时使用时 langchain_openai 强制走 .parse()/.stream()
导致 strict 校验报错的问题:有工具时从 payload 中移除 response_format
让父类走普通 .create()/.astream() 路径JSON 输出由 system prompt 指令保证。
"""
def _get_request_payload(
self,
input_: list[BaseMessage],
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> dict:
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
# 有工具时 langchain_openai 检测到 response_format 会切换到 .parse()/.stream()
# 接口OpenAI SDK 要求此时所有工具必须 strict=True动态生成的工具不满足。
# 移除 response_format让父类走普通路径JSON 输出由 system prompt 指令保证。
if payload.get("tools") and "response_format" in payload:
payload.pop("response_format")
return payload
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult: def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
result = super()._create_chat_result(response, generation_info) result = super()._create_chat_result(response, generation_info)

View File

@@ -253,9 +253,9 @@ class DateTimeTool(BuiltinTool):
return { return {
"datetime": input_value, "datetime": input_value,
"timezone": timezone_str, "timezone": timezone_str,
"timestamp": int(dt.timestamp()) * 1000, "timestamp": int(dt.timestamp() * 1000),
"iso_format": dt.isoformat(), "iso_format": dt.isoformat(),
"result_data": int(dt.timestamp()) * 1000 "result_data": int(dt.timestamp() * 1000)
} }
def _calculate_datetime(self, kwargs) -> dict: def _calculate_datetime(self, kwargs) -> dict:

View File

@@ -226,9 +226,12 @@ class LLMNode(BaseNode):
self.messages = [{"role": "user", "content": rendered}] self.messages = [{"role": "user", "content": rendered}]
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format在 system prompt 中注入 # ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format在 system prompt 中注入
if (self.typed_config.json_output # VOLCANO 模型不支持 response_format同样需要 system prompt 注入
and model_info.provider.lower() == ModelProvider.DASHSCOPE need_json_prompt = self.typed_config.json_output and (
and not model_info.is_omni): (model_info.provider.lower() == ModelProvider.DASHSCOPE and not model_info.is_omni)
or model_info.provider.lower() == ModelProvider.VOLCANO
)
if need_json_prompt:
system_msg = next((m for m in self.messages if m["role"] == "system"), None) system_msg = next((m for m in self.messages if m["role"] == "system"), None)
if system_msg: if system_msg:
system_msg["content"] += "\n请以JSON格式输出。" system_msg["content"] += "\n请以JSON格式输出。"