[MODIFY] Code optimization
This commit is contained in:
@@ -4,7 +4,7 @@ import time
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any, AsyncGenerator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.models import ReleaseShare, AppRelease, Conversation
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
@@ -16,6 +16,8 @@ from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import MultiAgentConfig
|
||||
from app.repositories import knowledge_repository
|
||||
import json
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@@ -88,7 +90,7 @@ class SharedChatService:
|
||||
return conversation
|
||||
except ResourceNotFoundException:
|
||||
logger.warning(
|
||||
f"会话不存在,将创建新会话",
|
||||
"会话不存在,将创建新会话",
|
||||
extra={"conversation_id": str(conversation_id)}
|
||||
)
|
||||
|
||||
@@ -102,7 +104,7 @@ class SharedChatService:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"为分享链接创建新会话",
|
||||
"为分享链接创建新会话",
|
||||
extra={
|
||||
"conversation_id": str(conversation.id),
|
||||
"share_token": share_token,
|
||||
@@ -121,17 +123,24 @@ class SharedChatService:
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
actual_config_id = None
|
||||
config_id=actual_config_id
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from sqlalchemy import select
|
||||
from app.models import ModelApiKey
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id=None
|
||||
config_id=actual_config_id
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
@@ -199,10 +208,11 @@ class SharedChatService:
|
||||
tools.append(kb_tool)
|
||||
|
||||
# 添加长期记忆工具
|
||||
|
||||
memory_flag=False
|
||||
if memory==True:
|
||||
memory_config = config.get("memory", {})
|
||||
if memory_config.get("enabled") and user_id:
|
||||
memory_flag=True
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
@@ -234,6 +244,7 @@ class SharedChatService:
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
@@ -254,7 +265,11 @@ class SharedChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
)
|
||||
|
||||
# 保存消息
|
||||
@@ -280,6 +295,7 @@ class SharedChatService:
|
||||
# )
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
return {
|
||||
"conversation_id": conversation.id,
|
||||
@@ -301,7 +317,9 @@ class SharedChatService:
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True
|
||||
memory: bool = True,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
@@ -312,6 +330,9 @@ class SharedChatService:
|
||||
import json
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id=None
|
||||
config_id=actual_config_id
|
||||
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
@@ -381,9 +402,11 @@ class SharedChatService:
|
||||
tools.append(kb_tool)
|
||||
|
||||
# 添加长期记忆工具
|
||||
memory_flag=False
|
||||
if memory:
|
||||
memory_config = config.get("memory", {})
|
||||
if memory_config.get("enabled") and user_id:
|
||||
memory_flag = True
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
@@ -440,7 +463,11 @@ class SharedChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
):
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
@@ -464,13 +491,14 @@ class SharedChatService:
|
||||
"usage": {}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 发送结束事件
|
||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
|
||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
logger.info(
|
||||
f"流式聊天完成",
|
||||
"流式聊天完成",
|
||||
extra={
|
||||
"conversation_id": str(conversation.id),
|
||||
"elapsed_time": elapsed_time,
|
||||
@@ -539,13 +567,19 @@ class SharedChatService:
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""多 Agent 聊天(非流式)"""
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import MultiAgentConfig
|
||||
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id=None
|
||||
config_id=actual_config_id
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
@@ -609,6 +643,8 @@ class SharedChatService:
|
||||
"sub_results": result.get("sub_results")
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
return {
|
||||
"conversation_id": conversation.id,
|
||||
@@ -630,11 +666,16 @@ class SharedChatService:
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""多 Agent 聊天(流式)"""
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id=None
|
||||
config_id=actual_config_id
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
@@ -741,13 +782,14 @@ class SharedChatService:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"多 Agent 流式聊天完成",
|
||||
"多 Agent 流式聊天完成",
|
||||
extra={
|
||||
"conversation_id": str(conversation.id),
|
||||
"elapsed_time": elapsed_time,
|
||||
"message_length": len(full_content)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
except (GeneratorExit, asyncio.CancelledError):
|
||||
# 生成器被关闭或任务被取消,正常退出
|
||||
|
||||
Reference in New Issue
Block a user