feat(model and app):
1. Increase support for visual models and multimodal models; 2. The application and workflow can input various multimodal files such as images, documents, audio, and videos.
This commit is contained in:
@@ -11,35 +11,37 @@ LangChain Agent 封装
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.models.models_model import ModelType, ModelProvider
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class LangChainAgent:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
provider: str = "openai",
|
||||
api_base: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
provider: str = "openai",
|
||||
api_base: Optional[str] = None,
|
||||
is_omni: bool = False,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
|
||||
@@ -60,12 +62,13 @@ class LangChainAgent:
|
||||
self.provider = provider
|
||||
self.tools = tools or []
|
||||
self.streaming = streaming
|
||||
self.is_omni = is_omni
|
||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||
|
||||
|
||||
# 工具调用计数器:记录每个工具的连续调用次数
|
||||
self.tool_call_counter: Dict[str, int] = {}
|
||||
self.last_tool_called: Optional[str] = None
|
||||
|
||||
|
||||
# 根据工具数量动态调整最大迭代次数
|
||||
# 基础值 + 每个工具额外的调用机会
|
||||
if max_iterations is None:
|
||||
@@ -73,9 +76,9 @@ class LangChainAgent:
|
||||
self.max_iterations = 5 + len(self.tools) * 2
|
||||
else:
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||
f"tool_count={len(self.tools)}, "
|
||||
@@ -89,6 +92,7 @@ class LangChainAgent:
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
extra_params={
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
@@ -143,21 +147,22 @@ class LangChainAgent:
|
||||
"""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from functools import wraps
|
||||
|
||||
|
||||
wrapped_tools = []
|
||||
|
||||
|
||||
for original_tool in tools:
|
||||
tool_name = original_tool.name
|
||||
original_func = original_tool.func if hasattr(original_tool, 'func') else None
|
||||
|
||||
|
||||
if not original_func:
|
||||
# 如果无法获取原始函数,直接使用原工具
|
||||
wrapped_tools.append(original_tool)
|
||||
continue
|
||||
|
||||
|
||||
# 创建包装函数
|
||||
def make_wrapped_func(tool_name, original_func):
|
||||
"""创建包装函数的工厂函数,避免闭包问题"""
|
||||
|
||||
@wraps(original_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
"""包装后的工具函数,跟踪连续调用次数"""
|
||||
@@ -168,13 +173,13 @@ class LangChainAgent:
|
||||
# 切换到新工具,重置计数器
|
||||
self.tool_call_counter[tool_name] = 1
|
||||
self.last_tool_called = tool_name
|
||||
|
||||
|
||||
current_count = self.tool_call_counter[tool_name]
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
|
||||
)
|
||||
|
||||
|
||||
# 检查是否超过最大连续调用次数
|
||||
if current_count > self.max_tool_consecutive_calls:
|
||||
logger.warning(
|
||||
@@ -185,12 +190,12 @@ class LangChainAgent:
|
||||
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
|
||||
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
|
||||
)
|
||||
|
||||
|
||||
# 调用原始工具函数
|
||||
return original_func(*args, **kwargs)
|
||||
|
||||
|
||||
return wrapped_func
|
||||
|
||||
|
||||
# 使用 StructuredTool 创建新工具
|
||||
wrapped_tool = StructuredTool(
|
||||
name=original_tool.name,
|
||||
@@ -198,17 +203,17 @@ class LangChainAgent:
|
||||
func=make_wrapped_func(tool_name, original_func),
|
||||
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
|
||||
)
|
||||
|
||||
|
||||
wrapped_tools.append(wrapped_tool)
|
||||
|
||||
|
||||
return wrapped_tools
|
||||
|
||||
def _prepare_messages(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[BaseMessage]:
|
||||
"""准备消息列表
|
||||
|
||||
@@ -248,7 +253,7 @@ class LangChainAgent:
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建多模态消息内容
|
||||
@@ -261,23 +266,26 @@ class LangChainAgent:
|
||||
List[Dict]: 消息内容列表
|
||||
"""
|
||||
# 根据 provider 使用不同的文本格式
|
||||
if self.provider.lower() in ["bedrock", "anthropic"]:
|
||||
# Anthropic/Bedrock: {"type": "text", "text": "..."}
|
||||
content_parts = [{"type": "text", "text": text}]
|
||||
else:
|
||||
# 通义千问等: {"text": "..."}
|
||||
content_parts = [{"text": text}]
|
||||
|
||||
# if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE,
|
||||
# ModelProvider.GPUSTACK] or (
|
||||
# self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)):
|
||||
# # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."}
|
||||
# content_parts = [{"type": "text", "text": text}]
|
||||
# else:
|
||||
# # 通义千问等: {"text": "..."}
|
||||
# content_parts = [{"type": "text", "text": text}]
|
||||
content_parts = [{"type": "text", "text": text}]
|
||||
|
||||
# 添加文件内容
|
||||
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
|
||||
content_parts.extend(files)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"构建多模态消息: provider={self.provider}, "
|
||||
f"parts={len(content_parts)}, "
|
||||
f"files={len(files)}"
|
||||
)
|
||||
|
||||
|
||||
return content_parts
|
||||
|
||||
async def chat(
|
||||
@@ -302,7 +310,7 @@ class LangChainAgent:
|
||||
Returns:
|
||||
Dict: 包含 content 和元数据的字典
|
||||
"""
|
||||
message_chat= message
|
||||
message_chat = message
|
||||
start_time = time.time()
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
@@ -322,8 +330,8 @@ class LangChainAgent:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -367,14 +375,14 @@ class LangChainAgent:
|
||||
# 获取最后的 AI 消息
|
||||
output_messages = result.get("messages", [])
|
||||
content = ""
|
||||
|
||||
|
||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||
total_tokens = 0
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||
logger.debug(f"AI 消息内容: {msg.content}")
|
||||
|
||||
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
@@ -407,12 +415,13 @@ class LangChainAgent:
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
break
|
||||
|
||||
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -439,16 +448,16 @@ class LangChainAgent:
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id:Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""执行流式对话
|
||||
|
||||
@@ -482,7 +491,6 @@ class LangChainAgent:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
@@ -500,13 +508,13 @@ class LangChainAgent:
|
||||
full_content = ''
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
|
||||
|
||||
# 处理所有可能的流式事件
|
||||
if kind == "on_chat_model_stream":
|
||||
# LLM 流式输出
|
||||
@@ -540,7 +548,7 @@ class LangChainAgent:
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
@@ -577,13 +585,13 @@ class LangChainAgent:
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
|
||||
# 记录工具调用(可选)
|
||||
elif kind == "on_tool_start":
|
||||
logger.debug(f"工具调用开始: {event.get('name')}")
|
||||
elif kind == "on_tool_end":
|
||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
@@ -595,7 +603,8 @@ class LangChainAgent:
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -609,5 +618,3 @@ class LangChainAgent:
|
||||
logger.info("=" * 80)
|
||||
logger.info("chat_stream 方法执行结束")
|
||||
logger.info("=" * 80)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user