Merge pull request #928 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
fix(llm)
This commit is contained in:
@@ -12,7 +12,7 @@ import time
|
|||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.errors import GraphRecursionError
|
from langgraph.errors import GraphRecursionError
|
||||||
|
|
||||||
@@ -83,7 +83,12 @@ class LangChainAgent:
|
|||||||
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format
|
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format
|
||||||
# 在 system prompt 中注入 JSON 要求
|
# 在 system prompt 中注入 JSON 要求
|
||||||
from app.models.models_model import ModelProvider
|
from app.models.models_model import ModelProvider
|
||||||
if json_output and provider.lower() == ModelProvider.DASHSCOPE and not is_omni:
|
if json_output and (
|
||||||
|
(provider.lower() == ModelProvider.DASHSCOPE and not is_omni)
|
||||||
|
or provider.lower() == ModelProvider.VOLCANO
|
||||||
|
# 有工具时 response_format 会被移除,所有 provider 都需要 system prompt 注入保证 JSON 输出
|
||||||
|
or bool(tools)
|
||||||
|
):
|
||||||
self.system_prompt += "\n请以JSON格式输出。"
|
self.system_prompt += "\n请以JSON格式输出。"
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -240,9 +245,7 @@ class LangChainAgent:
|
|||||||
Returns:
|
Returns:
|
||||||
List[BaseMessage]: 消息列表
|
List[BaseMessage]: 消息列表
|
||||||
"""
|
"""
|
||||||
messages:list = [SystemMessage(content=self.system_prompt)]
|
messages: list = []
|
||||||
|
|
||||||
# 添加系统提示词
|
|
||||||
|
|
||||||
# 添加历史消息
|
# 添加历史消息
|
||||||
if history:
|
if history:
|
||||||
|
|||||||
@@ -101,12 +101,10 @@ class RedBearModelFactory:
|
|||||||
extra_body["enable_thinking"] = True
|
extra_body["enable_thinking"] = True
|
||||||
if config.thinking_budget_tokens:
|
if config.thinking_budget_tokens:
|
||||||
extra_body["thinking_budget"] = config.thinking_budget_tokens
|
extra_body["thinking_budget"] = config.thinking_budget_tokens
|
||||||
params["extra_body"] = extra_body
|
|
||||||
# JSON 输出模式
|
# JSON 输出模式
|
||||||
if config.json_output:
|
if config.json_output:
|
||||||
model_kwargs = params.setdefault("model_kwargs", {})
|
model_kwargs = params.setdefault("model_kwargs", {})
|
||||||
model_kwargs["response_format"] = {"type": "json_object"}
|
model_kwargs["response_format"] = {"type": "json_object"}
|
||||||
params["model_kwargs"] = model_kwargs
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
||||||
@@ -148,11 +146,12 @@ class RedBearModelFactory:
|
|||||||
extra_body["enable_thinking"] = True
|
extra_body["enable_thinking"] = True
|
||||||
if config.thinking_budget_tokens:
|
if config.thinking_budget_tokens:
|
||||||
extra_body["thinking_budget"] = config.thinking_budget_tokens
|
extra_body["thinking_budget"] = config.thinking_budget_tokens
|
||||||
params["extra_body"] = extra_body
|
|
||||||
# JSON 输出模式
|
# JSON 输出模式
|
||||||
if config.json_output:
|
if config.json_output:
|
||||||
params.setdefault("model_kwargs", {})
|
model_kwargs = params.setdefault("model_kwargs", {})
|
||||||
params["model_kwargs"]["response_format"] = {"type": "json_object"}
|
# VOLCANO 模型不支持 response_format,JSON 输出由 system prompt 注入实现
|
||||||
|
if provider != ModelProvider.VOLCANO:
|
||||||
|
model_kwargs["response_format"] = {"type": "json_object"}
|
||||||
return params
|
return params
|
||||||
elif provider == ModelProvider.DASHSCOPE:
|
elif provider == ModelProvider.DASHSCOPE:
|
||||||
params = {
|
params = {
|
||||||
@@ -172,11 +171,9 @@ class RedBearModelFactory:
|
|||||||
model_kwargs["incremental_output"] = True
|
model_kwargs["incremental_output"] = True
|
||||||
if config.thinking_budget_tokens:
|
if config.thinking_budget_tokens:
|
||||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||||
params["model_kwargs"] = model_kwargs
|
|
||||||
if config.json_output:
|
if config.json_output:
|
||||||
model_kwargs = params.setdefault("model_kwargs", {})
|
model_kwargs = params.setdefault("model_kwargs", {})
|
||||||
model_kwargs["response_format"] = {"type": "json_object"}
|
model_kwargs["response_format"] = {"type": "json_object"}
|
||||||
params["model_kwargs"] = model_kwargs
|
|
||||||
return params
|
return params
|
||||||
elif provider == ModelProvider.BEDROCK:
|
elif provider == ModelProvider.BEDROCK:
|
||||||
# Bedrock 使用 AWS 凭证
|
# Bedrock 使用 AWS 凭证
|
||||||
@@ -225,8 +222,8 @@ class RedBearModelFactory:
|
|||||||
}
|
}
|
||||||
# JSON 输出模式
|
# JSON 输出模式
|
||||||
if config.json_output:
|
if config.json_output:
|
||||||
params.setdefault("model_kwargs", {})
|
model_kwargs = params.setdefault("model_kwargs", {})
|
||||||
params["model_kwargs"]["response_format"] = {"type": "json_object"}
|
model_kwargs["response_format"] = {"type": "json_object"}
|
||||||
return params
|
return params
|
||||||
else:
|
else:
|
||||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||||
@@ -261,12 +258,13 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
|||||||
if provider == ModelProvider.VOLCANO:
|
if provider == ModelProvider.VOLCANO:
|
||||||
return CompatibleChatOpenAI
|
return CompatibleChatOpenAI
|
||||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||||
if type == ModelType.LLM:
|
return CompatibleChatOpenAI
|
||||||
return OpenAI
|
# if type == ModelType.LLM:
|
||||||
elif type == ModelType.CHAT:
|
# return OpenAI
|
||||||
return ChatOpenAI
|
# elif type == ModelType.CHAT:
|
||||||
else:
|
# return CompatibleChatOpenAI
|
||||||
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
# else:
|
||||||
|
# raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||||
elif provider == ModelProvider.DASHSCOPE:
|
elif provider == ModelProvider.DASHSCOPE:
|
||||||
return ChatTongyi
|
return ChatTongyi
|
||||||
elif provider == ModelProvider.OLLAMA:
|
elif provider == ModelProvider.OLLAMA:
|
||||||
|
|||||||
@@ -8,12 +8,33 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
class CompatibleChatOpenAI(ChatOpenAI):
|
class CompatibleChatOpenAI(ChatOpenAI):
|
||||||
"""火山和千问的omni兼容模型,支持深度思考内容(reasoning_content)的流式和非流式透传。"""
|
"""火山和千问的omni兼容模型,支持深度思考内容(reasoning_content)的流式和非流式透传。
|
||||||
|
|
||||||
|
同时修复 json_output + tools 同时使用时 langchain_openai 强制走 .parse()/.stream()
|
||||||
|
导致 strict 校验报错的问题:有工具时从 payload 中移除 response_format,
|
||||||
|
让父类走普通 .create()/.astream() 路径,JSON 输出由 system prompt 指令保证。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_request_payload(
|
||||||
|
self,
|
||||||
|
input_: list[BaseMessage],
|
||||||
|
*,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict:
|
||||||
|
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||||
|
# 有工具时 langchain_openai 检测到 response_format 会切换到 .parse()/.stream()
|
||||||
|
# 接口,OpenAI SDK 要求此时所有工具必须 strict=True,动态生成的工具不满足。
|
||||||
|
# 移除 response_format,让父类走普通路径,JSON 输出由 system prompt 指令保证。
|
||||||
|
if payload.get("tools") and "response_format" in payload:
|
||||||
|
payload.pop("response_format")
|
||||||
|
return payload
|
||||||
|
|
||||||
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
|
def _create_chat_result(self, response: Union[dict, Any], generation_info: Optional[dict] = None) -> ChatResult:
|
||||||
result = super()._create_chat_result(response, generation_info)
|
result = super()._create_chat_result(response, generation_info)
|
||||||
|
|||||||
@@ -253,9 +253,9 @@ class DateTimeTool(BuiltinTool):
|
|||||||
return {
|
return {
|
||||||
"datetime": input_value,
|
"datetime": input_value,
|
||||||
"timezone": timezone_str,
|
"timezone": timezone_str,
|
||||||
"timestamp": int(dt.timestamp()) * 1000,
|
"timestamp": int(dt.timestamp() * 1000),
|
||||||
"iso_format": dt.isoformat(),
|
"iso_format": dt.isoformat(),
|
||||||
"result_data": int(dt.timestamp()) * 1000
|
"result_data": int(dt.timestamp() * 1000)
|
||||||
}
|
}
|
||||||
|
|
||||||
def _calculate_datetime(self, kwargs) -> dict:
|
def _calculate_datetime(self, kwargs) -> dict:
|
||||||
|
|||||||
@@ -226,9 +226,12 @@ class LLMNode(BaseNode):
|
|||||||
self.messages = [{"role": "user", "content": rendered}]
|
self.messages = [{"role": "user", "content": rendered}]
|
||||||
|
|
||||||
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format,在 system prompt 中注入
|
# ChatTongyi 要求 messages 含 'json' 字样才能使用 response_format,在 system prompt 中注入
|
||||||
if (self.typed_config.json_output
|
# VOLCANO 模型不支持 response_format,同样需要 system prompt 注入
|
||||||
and model_info.provider.lower() == ModelProvider.DASHSCOPE
|
need_json_prompt = self.typed_config.json_output and (
|
||||||
and not model_info.is_omni):
|
(model_info.provider.lower() == ModelProvider.DASHSCOPE and not model_info.is_omni)
|
||||||
|
or model_info.provider.lower() == ModelProvider.VOLCANO
|
||||||
|
)
|
||||||
|
if need_json_prompt:
|
||||||
system_msg = next((m for m in self.messages if m["role"] == "system"), None)
|
system_msg = next((m for m in self.messages if m["role"] == "system"), None)
|
||||||
if system_msg:
|
if system_msg:
|
||||||
system_msg["content"] += "\n请以JSON格式输出。"
|
system_msg["content"] += "\n请以JSON格式输出。"
|
||||||
|
|||||||
Reference in New Issue
Block a user