Files
MemoryBear/api/app/services/agent_server.py

130 lines
3.9 KiB
Python

from typing import Any, List
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import InMemorySaver
from pydantic import BaseModel
from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import before_model
from langchain.tools import tool
from langchain_core.messages import RemoveMessage
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from app.services.api_resquests_server import send_message, model, retrieval
class config(BaseModel):
template_str:str
params:dict
model_configs: List[dict] = []
history_memory:bool
knowledge_base:bool
class RemoryInput(BaseModel):
question: str
end_user_id: str
search_switch:str
class ChatRequest(BaseModel):
end_user_id: str
message: str
search_switch:str
kb_ids: List[str] = []
similarity_threshold:float
vector_similarity_weight:float
top_k:int
hybrid:bool
token:str
class RetrievalInput(BaseModel):
message: str
kb_ids: List[str] = []
similarity_threshold: float
vector_similarity_weight: float
top_k: int
hybrid: bool
token: str
async def tool_Retrieval(req):
tool_result = retrieval_search.invoke({
"message":req.message, "kb_ids":req.kb_ids,
"similarity_threshold":req.similarity_threshold, "vector_similarity_weight":req.vector_similarity_weight,
"top_k":req.top_k, "hybrid":req.hybrid, "token":req.token
})
return tool_result
async def tool_memory(req):
tool_result = remory_sk.invoke({
"question": req.message,
"end_user_id": req.end_user_id,
"search_switch": req.search_switch
})
return tool_result
@before_model
# ========== 消息剪枝中间件 ==========
def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""保留前1条 + 最近3~4条消息"""
messages = state["messages"]
if len(messages) <= 10:
return None
first_msg = messages[0]
recent_messages = messages[-10:] if len(messages) % 2 == 0 else messages[-11:]
new_messages = [first_msg] + recent_messages
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages
]
}
##-----------历史记忆------------
@ tool(args_schema=RemoryInput)
def remory_sk(question: str, end_user_id: str, search_switch: str):
"""
条件调用工具:
- 仅当 question 是疑问句时调用 send_message
- 根据 end_user_id 和 search_switch 参数决定是否执行检索
Args:
question: 用户的提问内容
end_user_id: 用户唯一标识符
search_switch: 搜索开关,控制是否执行检索
Returns:
检索结果或空字符串
"""
# 移除关于 configurable 的描述,避免混淆
if not end_user_id or end_user_id == '123':
print("警告: 无效的 user_id 参数")
return ''
if search_switch in ['on', 'off'] or not search_switch:
print("警告: 无效的 search_switch 参数")
return ''
return send_message(end_user_id, question, '[]', search_switch)
#-------------检索------------
@ tool(args_schema=RetrievalInput)
def retrieval_search(message,kb_ids,similarity_threshold,vector_similarity_weight,top_k,hybrid,token):
'''检索'''
search=retrieval(message,kb_ids,similarity_threshold,vector_similarity_weight,top_k,hybrid,token)
return search
async def create_dynamic_agent(model_name: str,model_id:str,PROMPT:str,token:str):
"""根据模型名动态创建代理"""
model_name, api_key, api_base=await model(model_id,token)
llm = ChatOpenAI(model=model_name, base_url=api_base, temperature=0.2,api_key=api_key)
memory = InMemorySaver()
return create_agent(
llm,
tools=[remory_sk,retrieval_search],
middleware=[trim_messages],
checkpointer=memory,
system_prompt=PROMPT
)