[MODIFY] Code optimization
This commit is contained in:
@@ -1,35 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.agent.agent_chat import Agent_chat
|
||||
from app.core.logging_config import get_business_logger
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from app.dependencies import workspace_access_guard
|
||||
from app.services.agent_server import config,ChatRequest
|
||||
router = APIRouter(prefix="/Test", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
class CombinedRequest(BaseModel):
|
||||
config_base: config
|
||||
agent_config: ChatRequest
|
||||
|
||||
@router.post("", summary="uuid")
|
||||
async def agent_chat(
|
||||
config_base: CombinedRequest
|
||||
):
|
||||
chat_config=config_base.agent_config
|
||||
chat_base=config_base.config_base
|
||||
request = ChatRequest(
|
||||
end_user_id=chat_config.end_user_id,
|
||||
message=chat_config.message,
|
||||
search_switch=chat_config.search_switch,
|
||||
kb_ids=chat_config.kb_ids,
|
||||
similarity_threshold=chat_config.similarity_threshold,
|
||||
vector_similarity_weight=chat_config.vector_similarity_weight,
|
||||
top_k=chat_config.top_k,
|
||||
hybrid=chat_config.hybrid,
|
||||
token=chat_config.token
|
||||
)
|
||||
|
||||
chat_result=await Agent_chat(chat_base).chat(request)
|
||||
|
||||
return chat_result
|
||||
@@ -1,109 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.api_resquests_server import messages_type, write_messages
|
||||
from app.services.agent_server import ChatRequest, tool_memory, create_dynamic_agent, tool_Retrieval
|
||||
|
||||
logger = get_business_logger()
|
||||
class Agent_chat:
|
||||
def __init__(self,config_data: dict):
|
||||
self.prompt_message = render_prompt_message(
|
||||
config_data.template_str,
|
||||
PromptMessageRole.USER,
|
||||
config_data.params
|
||||
)
|
||||
self.prompt = self.prompt_message.get_text_content()
|
||||
self.model_configs = config_data.model_configs
|
||||
self.history_memory = config_data.history_memory
|
||||
self.knowledge_base = config_data.knowledge_base
|
||||
logger.info(f"渲染结果:{self.prompt_message.get_text_content()}" )
|
||||
|
||||
async def run_agent(self,agent, end_user_id:str, user_prompt:str, model_name:str):
|
||||
response = agent.invoke(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt
|
||||
}
|
||||
]
|
||||
},
|
||||
{"configurable": {"thread_id": f'{model_name}_{end_user_id}'}},
|
||||
)
|
||||
outputs = []
|
||||
for msg in response["messages"]:
|
||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
outputs.append({
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"name": t["name"], "arguments": t["args"]}
|
||||
for t in msg.tool_calls
|
||||
]
|
||||
})
|
||||
elif hasattr(msg, "content") and msg.content:
|
||||
outputs.append({
|
||||
"role": msg.__class__.__name__.lower().replace("message", ""),
|
||||
"content": msg.content
|
||||
})
|
||||
ai_messages=[msg['content'] for msg in outputs if msg["role"] == "ai"]
|
||||
return {"model_name": model_name, "end_user_id": end_user_id, "response": ai_messages}
|
||||
|
||||
async def chat(self,req: ChatRequest) -> Dict[str, Any]:
|
||||
|
||||
end_user_id = req.end_user_id # 用 user_id 作为对话线程标识
|
||||
start=time.time()
|
||||
user_prompt = req.message
|
||||
|
||||
'''判断是都写入redis数据库'''
|
||||
messags_type = await messages_type(req.message,end_user_id)
|
||||
messags_type=messags_type['data']
|
||||
if messags_type=='question':
|
||||
writer_result=await write_messages(f'{end_user_id}', req.message)
|
||||
logger.info(f'判断类型写入耗时:{time.time() - start},{writer_result}')
|
||||
|
||||
|
||||
|
||||
'''history_memory'''
|
||||
|
||||
if self.history_memory==True:
|
||||
tool_result =await tool_memory(req)
|
||||
if tool_result!='' :tool_result=tool_result['data']
|
||||
if tool_result!='' :self.prompt=self.prompt+f''',历史消息:{tool_result},结合历史消息'''
|
||||
logger.info(f"记忆科学消耗时间:{time.time()-start},工具调用结果:{tool_result}")
|
||||
|
||||
'''baidu'''
|
||||
|
||||
|
||||
'''knowledge_base'''
|
||||
if self.knowledge_base == True:
|
||||
retrieval_result=await tool_Retrieval(req)
|
||||
retrieval_knowledge = [i['page_content'] for i in retrieval_result['data']]
|
||||
retrieval_knowledge=','.join(retrieval_knowledge)
|
||||
logger.info(f"检索消耗时间:{time.time()-start},{retrieval_knowledge}")
|
||||
if retrieval_knowledge!='' :self.prompt=self.prompt+f",知识库检索内容:{retrieval_knowledge},结合检索结果"
|
||||
self.prompt=self.prompt+f'给出最合适的答案,确保答案的完整性,只保留用户的问题的回答,不额外输出提示语'
|
||||
logger.info(f"用户输入:{user_prompt}")
|
||||
logger.info(f"系统prompt:{self.prompt}")
|
||||
|
||||
AGENTS = {
|
||||
cfg["name"]: await create_dynamic_agent(cfg["name"], cfg["moder_id"], self.prompt, req.token)
|
||||
for cfg in self.model_configs
|
||||
}
|
||||
tasks=[
|
||||
self.run_agent(agent, end_user_id, user_prompt, model_name)
|
||||
for model_name, agent in AGENTS.items()
|
||||
]
|
||||
# 并行运行
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
result=[]
|
||||
|
||||
for i in results:
|
||||
result.append(i)
|
||||
chat_result=(f"最终耗时:{time.time()-start},{result}")
|
||||
return chat_result
|
||||
@@ -15,6 +15,8 @@ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, Base
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from app.core.memory.agent.mcp_server.services import session_service
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -89,7 +91,7 @@ class LangChainAgent:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"LangChain Agent 初始化完成",
|
||||
"LangChain Agent 初始化完成",
|
||||
extra={
|
||||
"model": model_name,
|
||||
"provider": provider,
|
||||
@@ -139,6 +141,42 @@ class LangChainAgent:
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||
'''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||
end_user_end=f"Term_{end_user_end}"
|
||||
print(messages)
|
||||
print(aimessages)
|
||||
session_id = store.save_session(
|
||||
userid=end_user_end,
|
||||
messages=messages,
|
||||
apply_id=end_user_end,
|
||||
group_id=end_user_end,
|
||||
aimessages=aimessages
|
||||
)
|
||||
store.delete_duplicate_sessions()
|
||||
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||
return session_id
|
||||
async def term_memory_redis_read(self,end_user_end):
|
||||
end_user_end = f"Term_{end_user_end}"
|
||||
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||
# logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||
messagss_list=[]
|
||||
for messages in history:
|
||||
query = messages.get("Query")
|
||||
aimessages = messages.get("Answer")
|
||||
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
return messagss_list
|
||||
|
||||
|
||||
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
|
||||
if storage_type == "rag":
|
||||
await write_rag(end_user_id, message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
@@ -149,6 +187,7 @@ class LangChainAgent:
|
||||
config_id: Optional[str] = None, # 添加这个参数
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True
|
||||
) -> Dict[str, Any]:
|
||||
"""执行对话
|
||||
|
||||
@@ -160,29 +199,29 @@ class LangChainAgent:
|
||||
Returns:
|
||||
Dict: 包含 content 和元数据的字典
|
||||
"""
|
||||
message_chat= message
|
||||
start_time = time.time()
|
||||
|
||||
if config_id == None:
|
||||
actual_config_id = os.getenv("config_id")
|
||||
else:
|
||||
actual_config_id = config_id
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
if storage_type == "rag":
|
||||
await write_rag(end_user_id, message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
if config_id==None:
|
||||
actual_config_id = os.getenv("config_id")
|
||||
else:actual_config_id=config_id
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
write_id = write_message_task.delay(actual_end_user_id, message, actual_config_id,storage_type,user_rag_memory_id)
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
||||
|
||||
|
||||
history_term_memory=await self.term_memory_redis_read(end_user_id)
|
||||
if memory_flag:
|
||||
if len(history_term_memory)>=4 and storage_type != "rag":
|
||||
history_term_memory=';'.join(history_term_memory)
|
||||
logger.info(f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
|
||||
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
|
||||
logger.debug(
|
||||
f"准备调用 LangChain Agent",
|
||||
"准备调用 LangChain Agent",
|
||||
extra={
|
||||
"has_context": bool(context),
|
||||
"has_history": bool(history),
|
||||
@@ -203,15 +242,9 @@ class LangChainAgent:
|
||||
break
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
if storage_type == "rag":
|
||||
await write_rag(end_user_id, message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type, user_rag_memory_id)
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
||||
|
||||
if memory_flag:
|
||||
await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id)
|
||||
await self.term_memory_save(message_chat,end_user_id,content)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -224,7 +257,7 @@ class LangChainAgent:
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"Agent 调用完成",
|
||||
"Agent 调用完成",
|
||||
extra={
|
||||
"elapsed_time": elapsed_time,
|
||||
"content_length": len(response["content"])
|
||||
@@ -234,7 +267,7 @@ class LangChainAgent:
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent 调用失败", extra={"error": str(e)})
|
||||
logger.error("Agent 调用失败", extra={"error": str(e)})
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
@@ -246,7 +279,7 @@ class LangChainAgent:
|
||||
config_id: Optional[str] = None,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None,
|
||||
|
||||
memory_flag: Optional[bool] = True
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""执行流式对话
|
||||
|
||||
@@ -259,28 +292,27 @@ class LangChainAgent:
|
||||
str: 消息内容块
|
||||
"""
|
||||
logger.info("=" * 80)
|
||||
logger.info(f" chat_stream 方法开始执行")
|
||||
logger.info(" chat_stream 方法开始执行")
|
||||
logger.info(f" Message: {message[:100]}")
|
||||
logger.info(f" Has tools: {bool(self.tools)}")
|
||||
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||
logger.info("=" * 80)
|
||||
|
||||
start_time = time.time()
|
||||
if storage_type == "rag":
|
||||
await write_rag(end_user_id, message, user_rag_memory_id)
|
||||
message_chat = message
|
||||
if config_id == None:
|
||||
actual_config_id = os.getenv("config_id")
|
||||
else:
|
||||
if config_id==None:
|
||||
actual_config_id = os.getenv("config_id")
|
||||
else:actual_config_id=config_id
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
write_id = write_message_task.delay(actual_end_user_id, message, actual_config_id,storage_type,user_rag_memory_id)
|
||||
actual_config_id = config_id
|
||||
|
||||
try:
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
||||
except Exception as e:
|
||||
logger.error(f"Agent 记忆用户输入出错", extra={"error": str(e)})
|
||||
history_term_memory = await self.term_memory_redis_read(end_user_id)
|
||||
if memory_flag:
|
||||
if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||
history_term_memory = ';'.join(history_term_memory)
|
||||
logger.info(
|
||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
|
||||
history_term_memory, actual_config_id)
|
||||
|
||||
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
@@ -294,7 +326,7 @@ class LangChainAgent:
|
||||
|
||||
# 统一使用 agent 的 astream_events 实现流式输出
|
||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||
|
||||
full_content=''
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
@@ -307,6 +339,7 @@ class LangChainAgent:
|
||||
if kind == "on_chat_model_stream":
|
||||
# LLM 流式输出
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
full_content+=chunk.content
|
||||
if chunk and hasattr(chunk, "content") and chunk.content:
|
||||
yield chunk.content
|
||||
yielded_content = True
|
||||
@@ -316,6 +349,7 @@ class LangChainAgent:
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk:
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
full_content+=chunk.content
|
||||
yield chunk.content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk, str):
|
||||
@@ -329,6 +363,9 @@ class LangChainAgent:
|
||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
if memory_flag:
|
||||
await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id)
|
||||
await self.term_memory_save(message_chat, end_user_id, full_content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
@@ -341,7 +378,7 @@ class LangChainAgent:
|
||||
raise
|
||||
finally:
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"chat_stream 方法执行结束")
|
||||
logger.info("chat_stream 方法执行结束")
|
||||
logger.info("=" * 80)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user