Files
MemoryBear/app/core/agent/langchain_agent.py
2025-11-30 18:22:17 +08:00

348 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
LangChain Agent 封装
使用 LangChain 1.x 标准方式
- 使用 create_agent 创建 agent graph
- 支持工具调用循环
- 支持流式输出
- 使用 RedBearLLM 支持多提供商
"""
import os
import time
import asyncio
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain_core.tools import BaseTool
from langchain.agents import create_agent
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.core.logging_config import get_business_logger
from app.services.memory_agent_service import MemoryAgentService
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
logger = get_business_logger()
class LangChainAgent:
def __init__(
self,
model_name: str,
api_key: str,
provider: str = "openai",
api_base: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 2000,
system_prompt: Optional[str] = None,
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False
):
"""初始化 LangChain Agent
Args:
model_name: 模型名称
api_key: API Key
provider: 提供商openai, xinference, gpustack, ollama, dashscope
api_base: API 基础 URL
temperature: 温度参数
max_tokens: 最大 token 数
system_prompt: 系统提示词
tools: 工具列表(可选,框架自动走 ReAct 循环)
streaming: 是否启用流式输出(默认 True
"""
self.model_name = model_name
self.provider = provider
self.system_prompt = system_prompt or "你是一个专业的AI助手"
self.tools = tools or []
self.streaming = streaming
# 创建 RedBearLLM支持多提供商
model_config = RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
extra_params={
"temperature": temperature,
"max_tokens": max_tokens,
"streaming": streaming # 使用参数控制流式
}
)
self.llm = RedBearLLM(model_config, type=ModelType.CHAT)
# 获取底层模型用于真正的流式调用
self._underlying_llm = self.llm._model if hasattr(self.llm, '_model') else self.llm
# 确保底层模型也启用流式
if streaming and hasattr(self._underlying_llm, 'streaming'):
self._underlying_llm.streaming = True
# 使用 create_agent 创建 agent graphLangChain 1.x 标准方式)
# 无论是否有工具,都使用 agent 统一处理
self.agent = create_agent(
model=self.llm,
tools=self.tools if self.tools else None,
system_prompt=self.system_prompt
)
logger.info(
f"LangChain Agent 初始化完成",
extra={
"model": model_name,
"provider": provider,
"has_api_base": bool(api_base),
"temperature": temperature,
"streaming": streaming,
"tool_count": len(self.tools),
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
"tool_count": len(self.tools)
}
)
def _prepare_messages(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None
) -> List[BaseMessage]:
"""准备消息列表
Args:
message: 用户消息
history: 历史消息列表
context: 上下文信息
Returns:
List[BaseMessage]: 消息列表
"""
messages = []
# 添加系统提示词
messages.append(SystemMessage(content=self.system_prompt))
# 添加历史消息
if history:
for msg in history:
if msg["role"] == "user":
messages.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "assistant":
messages.append(AIMessage(content=msg["content"]))
# 添加当前用户消息
user_content = message
if context:
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
messages.append(HumanMessage(content=user_content))
return messages
async def chat(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id: Optional[str] = None,
config_id: Optional[str] = None, # 添加这个参数
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""执行对话
Args:
message: 用户消息
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
context: 上下文信息(如知识库检索结果)
Returns:
Dict: 包含 content 和元数据的字典
"""
start_time = time.time()
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
if config_id==None:
actual_config_id = os.getenv("config_id")
else:actual_config_id=config_id
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
write_id = write_message_task.delay(actual_end_user_id, message, actual_config_id,storage_type,user_rag_memory_id)
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}')
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
logger.debug(
f"准备调用 LangChain Agent",
extra={
"has_context": bool(context),
"has_history": bool(history),
"has_tools": bool(self.tools),
"message_count": len(messages)
}
)
# 统一使用 agent.invoke 调用
result = await self.agent.ainvoke({"messages": messages})
# 获取最后的 AI 消息
output_messages = result.get("messages", [])
content = ""
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
content = msg.content
break
elapsed_time = time.time() - start_time
if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type, user_rag_memory_id)
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}')
response = {
"content": content,
"model": self.model_name,
"elapsed_time": elapsed_time,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
logger.debug(
f"Agent 调用完成",
extra={
"elapsed_time": elapsed_time,
"content_length": len(response["content"])
}
)
return response
except Exception as e:
logger.error(f"Agent 调用失败", extra={"error": str(e)})
raise
async def chat_stream(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id:Optional[str] = None,
config_id: Optional[str] = None,
storage_type:Optional[str] = None,
user_rag_memory_id:Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""执行流式对话
Args:
message: 用户消息
history: 历史消息列表
context: 上下文信息
Yields:
str: 消息内容块
"""
logger.info("=" * 80)
logger.info(f" chat_stream 方法开始执行")
logger.info(f" Message: {message[:100]}")
logger.info(f" Has tools: {bool(self.tools)}")
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
logger.info("=" * 80)
start_time = time.time()
if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id)
else:
if config_id==None:
actual_config_id = os.getenv("config_id")
else:actual_config_id=config_id
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
write_id = write_message_task.delay(actual_end_user_id, message, actual_config_id,storage_type,user_rag_memory_id)
try:
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}')
except Exception as e:
logger.error(f"Agent 记忆用户输入出错", extra={"error": str(e)})
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
logger.debug(
f"准备流式调用has_tools={bool(self.tools)}, message_count={len(messages)}"
)
chunk_count = 0
yielded_content = False
# 统一使用 agent 的 astream_events 实现流式输出
logger.debug("使用 Agent astream_events 实现流式输出")
try:
async for event in self.agent.astream_events(
{"messages": messages},
version="v2"
):
chunk_count += 1
kind = event.get("event")
# 处理所有可能的流式事件
if kind == "on_chat_model_stream":
# LLM 流式输出
chunk = event.get("data", {}).get("chunk")
if chunk and hasattr(chunk, "content") and chunk.content:
yield chunk.content
yielded_content = True
elif kind == "on_llm_stream":
# 另一种 LLM 流式事件
chunk = event.get("data", {}).get("chunk")
if chunk:
if hasattr(chunk, "content") and chunk.content:
yield chunk.content
yielded_content = True
elif isinstance(chunk, str):
yield chunk
yielded_content = True
# 记录工具调用(可选)
elif kind == "on_tool_start":
logger.debug(f"工具调用开始: {event.get('name')}")
elif kind == "on_tool_end":
logger.debug(f"工具调用结束: {event.get('name')}")
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise
except Exception as e:
logger.error("=" * 80)
logger.error(f"chat_stream 异常: {str(e)}")
logger.error("=" * 80, exc_info=True)
raise
finally:
logger.info("=" * 80)
logger.info(f"chat_stream 方法执行结束")
logger.info("=" * 80)