fix(db): fix database connection leak
This commit is contained in:
@@ -1,28 +1,29 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.language_utils import get_language_from_header
|
from app.core.language_utils import get_language_from_header
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.rag.llm.cv_model import QWenCV
|
from app.core.rag.llm.cv_model import QWenCV
|
||||||
from app.core.response_utils import fail, success
|
from app.core.response_utils import fail, success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||||
from app.models import ModelApiKey
|
from app.models import ModelApiKey
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.repositories import knowledge_repository
|
||||||
from app.core.memory.agent.utils.redis_tool import store
|
|
||||||
from app.repositories import knowledge_repository, WorkspaceRepository
|
|
||||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import task_service, workspace_service
|
from app.services import task_service, workspace_service
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
from dotenv import load_dotenv
|
|
||||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from starlette.responses import StreamingResponse
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -37,7 +38,7 @@ router = APIRouter(
|
|||||||
|
|
||||||
@router.get("/health/status", response_model=ApiResponse)
|
@router.get("/health/status", response_model=ApiResponse)
|
||||||
async def get_health_status(
|
async def get_health_status(
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get latest health status written by Celery periodic task
|
Get latest health status written by Celery periodic task
|
||||||
@@ -55,8 +56,9 @@ async def get_health_status(
|
|||||||
|
|
||||||
@router.get("/download_log")
|
@router.get("/download_log")
|
||||||
async def download_log(
|
async def download_log(
|
||||||
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
|
log_type: str = Query("file", regex="^(file|transmission)$",
|
||||||
current_user: User = Depends(get_current_user)
|
description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Download or stream agent service log file
|
Download or stream agent service log file
|
||||||
@@ -119,10 +121,10 @@ async def download_log(
|
|||||||
@router.post("/writer_service", response_model=ApiResponse)
|
@router.post("/writer_service", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def write_server(
|
async def write_server(
|
||||||
user_input: Write_UserInput,
|
user_input: Write_UserInput,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Write service endpoint - processes write operations synchronously
|
Write service endpoint - processes write operations synchronously
|
||||||
@@ -161,13 +163,15 @@ async def write_server(
|
|||||||
if knowledge:
|
if knowledge:
|
||||||
user_rag_memory_id = str(knowledge.id)
|
user_rag_memory_id = str(knowledge.id)
|
||||||
else:
|
else:
|
||||||
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
api_logger.warning(
|
||||||
|
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
storage_type = 'neo4j'
|
||||||
else:
|
else:
|
||||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
storage_type = 'neo4j'
|
||||||
|
|
||||||
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
api_logger.info(
|
||||||
|
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||||
try:
|
try:
|
||||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
result = await memory_agent_service.write_memory(
|
result = await memory_agent_service.write_memory(
|
||||||
@@ -195,10 +199,10 @@ async def write_server(
|
|||||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
@router.post("/writer_service_async", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def write_server_async(
|
async def write_server_async(
|
||||||
user_input: Write_UserInput,
|
user_input: Write_UserInput,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Async write service endpoint - enqueues write processing to Celery
|
Async write service endpoint - enqueues write processing to Celery
|
||||||
@@ -216,7 +220,8 @@ async def write_server_async(
|
|||||||
|
|
||||||
config_id = user_input.config_id
|
config_id = user_input.config_id
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
api_logger.info(
|
||||||
|
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||||
|
|
||||||
# 获取 storage_type,如果为 None 则使用默认值
|
# 获取 storage_type,如果为 None 则使用默认值
|
||||||
storage_type = workspace_service.get_workspace_storage_type(
|
storage_type = workspace_service.get_workspace_storage_type(
|
||||||
@@ -254,9 +259,9 @@ async def write_server_async(
|
|||||||
@router.post("/read_service", response_model=ApiResponse)
|
@router.post("/read_service", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def read_server(
|
async def read_server(
|
||||||
user_input: UserInput,
|
user_input: UserInput,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Read service endpoint - processes read operations synchronously
|
Read service endpoint - processes read operations synchronously
|
||||||
@@ -292,7 +297,8 @@ async def read_server(
|
|||||||
if knowledge:
|
if knowledge:
|
||||||
user_rag_memory_id = str(knowledge.id)
|
user_rag_memory_id = str(knowledge.id)
|
||||||
|
|
||||||
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
api_logger.info(
|
||||||
|
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||||
try:
|
try:
|
||||||
result = await memory_agent_service.read_memory(
|
result = await memory_agent_service.read_memory(
|
||||||
user_input.end_user_id,
|
user_input.end_user_id,
|
||||||
@@ -306,7 +312,8 @@ async def read_server(
|
|||||||
)
|
)
|
||||||
if str(user_input.search_switch) == "2":
|
if str(user_input.search_switch) == "2":
|
||||||
retrieve_info = result['answer']
|
retrieve_info = result['answer']
|
||||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
|
||||||
|
user_input.end_user_id)
|
||||||
query = user_input.message
|
query = user_input.message
|
||||||
|
|
||||||
# 调用 memory_agent_service 的方法生成最终答案
|
# 调用 memory_agent_service 的方法生成最终答案
|
||||||
@@ -319,7 +326,7 @@ async def read_server(
|
|||||||
db=db
|
db=db
|
||||||
)
|
)
|
||||||
if "信息不足,无法回答" in result['answer']:
|
if "信息不足,无法回答" in result['answer']:
|
||||||
result['answer']=retrieve_info
|
result['answer'] = retrieve_info
|
||||||
return success(data=result, msg="回复对话消息成功")
|
return success(data=result, msg="回复对话消息成功")
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||||
@@ -335,9 +342,10 @@ async def read_server(
|
|||||||
@router.post("/file", response_model=ApiResponse)
|
@router.post("/file", response_model=ApiResponse)
|
||||||
async def file_update(
|
async def file_update(
|
||||||
files: List[UploadFile] = File(..., description="要上传的文件"),
|
files: List[UploadFile] = File(..., description="要上传的文件"),
|
||||||
model_id:str = Form(..., description="模型ID"),
|
model_id: str = Form(..., description="模型ID"),
|
||||||
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
文件上传接口 - 支持图片识别
|
文件上传接口 - 支持图片识别
|
||||||
@@ -350,9 +358,6 @@ async def file_update(
|
|||||||
Returns:
|
Returns:
|
||||||
文件处理结果
|
文件处理结果
|
||||||
"""
|
"""
|
||||||
|
|
||||||
db_gen = get_db() # get_db 通常是一个生成器
|
|
||||||
db = next(db_gen)
|
|
||||||
api_logger.info(f"File upload requested, file count: {len(files)}")
|
api_logger.info(f"File upload requested, file count: {len(files)}")
|
||||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||||
apiConfig: ModelApiKey = config.api_keys[0]
|
apiConfig: ModelApiKey = config.api_keys[0]
|
||||||
@@ -430,8 +435,8 @@ async def read_server_async(
|
|||||||
|
|
||||||
@router.get("/read_result/", response_model=ApiResponse)
|
@router.get("/read_result/", response_model=ApiResponse)
|
||||||
async def get_read_task_result(
|
async def get_read_task_result(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get the status and result of an async read task
|
Get the status and result of an async read task
|
||||||
@@ -507,8 +512,8 @@ async def get_read_task_result(
|
|||||||
|
|
||||||
@router.get("/write_result/", response_model=ApiResponse)
|
@router.get("/write_result/", response_model=ApiResponse)
|
||||||
async def get_write_task_result(
|
async def get_write_task_result(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get the status and result of an async write task
|
Get the status and result of an async write task
|
||||||
@@ -584,9 +589,9 @@ async def get_write_task_result(
|
|||||||
|
|
||||||
@router.post("/status_type", response_model=ApiResponse)
|
@router.post("/status_type", response_model=ApiResponse)
|
||||||
async def status_type(
|
async def status_type(
|
||||||
user_input: Write_UserInput,
|
user_input: Write_UserInput,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Determine the type of user message (read or write)
|
Determine the type of user message (read or write)
|
||||||
@@ -629,9 +634,10 @@ async def status_type(
|
|||||||
|
|
||||||
@router.get("/stats/types", response_model=ApiResponse)
|
@router.get("/stats/types", response_model=ApiResponse)
|
||||||
async def get_knowledge_type_stats_api(
|
async def get_knowledge_type_stats_api(
|
||||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
||||||
@@ -640,14 +646,9 @@ async def get_knowledge_type_stats_api(
|
|||||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||||
- 如果用户没有当前工作空间,对应的统计返回 0
|
- 如果用户没有当前工作空间,对应的统计返回 0
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
api_logger.info(
|
||||||
|
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
from app.db import get_db
|
|
||||||
|
|
||||||
# 获取数据库会话
|
|
||||||
db_gen = get_db()
|
|
||||||
db = next(db_gen)
|
|
||||||
|
|
||||||
# 调用service层函数
|
# 调用service层函数
|
||||||
result = await memory_agent_service.get_knowledge_type_stats(
|
result = await memory_agent_service.get_knowledge_type_stats(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -664,11 +665,11 @@ async def get_knowledge_type_stats_api(
|
|||||||
|
|
||||||
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
||||||
async def get_interest_distribution_by_user_api(
|
async def get_interest_distribution_by_user_api(
|
||||||
end_user_id: str = Query(..., description="用户ID(必填)"),
|
end_user_id: str = Query(..., description="用户ID(必填)"),
|
||||||
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取指定用户的兴趣分布标签
|
获取指定用户的兴趣分布标签
|
||||||
@@ -716,9 +717,9 @@ async def get_interest_distribution_by_user_api(
|
|||||||
|
|
||||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||||
async def get_user_profile_api(
|
async def get_user_profile_api(
|
||||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取用户详情,包含:
|
获取用户详情,包含:
|
||||||
@@ -782,9 +783,9 @@ async def get_user_profile_api(
|
|||||||
|
|
||||||
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
|
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
|
||||||
async def get_end_user_connected_config(
|
async def get_end_user_connected_config(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取终端用户关联的记忆配置
|
获取终端用户关联的记忆配置
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from app.core.logging_config import get_agent_logger
|
|
||||||
from app.db import get_db
|
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
||||||
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
from app.core.memory.agent.utils.llm_tools import (
|
from app.core.memory.agent.utils.llm_tools import (
|
||||||
PROJECT_ROOT_,
|
PROJECT_ROOT_,
|
||||||
ReadState,
|
ReadState,
|
||||||
@@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import (
|
|||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
from app.db import get_db_context
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
db_session = next(get_db())
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -53,13 +52,14 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用优化的LLM服务
|
# 使用优化的LLM服务
|
||||||
structured = await problem_service.call_llm_structured(
|
with get_db_context() as db_session:
|
||||||
state=state,
|
structured = await problem_service.call_llm_structured(
|
||||||
db_session=db_session,
|
state=state,
|
||||||
system_prompt=system_prompt,
|
db_session=db_session,
|
||||||
response_model=ProblemExtensionResponse,
|
system_prompt=system_prompt,
|
||||||
fallback_value=[]
|
response_model=ProblemExtensionResponse,
|
||||||
)
|
fallback_value=[]
|
||||||
|
)
|
||||||
|
|
||||||
# 添加更详细的日志记录
|
# 添加更详细的日志记录
|
||||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||||
@@ -171,13 +171,14 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用优化的LLM服务
|
# 使用优化的LLM服务
|
||||||
response_content = await problem_service.call_llm_structured(
|
with get_db_context() as db_session:
|
||||||
state=state,
|
response_content = await problem_service.call_llm_structured(
|
||||||
db_session=db_session,
|
state=state,
|
||||||
system_prompt=system_prompt,
|
db_session=db_session,
|
||||||
response_model=ProblemExtensionResponse,
|
system_prompt=system_prompt,
|
||||||
fallback_value=[]
|
response_model=ProblemExtensionResponse,
|
||||||
)
|
fallback_value=[]
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||||
|
|
||||||
|
|||||||
@@ -6,31 +6,26 @@ import os
|
|||||||
# ===== 第三方库 =====
|
# ===== 第三方库 =====
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.db import get_db, get_db_context
|
|
||||||
|
|
||||||
from app.schemas import model_schema
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
from app.services.model_service import ModelConfigService
|
|
||||||
|
|
||||||
from app.core.memory.agent.services.search_service import SearchService
|
|
||||||
from app.core.memory.agent.utils.llm_tools import (
|
|
||||||
COUNTState,
|
|
||||||
ReadState,
|
|
||||||
deduplicate_entries,
|
|
||||||
merge_to_key_value_pairs,
|
|
||||||
)
|
|
||||||
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
||||||
create_hybrid_retrieval_tool_sync,
|
create_hybrid_retrieval_tool_sync,
|
||||||
create_time_retrieval_tool,
|
create_time_retrieval_tool,
|
||||||
extract_tool_message_content,
|
extract_tool_message_content,
|
||||||
)
|
)
|
||||||
|
from app.core.memory.agent.services.search_service import SearchService
|
||||||
|
from app.core.memory.agent.utils.llm_tools import (
|
||||||
|
ReadState,
|
||||||
|
deduplicate_entries,
|
||||||
|
merge_to_key_value_pairs,
|
||||||
|
)
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.schemas import model_schema
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
db = next(get_db())
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def rag_config(state):
|
async def rag_config(state):
|
||||||
@@ -50,10 +45,12 @@ async def rag_config(state):
|
|||||||
"reranker_top_k": 10
|
"reranker_top_k": 10
|
||||||
}
|
}
|
||||||
return kb_config
|
return kb_config
|
||||||
async def rag_knowledge(state,question):
|
|
||||||
|
|
||||||
|
async def rag_knowledge(state, question):
|
||||||
kb_config = await rag_config(state)
|
kb_config = await rag_config(state)
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||||
try:
|
try:
|
||||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||||
@@ -61,13 +58,13 @@ async def rag_knowledge(state,question):
|
|||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
raw_results = clean_content
|
raw_results = clean_content
|
||||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||||
except Exception :
|
except Exception:
|
||||||
retrieval_knowledge=[]
|
retrieval_knowledge = []
|
||||||
clean_content = ''
|
clean_content = ''
|
||||||
raw_results = ''
|
raw_results = ''
|
||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||||
return retrieval_knowledge,clean_content,cleaned_query,raw_results
|
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||||
|
|
||||||
|
|
||||||
async def llm_infomation(state: ReadState) -> ReadState:
|
async def llm_infomation(state: ReadState) -> ReadState:
|
||||||
@@ -141,7 +138,6 @@ async def clean_databases(data) -> str:
|
|||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
text_parts.append(item)
|
text_parts.append(item)
|
||||||
|
|
||||||
|
|
||||||
return '\n'.join(text_parts).strip()
|
return '\n'.join(text_parts).strip()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -150,23 +146,23 @@ async def clean_databases(data) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
模型信息
|
模型信息
|
||||||
'''
|
'''
|
||||||
|
|
||||||
problem_extension=state.get('problem_extension', '')['context']
|
problem_extension = state.get('problem_extension', '')['context']
|
||||||
storage_type=state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
end_user_id=state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
original=state.get('data', '')
|
original = state.get('data', '')
|
||||||
problem_list=[]
|
problem_list = []
|
||||||
for key,values in problem_extension.items():
|
for key, values in problem_extension.items():
|
||||||
for data in values:
|
for data in values:
|
||||||
problem_list.append(data)
|
problem_list.append(data)
|
||||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
|
|
||||||
# 创建异步任务处理单个问题
|
# 创建异步任务处理单个问题
|
||||||
async def process_question_nodes(idx, question):
|
async def process_question_nodes(idx, question):
|
||||||
try:
|
try:
|
||||||
@@ -244,7 +240,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
send_verify = []
|
send_verify = []
|
||||||
for i, j in zip(keys, val, strict=False):
|
for i, j in zip(keys, val, strict=False):
|
||||||
if j!=['']:
|
if j != ['']:
|
||||||
send_verify.append({
|
send_verify.append({
|
||||||
"Query_small": i,
|
"Query_small": i,
|
||||||
"Answer_Small": j
|
"Answer_Small": j
|
||||||
@@ -257,15 +253,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||||
return {'retrieve':dup_databases}
|
return {'retrieve': dup_databases}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def retrieve(state: ReadState) -> ReadState:
|
async def retrieve(state: ReadState) -> ReadState:
|
||||||
# 从state中获取end_user_id
|
# 从state中获取end_user_id
|
||||||
import time
|
import time
|
||||||
start=time.time()
|
start = time.time()
|
||||||
problem_extension = state.get('problem_extension', '')['context']
|
problem_extension = state.get('problem_extension', '')['context']
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
@@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
with get_db_context() as db: # 使用同步数据库上下文管理器
|
with get_db_context() as db: # 使用同步数据库上下文管理器
|
||||||
config_service = MemoryConfigService(db)
|
config_service = MemoryConfigService(db)
|
||||||
return await llm_infomation(state)
|
return await llm_infomation(state)
|
||||||
|
|
||||||
llm_config = await get_llm_info()
|
llm_config = await get_llm_info()
|
||||||
api_key_obj = llm_config.api_keys[0]
|
api_key_obj = llm_config.api_keys[0]
|
||||||
api_key = api_key_obj.api_key
|
api_key = api_key_obj.api_key
|
||||||
@@ -296,11 +291,11 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
)
|
)
|
||||||
|
|
||||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||||
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
|
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
|
||||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||||
agent = create_agent(
|
agent = create_agent(
|
||||||
llm,
|
llm,
|
||||||
tools=[time_retrieval_tool,hybrid_retrieval],
|
tools=[time_retrieval_tool, hybrid_retrieval],
|
||||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
async with SEMAPHORE: # 限制并发
|
async with SEMAPHORE: # 限制并发
|
||||||
try:
|
try:
|
||||||
if storage_type == "rag" and user_rag_memory_id:
|
if storage_type == "rag" and user_rag_memory_id:
|
||||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
|
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||||
|
question)
|
||||||
else:
|
else:
|
||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||||
@@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
# json.dump(dup_databases, f, indent=4)
|
# json.dump(dup_databases, f, indent=4)
|
||||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||||
return {'retrieve': dup_databases}
|
return {'retrieve': dup_databases}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -18,12 +16,11 @@ from app.core.memory.agent.utils.redis_tool import store
|
|||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
|
from app.db import get_db_context
|
||||||
from app.db import get_db
|
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
db_session = next(get_db())
|
|
||||||
|
|
||||||
class SummaryNodeService(LLMServiceMixin):
|
class SummaryNodeService(LLMServiceMixin):
|
||||||
"""总结节点服务类"""
|
"""总结节点服务类"""
|
||||||
@@ -32,8 +29,11 @@ class SummaryNodeService(LLMServiceMixin):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# 创建全局服务实例
|
||||||
summary_service = SummaryNodeService()
|
summary_service = SummaryNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def rag_config(state):
|
async def rag_config(state):
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
kb_config = {
|
kb_config = {
|
||||||
@@ -51,10 +51,12 @@ async def rag_config(state):
|
|||||||
"reranker_top_k": 10
|
"reranker_top_k": 10
|
||||||
}
|
}
|
||||||
return kb_config
|
return kb_config
|
||||||
async def rag_knowledge(state,question):
|
|
||||||
|
|
||||||
|
async def rag_knowledge(state, question):
|
||||||
kb_config = await rag_config(state)
|
kb_config = await rag_config(state)
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||||
try:
|
try:
|
||||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||||
@@ -62,20 +64,23 @@ async def rag_knowledge(state,question):
|
|||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
raw_results = clean_content
|
raw_results = clean_content
|
||||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||||
except Exception :
|
except Exception:
|
||||||
retrieval_knowledge=[]
|
retrieval_knowledge = []
|
||||||
clean_content = ''
|
clean_content = ''
|
||||||
raw_results = ''
|
raw_results = ''
|
||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||||
return retrieval_knowledge,clean_content,cleaned_query,raw_results
|
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||||
|
|
||||||
|
|
||||||
async def summary_history(state: ReadState) -> ReadState:
|
async def summary_history(state: ReadState) -> ReadState:
|
||||||
end_user_id = state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
|
||||||
|
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||||
|
search_mode) -> str:
|
||||||
"""
|
"""
|
||||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||||
"""
|
"""
|
||||||
@@ -99,13 +104,14 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# 使用优化的LLM服务进行结构化输出
|
# 使用优化的LLM服务进行结构化输出
|
||||||
structured = await summary_service.call_llm_structured(
|
with get_db_context() as db_session:
|
||||||
state=state,
|
structured = await summary_service.call_llm_structured(
|
||||||
db_session=db_session,
|
state=state,
|
||||||
system_prompt=system_prompt,
|
db_session=db_session,
|
||||||
response_model=response_model,
|
system_prompt=system_prompt,
|
||||||
fallback_value=None
|
response_model=response_model,
|
||||||
)
|
fallback_value=None
|
||||||
|
)
|
||||||
# 验证结构化响应
|
# 验证结构化响应
|
||||||
if structured is None:
|
if structured is None:
|
||||||
logger.warning("LLM返回None,使用默认回答")
|
logger.warning("LLM返回None,使用默认回答")
|
||||||
@@ -157,7 +163,8 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
logger.error(f"Fallback也失败: {fallback_error}")
|
logger.error(f"Fallback也失败: {fallback_error}")
|
||||||
return "信息不足,无法回答"
|
return "信息不足,无法回答"
|
||||||
|
|
||||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
|
||||||
|
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||||
data = state.get("data", '')
|
data = state.get("data", '')
|
||||||
end_user_id = state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
await SessionService(store).save_session(
|
await SessionService(store).save_session(
|
||||||
@@ -169,10 +176,12 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
|||||||
)
|
)
|
||||||
await SessionService(store).cleanup_duplicates()
|
await SessionService(store).cleanup_duplicates()
|
||||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||||
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
|
||||||
storage_type=state.get("storage_type",'')
|
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||||
data=state.get("data", '')
|
storage_type = state.get("storage_type", '')
|
||||||
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
|
data = state.get("data", '')
|
||||||
input_summary = {
|
input_summary = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"summary_result": aimessages,
|
"summary_result": aimessages,
|
||||||
@@ -189,14 +198,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
|||||||
"user_rag_memory_id": user_rag_memory_id
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
retrieve={
|
retrieve = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"summary_result": aimessages,
|
"summary_result": aimessages,
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
"user_rag_memory_id": user_rag_memory_id,
|
"user_rag_memory_id": user_rag_memory_id,
|
||||||
"_intermediate": {
|
"_intermediate": {
|
||||||
"type": "retrieval_summary",
|
"type": "retrieval_summary",
|
||||||
"title":"快速检索",
|
"title": "快速检索",
|
||||||
"summary": aimessages,
|
"summary": aimessages,
|
||||||
"query": data,
|
"query": data,
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
@@ -204,17 +213,18 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return input_summary,retrieve
|
return input_summary, retrieve
|
||||||
|
|
||||||
|
|
||||||
async def Input_Summary(state: ReadState) -> ReadState:
|
async def Input_Summary(state: ReadState) -> ReadState:
|
||||||
start=time.time()
|
start = time.time()
|
||||||
storage_type=state.get("storage_type",'')
|
storage_type = state.get("storage_type", '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
data=state.get("data", '')
|
data = state.get("data", '')
|
||||||
end_user_id=state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
history = await summary_history( state)
|
history = await summary_history(state)
|
||||||
search_params = {
|
search_params = {
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"question": data,
|
"question": data,
|
||||||
@@ -223,12 +233,13 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if storage_type!="rag":
|
if storage_type != "rag":
|
||||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
|
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
|
||||||
|
memory_config=memory_config)
|
||||||
else:
|
else:
|
||||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
|
logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
|
||||||
retrieve_info, question, raw_results = "", data, []
|
retrieve_info, question, raw_results = "", data, []
|
||||||
try:
|
try:
|
||||||
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
||||||
@@ -237,8 +248,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
|
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
|
||||||
summary = summary_result[0]
|
summary = summary_result[0]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error( f"Input_Summary failed: {e}", exc_info=True )
|
logger.error(f"Input_Summary failed: {e}", exc_info=True)
|
||||||
summary= {
|
summary = {
|
||||||
"status": "fail",
|
"status": "fail",
|
||||||
"summary_result": "信息不足,无法回答",
|
"summary_result": "信息不足,无法回答",
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
@@ -251,30 +262,31 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
except Exception:
|
except Exception:
|
||||||
duration = 0.0
|
duration = 0.0
|
||||||
log_time('检索', duration)
|
log_time('检索', duration)
|
||||||
return {"summary":summary}
|
return {"summary": summary}
|
||||||
|
|
||||||
async def Retrieve_Summary(state: ReadState)-> ReadState:
|
|
||||||
retrieve=state.get("retrieve", '')
|
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||||
history = await summary_history( state)
|
retrieve = state.get("retrieve", '')
|
||||||
|
history = await summary_history(state)
|
||||||
import json
|
import json
|
||||||
with open("检索.json","w",encoding='utf-8') as f:
|
with open("检索.json", "w", encoding='utf-8') as f:
|
||||||
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
|
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
|
||||||
retrieve=retrieve.get("Expansion_issue", [])
|
retrieve = retrieve.get("Expansion_issue", [])
|
||||||
start=time.time()
|
start = time.time()
|
||||||
retrieve_info_str=[]
|
retrieve_info_str = []
|
||||||
for data in retrieve:
|
for data in retrieve:
|
||||||
if data=='':
|
if data == '':
|
||||||
retrieve_info_str=''
|
retrieve_info_str = ''
|
||||||
else:
|
else:
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key=='Answer_Small':
|
if key == 'Answer_Small':
|
||||||
for i in value:
|
for i in value:
|
||||||
retrieve_info_str.append(i)
|
retrieve_info_str.append(i)
|
||||||
retrieve_info_str=list(set(retrieve_info_str))
|
retrieve_info_str = list(set(retrieve_info_str))
|
||||||
retrieve_info_str='\n'.join(retrieve_info_str)
|
retrieve_info_str = '\n'.join(retrieve_info_str)
|
||||||
|
|
||||||
aimessages=await summary_llm(state,history,retrieve_info_str,
|
aimessages = await summary_llm(state, history, retrieve_info_str,
|
||||||
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
|
||||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||||
await summary_redis_save(state, aimessages)
|
await summary_redis_save(state, aimessages)
|
||||||
if aimessages == '':
|
if aimessages == '':
|
||||||
@@ -290,29 +302,29 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
|
|||||||
# 修复协程调用 - 先await,然后访问返回值
|
# 修复协程调用 - 先await,然后访问返回值
|
||||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||||
summary = summary_result[1]
|
summary = summary_result[1]
|
||||||
return {"summary":summary}
|
return {"summary": summary}
|
||||||
|
|
||||||
|
|
||||||
async def Summary(state: ReadState)-> ReadState:
|
async def Summary(state: ReadState) -> ReadState:
|
||||||
start=time.time()
|
start = time.time()
|
||||||
query = state.get("data", '')
|
query = state.get("data", '')
|
||||||
verify=state.get("verify", '')
|
verify = state.get("verify", '')
|
||||||
verify_expansion_issue=verify.get("verified_data", '')
|
verify_expansion_issue = verify.get("verified_data", '')
|
||||||
retrieve_info_str=''
|
retrieve_info_str = ''
|
||||||
for data in verify_expansion_issue:
|
for data in verify_expansion_issue:
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key=='answer_small':
|
if key == 'answer_small':
|
||||||
for i in value:
|
for i in value:
|
||||||
retrieve_info_str+=i+'\n'
|
retrieve_info_str += i + '\n'
|
||||||
history=await summary_history(state)
|
history = await summary_history(state)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"history": history,
|
"history": history,
|
||||||
"retrieve_info": retrieve_info_str
|
"retrieve_info": retrieve_info_str
|
||||||
}
|
}
|
||||||
aimessages=await summary_llm(state,history,data,
|
aimessages = await summary_llm(state, history, data,
|
||||||
'summary_prompt.jinja2','summary',SummaryResponse,0)
|
'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||||
|
|
||||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||||
await summary_redis_save(state, aimessages)
|
await summary_redis_save(state, aimessages)
|
||||||
@@ -327,10 +339,12 @@ async def Summary(state: ReadState)-> ReadState:
|
|||||||
# 修复协程调用 - 先await,然后访问返回值
|
# 修复协程调用 - 先await,然后访问返回值
|
||||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||||
summary = summary_result[1]
|
summary = summary_result[1]
|
||||||
return {"summary":summary}
|
return {"summary": summary}
|
||||||
async def Summary_fails(state: ReadState)-> ReadState:
|
|
||||||
storage_type=state.get("storage_type", '')
|
|
||||||
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
async def Summary_fails(state: ReadState) -> ReadState:
|
||||||
|
storage_type = state.get("storage_type", '')
|
||||||
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
history = await summary_history(state)
|
history = await summary_history(state)
|
||||||
query = state.get("data", '')
|
query = state.get("data", '')
|
||||||
verify = state.get("verify", '')
|
verify = state.get("verify", '')
|
||||||
@@ -346,12 +360,12 @@ async def Summary_fails(state: ReadState)-> ReadState:
|
|||||||
"history": history,
|
"history": history,
|
||||||
"retrieve_info": retrieve_info_str
|
"retrieve_info": retrieve_info_str
|
||||||
}
|
}
|
||||||
aimessages = await summary_llm(state, history, data,
|
aimessages = await summary_llm(state, history, data,
|
||||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||||
result= {
|
result = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"summary_result": aimessages,
|
"summary_result": aimessages,
|
||||||
"storage_type": storage_type,
|
"storage_type": storage_type,
|
||||||
"user_rag_memory_id": user_rag_memory_id
|
"user_rag_memory_id": user_rag_memory_id
|
||||||
}
|
}
|
||||||
return {"summary":result}
|
return {"summary": result}
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from app.core.logging_config import get_agent_logger
|
|
||||||
from app.db import get_db
|
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.models.verification_models import VerificationResult
|
from app.core.memory.agent.models.verification_models import VerificationResult
|
||||||
|
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||||
from app.core.memory.agent.utils.llm_tools import (
|
from app.core.memory.agent.utils.llm_tools import (
|
||||||
PROJECT_ROOT_,
|
PROJECT_ROOT_,
|
||||||
ReadState,
|
ReadState,
|
||||||
@@ -10,12 +11,12 @@ from app.core.memory.agent.utils.llm_tools import (
|
|||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
from app.db import get_db_context
|
||||||
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
db_session = next(get_db())
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class VerificationNodeService(LLMServiceMixin):
|
class VerificationNodeService(LLMServiceMixin):
|
||||||
"""验证节点服务类"""
|
"""验证节点服务类"""
|
||||||
|
|
||||||
@@ -23,9 +24,11 @@ class VerificationNodeService(LLMServiceMixin):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# 创建全局服务实例
|
||||||
verification_service = VerificationNodeService()
|
verification_service = VerificationNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||||
"""处理验证结果并生成输出格式"""
|
"""处理验证结果并生成输出格式"""
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
@@ -58,6 +61,8 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Verify_result
|
return Verify_result
|
||||||
|
|
||||||
|
|
||||||
async def Verify(state: ReadState):
|
async def Verify(state: ReadState):
|
||||||
logger.info("=== Verify 节点开始执行 ===")
|
logger.info("=== Verify 节点开始执行 ===")
|
||||||
try:
|
try:
|
||||||
@@ -71,7 +76,8 @@ async def Verify(state: ReadState):
|
|||||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||||
|
|
||||||
retrieve = state.get("retrieve", {})
|
retrieve = state.get("retrieve", {})
|
||||||
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
logger.info(
|
||||||
|
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||||
|
|
||||||
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
||||||
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
||||||
@@ -100,23 +106,24 @@ async def Verify(state: ReadState):
|
|||||||
try:
|
try:
|
||||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||||
import asyncio
|
|
||||||
structured = await asyncio.wait_for(
|
with get_db_context() as db_session:
|
||||||
verification_service.call_llm_structured(
|
structured = await asyncio.wait_for(
|
||||||
state=state,
|
verification_service.call_llm_structured(
|
||||||
db_session=db_session,
|
state=state,
|
||||||
system_prompt=system_prompt,
|
db_session=db_session,
|
||||||
response_model=VerificationResult,
|
system_prompt=system_prompt,
|
||||||
fallback_value={
|
response_model=VerificationResult,
|
||||||
"query": content,
|
fallback_value={
|
||||||
"history": history if isinstance(history, list) else [],
|
"query": content,
|
||||||
"expansion_issue": [],
|
"history": history if isinstance(history, list) else [],
|
||||||
"split_result": "failed",
|
"expansion_issue": [],
|
||||||
"reason": "验证失败或超时"
|
"split_result": "failed",
|
||||||
}
|
"reason": "验证失败或超时"
|
||||||
),
|
}
|
||||||
timeout=150.0 # 150秒超时
|
),
|
||||||
)
|
timeout=150.0 # 150秒超时
|
||||||
|
)
|
||||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值")
|
logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值")
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage
|
|||||||
from langgraph.constants import START, END
|
from langgraph.constants import START, END
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
|
|
||||||
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
@@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def make_read_graph():
|
async def make_read_graph():
|
||||||
"""创建并返回 LangGraph 工作流"""
|
"""创建并返回 LangGraph 工作流"""
|
||||||
@@ -62,7 +60,6 @@ async def make_read_graph():
|
|||||||
workflow.add_edge("Summary_fails", END)
|
workflow.add_edge("Summary_fails", END)
|
||||||
workflow.add_edge("Summary", END)
|
workflow.add_edge("Summary", END)
|
||||||
|
|
||||||
|
|
||||||
'''-----'''
|
'''-----'''
|
||||||
# workflow.add_edge("Retrieve", END)
|
# workflow.add_edge("Retrieve", END)
|
||||||
|
|
||||||
@@ -76,6 +73,7 @@ async def make_read_graph():
|
|||||||
finally:
|
finally:
|
||||||
print("工作流创建完成")
|
print("工作流创建完成")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
"""主函数 - 运行工作流"""
|
"""主函数 - 运行工作流"""
|
||||||
message = "昨天有什么好看的电影"
|
message = "昨天有什么好看的电影"
|
||||||
@@ -92,13 +90,15 @@ async def main():
|
|||||||
service_name="MemoryAgentService"
|
service_name="MemoryAgentService"
|
||||||
)
|
)
|
||||||
import time
|
import time
|
||||||
start=time.time()
|
start = time.time()
|
||||||
try:
|
try:
|
||||||
async with make_read_graph() as graph:
|
async with make_read_graph() as graph:
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
config = {"configurable": {"thread_id": end_user_id}}
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
"end_user_id": end_user_id
|
||||||
|
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||||
|
"memory_config": memory_config}
|
||||||
# 获取节点更新信息
|
# 获取节点更新信息
|
||||||
_intermediate_outputs = []
|
_intermediate_outputs = []
|
||||||
summary = ''
|
summary = ''
|
||||||
@@ -141,7 +141,6 @@ async def main():
|
|||||||
if verify_n and verify_n != [] and verify_n != {}:
|
if verify_n and verify_n != [] and verify_n != {}:
|
||||||
_intermediate_outputs.append(verify_n)
|
_intermediate_outputs.append(verify_n)
|
||||||
|
|
||||||
|
|
||||||
# Summary 节点
|
# Summary 节点
|
||||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||||
if summary_n and summary_n != [] and summary_n != {}:
|
if summary_n and summary_n != [] and summary_n != {}:
|
||||||
@@ -165,13 +164,16 @@ async def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
db_session.close()
|
||||||
|
|
||||||
end=time.time()
|
end = time.time()
|
||||||
print(100*'y')
|
print(100 * 'y')
|
||||||
print(f"总耗时: {end-start}s")
|
print(f"总耗时: {end - start}s")
|
||||||
print(100*'y')
|
print(100 * 'y')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -1,56 +0,0 @@
|
|||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Dict, Optional
|
|
||||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
|
|
||||||
from app.db import get_db
|
|
||||||
from app.core.logging_config import get_agent_logger
|
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
|
||||||
|
|
||||||
class LLMClientPool:
|
|
||||||
"""LLM客户端连接池"""
|
|
||||||
|
|
||||||
def __init__(self, max_size: int = 5):
|
|
||||||
self.max_size = max_size
|
|
||||||
self.pools: Dict[str, asyncio.Queue] = {}
|
|
||||||
self.active_clients: Dict[str, int] = {}
|
|
||||||
|
|
||||||
async def get_client(self, llm_model_id: str):
|
|
||||||
"""获取LLM客户端"""
|
|
||||||
if llm_model_id not in self.pools:
|
|
||||||
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
|
|
||||||
self.active_clients[llm_model_id] = 0
|
|
||||||
|
|
||||||
pool = self.pools[llm_model_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 尝试从池中获取客户端
|
|
||||||
client = pool.get_nowait()
|
|
||||||
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
|
|
||||||
return client
|
|
||||||
except asyncio.QueueEmpty:
|
|
||||||
# 池为空,创建新客户端
|
|
||||||
if self.active_clients[llm_model_id] < self.max_size:
|
|
||||||
db_session = next(get_db())
|
|
||||||
client = get_llm_client_fast(llm_model_id, db_session)
|
|
||||||
self.active_clients[llm_model_id] += 1
|
|
||||||
logger.debug(f"创建新LLM客户端: {llm_model_id}")
|
|
||||||
return client
|
|
||||||
else:
|
|
||||||
# 等待可用客户端
|
|
||||||
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
|
|
||||||
return await pool.get()
|
|
||||||
|
|
||||||
async def return_client(self, llm_model_id: str, client):
|
|
||||||
"""归还LLM客户端到池中"""
|
|
||||||
if llm_model_id in self.pools:
|
|
||||||
try:
|
|
||||||
self.pools[llm_model_id].put_nowait(client)
|
|
||||||
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
|
|
||||||
except asyncio.QueueFull:
|
|
||||||
# 池已满,丢弃客户端
|
|
||||||
self.active_clients[llm_model_id] -= 1
|
|
||||||
logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}")
|
|
||||||
|
|
||||||
# 全局客户端池
|
|
||||||
llm_client_pool = LLMClientPool()
|
|
||||||
@@ -14,7 +14,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
|
|||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
from app.db import get_db
|
from app.db import get_db_context
|
||||||
from app.models import AppRelease
|
from app.models import AppRelease
|
||||||
from app.services.draft_run_service import AgentRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
|
|||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {"output": VariableType.STRING}
|
return {"output": VariableType.STRING}
|
||||||
|
|
||||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]:
|
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]:
|
||||||
"""准备 Agent(公共逻辑)
|
"""准备 Agent(公共逻辑)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -57,17 +57,17 @@ class AgentNode(BaseNode):
|
|||||||
if not agent_id:
|
if not agent_id:
|
||||||
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
|
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
|
||||||
|
|
||||||
db = next(get_db())
|
with get_db_context() as db:
|
||||||
release = db.query(AppRelease).filter(
|
release = db.query(AppRelease).filter(
|
||||||
AppRelease.id == agent_id
|
AppRelease.id == agent_id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not release:
|
if not release:
|
||||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||||
|
|
||||||
draft_service = AgentRunService(db)
|
|
||||||
|
|
||||||
return draft_service, release, message
|
|
||||||
|
return release, message
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
"""非流式执行
|
"""非流式执行
|
||||||
@@ -79,19 +79,21 @@ class AgentNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
状态更新字典
|
状态更新字典
|
||||||
"""
|
"""
|
||||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
release, message = self._prepare_agent(variable_pool)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
|
||||||
|
with get_db_context() as db:
|
||||||
|
draft_service = AgentRunService(db)
|
||||||
|
|
||||||
# 执行 Agent(非流式)
|
# 执行 Agent(非流式)
|
||||||
result = await draft_service.run(
|
result = await draft_service.run(
|
||||||
agent_config=release.config,
|
agent_config=release.config,
|
||||||
model_config=None,
|
model_config=None,
|
||||||
message=message,
|
message=message,
|
||||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||||
user_id=state.get("user_id"),
|
user_id=state.get("user_id"),
|
||||||
variables=variable_pool.get_all_conversation_vars()
|
variables=variable_pool.get_all_conversation_vars()
|
||||||
)
|
)
|
||||||
|
|
||||||
response = result.get("response", "")
|
response = result.get("response", "")
|
||||||
|
|
||||||
@@ -118,34 +120,35 @@ class AgentNode(BaseNode):
|
|||||||
Yields:
|
Yields:
|
||||||
流式事件字典
|
流式事件字典
|
||||||
"""
|
"""
|
||||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
release, message = self._prepare_agent(variable_pool)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
|
||||||
|
|
||||||
# 累积完整响应
|
# 累积完整响应
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
with get_db_context() as db:
|
||||||
|
draft_service = AgentRunService(db)
|
||||||
# 执行 Agent(流式)
|
# 执行 Agent(流式)
|
||||||
async for chunk in draft_service.run_stream(
|
async for chunk in draft_service.run_stream(
|
||||||
agent_config=release.config,
|
agent_config=release.config,
|
||||||
model_config=None,
|
model_config=None,
|
||||||
message=message,
|
message=message,
|
||||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||||
user_id=state.get("user_id"),
|
user_id=state.get("user_id"),
|
||||||
variables=variable_pool.get_all_conversation_vars()
|
variables=variable_pool.get_all_conversation_vars()
|
||||||
):
|
):
|
||||||
# 提取内容
|
# 提取内容
|
||||||
content = chunk.get("content", "")
|
content = chunk.get("content", "")
|
||||||
full_response += content
|
full_response += content
|
||||||
|
|
||||||
# 流式返回每个 chunk
|
# 流式返回每个 chunk
|
||||||
yield {
|
yield {
|
||||||
"type": "chunk",
|
"type": "chunk",
|
||||||
"node_id": self.node_id,
|
"node_id": self.node_id,
|
||||||
"content": content,
|
"content": content,
|
||||||
"full_content": full_response,
|
"full_content": full_response,
|
||||||
"meta_data": chunk.get("meta_data", {})
|
"meta_data": chunk.get("meta_data", {})
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
|
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from app.core.error_codes import BizCode
|
|||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
|
from app.db import get_db_context
|
||||||
from app.models import AgentConfig, ModelConfig
|
from app.models import AgentConfig, ModelConfig
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.schemas.app_schema import FileInput
|
from app.schemas.app_schema import FileInput
|
||||||
@@ -103,9 +104,7 @@ def create_long_term_memory_tool(
|
|||||||
"""
|
"""
|
||||||
logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}")
|
logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}")
|
||||||
try:
|
try:
|
||||||
from app.db import get_db
|
with get_db_context() as db:
|
||||||
db = next(get_db())
|
|
||||||
try:
|
|
||||||
memory_content = asyncio.run(
|
memory_content = asyncio.run(
|
||||||
MemoryAgentService().read_memory(
|
MemoryAgentService().read_memory(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -127,9 +126,6 @@ def create_long_term_memory_tool(
|
|||||||
logger.info(f"读取任务状态:{status}")
|
logger.info(f"读取任务状态:{status}")
|
||||||
if memory_content:
|
if memory_content:
|
||||||
memory_content = memory_content['answer']
|
memory_content = memory_content['answer']
|
||||||
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
logger.info(f'用户ID:Agent:{end_user_id}')
|
logger.info(f'用户ID:Agent:{end_user_id}')
|
||||||
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
|
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config
|
|||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
@@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import (
|
|||||||
reorder_output_results,
|
reorder_output_results,
|
||||||
)
|
)
|
||||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||||
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_agent_schema import Write_UserInput
|
from app.schemas.memory_agent_schema import Write_UserInput
|
||||||
from app.schemas.memory_config_schema import ConfigurationError
|
from app.schemas.memory_config_schema import ConfigurationError
|
||||||
@@ -69,7 +66,8 @@ class MemoryAgentService:
|
|||||||
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||||||
# 记录成功的操作
|
# 记录成功的操作
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||||
|
success=True,
|
||||||
duration=duration, details={"message_length": len(message)})
|
duration=duration, details={"message_length": len(message)})
|
||||||
return context
|
return context
|
||||||
else:
|
else:
|
||||||
@@ -88,8 +86,6 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
raise ValueError(f"写入失败: {messages}")
|
raise ValueError(f"写入失败: {messages}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def extract_tool_call_info(self, event: Dict) -> bool:
|
def extract_tool_call_info(self, event: Dict) -> bool:
|
||||||
"""Extract tool call information from event"""
|
"""Extract tool call information from event"""
|
||||||
last_message = event["messages"][-1]
|
last_message = event["messages"][-1]
|
||||||
@@ -271,7 +267,8 @@ class MemoryAgentService:
|
|||||||
logger.info("Log streaming completed, cleaning up resources")
|
logger.info("Log streaming completed, cleaning up resources")
|
||||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||||
|
|
||||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
|
||||||
|
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
||||||
"""
|
"""
|
||||||
Process write operation with config_id
|
Process write operation with config_id
|
||||||
|
|
||||||
@@ -300,7 +297,8 @@ class MemoryAgentService:
|
|||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||||
if config_id is None and workspace_id is None:
|
if config_id is None and workspace_id is None:
|
||||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
raise ValueError(
|
||||||
|
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "No memory configuration found" in str(e):
|
if "No memory configuration found" in str(e):
|
||||||
raise # Re-raise our specific error
|
raise # Re-raise our specific error
|
||||||
@@ -331,7 +329,8 @@ class MemoryAgentService:
|
|||||||
# Log failed operation
|
# Log failed operation
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||||
|
success=False, duration=duration, error=error_msg)
|
||||||
|
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
@@ -351,9 +350,9 @@ class MemoryAgentService:
|
|||||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||||
elif msg['role'] == 'assistant':
|
elif msg['role'] == 'assistant':
|
||||||
langchain_messages.append(AIMessage(content=msg['content']))
|
langchain_messages.append(AIMessage(content=msg['content']))
|
||||||
print(100*'-')
|
print(100 * '-')
|
||||||
print(langchain_messages)
|
print(langchain_messages)
|
||||||
print(100*'-')
|
print(100 * '-')
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {
|
initial_state = {
|
||||||
"messages": langchain_messages,
|
"messages": langchain_messages,
|
||||||
@@ -375,29 +374,28 @@ class MemoryAgentService:
|
|||||||
contents = massages.get('write_result')
|
contents = massages.get('write_result')
|
||||||
# Convert messages back to string for logging
|
# Convert messages back to string for logging
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
|
||||||
|
contents)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Ensure proper error handling and logging
|
# Ensure proper error handling and logging
|
||||||
error_msg = f"Write operation failed: {str(e)}"
|
error_msg = f"Write operation failed: {str(e)}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||||
|
success=False, duration=duration, error=error_msg)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def read_memory(
|
async def read_memory(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
history: List[Dict],
|
history: List[Dict],
|
||||||
search_switch: str,
|
search_switch: str,
|
||||||
config_id: Optional[uuid.UUID]|int,
|
config_id: Optional[uuid.UUID] | int,
|
||||||
db: Session,
|
db: Session,
|
||||||
storage_type: str,
|
storage_type: str,
|
||||||
user_rag_memory_id: str) -> Dict:
|
user_rag_memory_id: str) -> Dict:
|
||||||
"""
|
"""
|
||||||
Process read operation with config_id
|
Process read operation with config_id
|
||||||
|
|
||||||
@@ -425,7 +423,7 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
ori_message= message
|
ori_message = message
|
||||||
|
|
||||||
# Resolve config_id and workspace_id
|
# Resolve config_id and workspace_id
|
||||||
# Always get workspace_id from end_user for fallback, even if config_id is provided
|
# Always get workspace_id from end_user for fallback, even if config_id is provided
|
||||||
@@ -437,7 +435,8 @@ class MemoryAgentService:
|
|||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||||
if config_id is None and workspace_id is None:
|
if config_id is None and workspace_id is None:
|
||||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
raise ValueError(
|
||||||
|
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "No memory configuration found" in str(e):
|
if "No memory configuration found" in str(e):
|
||||||
raise # Re-raise our specific error
|
raise # Re-raise our specific error
|
||||||
@@ -454,7 +453,6 @@ class MemoryAgentService:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
audit_logger = None
|
audit_logger = None
|
||||||
|
|
||||||
|
|
||||||
config_load_start = time.time()
|
config_load_start = time.time()
|
||||||
try:
|
try:
|
||||||
# Use a separate database session to avoid transaction failures
|
# Use a separate database session to avoid transaction failures
|
||||||
@@ -576,7 +574,8 @@ class MemoryAgentService:
|
|||||||
raw_results = intermediate.get('raw_results', {})
|
raw_results = intermediate.get('raw_results', {})
|
||||||
try:
|
try:
|
||||||
reranked_results = raw_results.get('reranked_results', [])
|
reranked_results = raw_results.get('reranked_results', [])
|
||||||
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
|
statements = [statement['statement'] for statement in
|
||||||
|
reranked_results.get('statements', [])]
|
||||||
except Exception:
|
except Exception:
|
||||||
statements = []
|
statements = []
|
||||||
|
|
||||||
@@ -602,7 +601,8 @@ class MemoryAgentService:
|
|||||||
)
|
)
|
||||||
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
|
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
logger.debug(
|
||||||
|
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||||||
|
|
||||||
except Exception as save_error:
|
except Exception as save_error:
|
||||||
# 保存失败不应该影响主流程,只记录错误
|
# 保存失败不应该影响主流程,只记录错误
|
||||||
@@ -610,7 +610,8 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
# Log successful operation
|
# Log successful operation
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
logger.info(
|
||||||
|
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(
|
audit_logger.log_operation(
|
||||||
@@ -641,7 +642,6 @@ class MemoryAgentService:
|
|||||||
)
|
)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Get standardized message list from user input.
|
Get standardized message list from user input.
|
||||||
@@ -665,7 +665,8 @@ class MemoryAgentService:
|
|||||||
for idx, msg in enumerate(user_input.messages):
|
for idx, msg in enumerate(user_input.messages):
|
||||||
if not isinstance(msg, dict):
|
if not isinstance(msg, dict):
|
||||||
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
|
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
|
||||||
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
raise ValueError(
|
||||||
|
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
||||||
|
|
||||||
if 'role' not in msg:
|
if 'role' not in msg:
|
||||||
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
|
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
|
||||||
@@ -673,7 +674,8 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
if 'content' not in msg:
|
if 'content' not in msg:
|
||||||
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
|
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
|
||||||
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
raise ValueError(
|
||||||
|
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
||||||
|
|
||||||
if msg['role'] not in ['user', 'assistant']:
|
if msg['role'] not in ['user', 'assistant']:
|
||||||
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
|
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
|
||||||
@@ -687,11 +689,11 @@ class MemoryAgentService:
|
|||||||
return user_input.messages
|
return user_input.messages
|
||||||
|
|
||||||
async def classify_message_type(
|
async def classify_message_type(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
config_id: UUID,
|
config_id: UUID,
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: Optional[UUID] = None
|
workspace_id: Optional[UUID] = None
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Determine the type of user message (read or write)
|
Determine the type of user message (read or write)
|
||||||
@@ -719,14 +721,15 @@ class MemoryAgentService:
|
|||||||
status = await status_typle(message, memory_config.llm_model_id)
|
status = await status_typle(message, memory_config.llm_model_id)
|
||||||
logger.debug(f"Message type: {status}")
|
logger.debug(f"Message type: {status}")
|
||||||
return status
|
return status
|
||||||
|
|
||||||
async def generate_summary_from_retrieve(
|
async def generate_summary_from_retrieve(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
retrieve_info: str,
|
retrieve_info: str,
|
||||||
history: List[Dict],
|
history: List[Dict],
|
||||||
query: str,
|
query: str,
|
||||||
config_id: str,
|
config_id: str,
|
||||||
db: Session
|
db: Session
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
基于检索信息、历史对话和查询生成最终答案
|
基于检索信息、历史对话和查询生成最终答案
|
||||||
@@ -805,13 +808,12 @@ class MemoryAgentService:
|
|||||||
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
|
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
|
||||||
return "信息不足,无法回答。"
|
return "信息不足,无法回答。"
|
||||||
|
|
||||||
|
|
||||||
async def get_knowledge_type_stats(
|
async def get_knowledge_type_stats(
|
||||||
self,
|
self,
|
||||||
end_user_id: Optional[str] = None,
|
db: Session,
|
||||||
only_active: bool = True,
|
end_user_id: Optional[str] = None,
|
||||||
current_workspace_id: Optional[uuid.UUID] = None,
|
only_active: bool = True,
|
||||||
db: Session = None
|
current_workspace_id: Optional[uuid.UUID] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
统计知识库类型分布,包含:
|
统计知识库类型分布,包含:
|
||||||
@@ -837,11 +839,6 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
# 1. 统计 PostgreSQL 中的知识库类型
|
# 1. 统计 PostgreSQL 中的知识库类型
|
||||||
try:
|
try:
|
||||||
if db is None:
|
|
||||||
from app.db import get_db
|
|
||||||
db_gen = get_db()
|
|
||||||
db = next(db_gen)
|
|
||||||
|
|
||||||
# 初始化所有标准类型为 0
|
# 初始化所有标准类型为 0
|
||||||
for kb_type in KnowledgeType:
|
for kb_type in KnowledgeType:
|
||||||
result[kb_type.value] = 0
|
result[kb_type.value] = 0
|
||||||
@@ -881,21 +878,19 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
# 3. 计算知识库类型总和(不包括 memory)
|
# 3. 计算知识库类型总和(不包括 memory)
|
||||||
result["total"] = (
|
result["total"] = (
|
||||||
result.get("General", 0) +
|
result.get("General", 0) +
|
||||||
result.get("Web", 0) +
|
result.get("Web", 0) +
|
||||||
result.get("Third-party", 0) +
|
result.get("Third-party", 0) +
|
||||||
result.get("Folder", 0)
|
result.get("Folder", 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def get_interest_distribution_by_user(
|
async def get_interest_distribution_by_user(
|
||||||
self,
|
self,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
language: str = "zh"
|
language: str = "zh"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取指定用户的兴趣分布标签。
|
获取指定用户的兴趣分布标签。
|
||||||
@@ -921,13 +916,12 @@ class MemoryAgentService:
|
|||||||
logger.error(f"兴趣分布标签查询失败: {e}")
|
logger.error(f"兴趣分布标签查询失败: {e}")
|
||||||
raise Exception(f"兴趣分布标签查询失败: {e}")
|
raise Exception(f"兴趣分布标签查询失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def get_user_profile(
|
async def get_user_profile(
|
||||||
self,
|
self,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user_id: Optional[str] = None,
|
current_user_id: Optional[str] = None,
|
||||||
llm_id: Optional[str] = None,
|
llm_id: Optional[str] = None,
|
||||||
db: Session = None
|
db: Session = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取用户详情,包含:
|
获取用户详情,包含:
|
||||||
@@ -1017,7 +1011,8 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
# 定义标签提取的结构
|
# 定义标签提取的结构
|
||||||
class UserTags(BaseModel):
|
class UserTags(BaseModel):
|
||||||
tags: list[str] = Field(..., description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友")
|
tags: list[str] = Field(...,
|
||||||
|
description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友")
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
ValueError: 当终端用户不存在或应用未发布时
|
ValueError: 当终端用户不存在或应用未发布时
|
||||||
"""
|
"""
|
||||||
import json as json_module
|
import json as json_module
|
||||||
import uuid
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
@@ -1268,7 +1262,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
"workspace_id": str(app.workspace_id)
|
"workspace_id": str(app.workspace_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
|
logger.info(
|
||||||
|
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,45 +1,42 @@
|
|||||||
# 修改 memory_konwledges_server.py 文件
|
# 修改 memory_konwledges_server.py 文件
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from fastapi import HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.rag.models.chunk import DocumentChunk
|
from app.core.rag.models.chunk import DocumentChunk
|
||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db_context
|
||||||
from app.schemas import file_schema, document_schema
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
|
||||||
from app.models.document_model import Document
|
from app.models.document_model import Document
|
||||||
import uuid
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from fastapi import HTTPException, status
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
|
from app.schemas import file_schema, document_schema
|
||||||
from app.schemas.file_schema import CustomTextFileCreate
|
from app.schemas.file_schema import CustomTextFileCreate
|
||||||
from app.services import document_service, file_service, knowledge_service
|
from app.services import document_service, file_service, knowledge_service
|
||||||
from app.celery_app import celery_app
|
|
||||||
from app.core.logging_config import get_api_logger
|
|
||||||
from app.schemas.file_schema import CustomTextFileCreate
|
|
||||||
from app.db import get_db
|
|
||||||
# 创建一个简单的用户类用于测试
|
# 创建一个简单的用户类用于测试
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
|
||||||
class ChunkCreate(BaseModel):
|
class ChunkCreate(BaseModel):
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class SimpleUser:
|
class SimpleUser:
|
||||||
def __init__(self, user_id: str):
|
def __init__(self, user_id: str):
|
||||||
# 确保ID是UUID类型
|
# 确保ID是UUID类型
|
||||||
self.id = user_id
|
self.id = user_id
|
||||||
self.username = user_id
|
self.username = user_id
|
||||||
|
|
||||||
'''解析'''
|
|
||||||
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
|
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
|
||||||
"""
|
"""
|
||||||
解析指定文档
|
解析指定文档
|
||||||
@@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user
|
|||||||
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
|
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
'''获取块ID'''
|
|
||||||
async def get_document_chunks(
|
async def get_document_chunks(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
document_id: uuid.UUID,
|
document_id: uuid.UUID,
|
||||||
@@ -198,7 +195,7 @@ async def get_document_chunks(
|
|||||||
|
|
||||||
return success(data=result, msg="文档块列表查询成功")
|
return success(data=result, msg="文档块列表查询成功")
|
||||||
|
|
||||||
'''查找文档ID'''
|
|
||||||
def find_document_id_by_kb_and_filename(
|
def find_document_id_by_kb_and_filename(
|
||||||
db: Session,
|
db: Session,
|
||||||
kb_id: str,
|
kb_id: str,
|
||||||
@@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
'''获取知识库ID'''
|
|
||||||
def find_documents_by_kb_id(
|
def find_documents_by_kb_id(
|
||||||
db: Session,
|
db: Session,
|
||||||
kb_id: str,
|
kb_id: str,
|
||||||
@@ -268,18 +265,14 @@ def find_documents_by_kb_id(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
''''上传文件'''
|
|
||||||
async def memory_konwledges_up(
|
async def memory_konwledges_up(
|
||||||
kb_id: str,
|
kb_id: str,
|
||||||
parent_id: str,
|
parent_id: str,
|
||||||
create_data: file_schema.CustomTextFileCreate,
|
create_data: file_schema.CustomTextFileCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session,
|
||||||
current_user: SimpleUser = None, # 修改为SimpleUser
|
current_user: SimpleUser,
|
||||||
):
|
):
|
||||||
# 如果没有提供current_user,则创建一个默认的
|
|
||||||
if current_user is None:
|
|
||||||
current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60")
|
|
||||||
|
|
||||||
content_bytes = create_data.content.encode('utf-8')
|
content_bytes = create_data.content.encode('utf-8')
|
||||||
file_size = len(content_bytes)
|
file_size = len(content_bytes)
|
||||||
print(f"file size: {file_size} byte")
|
print(f"file size: {file_size} byte")
|
||||||
@@ -350,8 +343,6 @@ async def memory_konwledges_up(
|
|||||||
|
|
||||||
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
||||||
|
|
||||||
'''添加新块'''
|
|
||||||
|
|
||||||
|
|
||||||
async def create_document_chunk(
|
async def create_document_chunk(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
@@ -450,6 +441,7 @@ async def create_document_chunk(
|
|||||||
|
|
||||||
return success(data=chunk, msg="文档块创建成功")
|
return success(data=chunk, msg="文档块创建成功")
|
||||||
|
|
||||||
|
|
||||||
async def write_rag(end_user_id, message, user_rag_memory_id):
|
async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||||
"""
|
"""
|
||||||
将消息写入 RAG 知识库
|
将消息写入 RAG 知识库
|
||||||
@@ -483,15 +475,12 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
|||||||
detail=f"知识库ID格式无效: {user_rag_memory_id}"
|
detail=f"知识库ID格式无效: {user_rag_memory_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
db_gen = get_db()
|
with get_db_context() as db:
|
||||||
db = next(db_gen)
|
|
||||||
|
|
||||||
try:
|
|
||||||
create_data = CustomTextFileCreate(title=end_user_id, content=message)
|
create_data = CustomTextFileCreate(title=end_user_id, content=message)
|
||||||
current_user = SimpleUser(user_rag_memory_id)
|
current_user = SimpleUser(user_rag_memory_id)
|
||||||
# 检查文档是否已存在
|
# 检查文档是否已存在
|
||||||
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
||||||
print('======',document)
|
print('======', document)
|
||||||
api_logger.info(f"查找文档结果: document_id={document}")
|
api_logger.info(f"查找文档结果: document_id={document}")
|
||||||
if document is not None:
|
if document is not None:
|
||||||
# 文档已存在,直接添加新块
|
# 文档已存在,直接添加新块
|
||||||
@@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
|||||||
else:
|
else:
|
||||||
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
|
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
|
||||||
return result
|
return result
|
||||||
finally:
|
|
||||||
# 确保数据库会话被关闭
|
|
||||||
db.close()
|
|
||||||
@@ -21,8 +21,7 @@ from app.repositories.end_user_repository import EndUserRepository
|
|||||||
from app.repositories.neo4j.cypher_queries import Graph_Node_query
|
from app.repositories.neo4j.cypher_queries import Graph_Node_query
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
||||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
from app.services.memory_base_service import MemoryBaseService
|
||||||
from app.services.memory_base_service import MemoryBaseService, MemoryTransService
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||||
from app.services.memory_short_service import ShortService
|
from app.services.memory_short_service import ShortService
|
||||||
@@ -1167,7 +1166,6 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
|||||||
|
|
||||||
from app.core.language_utils import validate_language
|
from app.core.language_utils import validate_language
|
||||||
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
||||||
from app.db import get_db
|
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
|
||||||
# 验证语言参数
|
# 验证语言参数
|
||||||
@@ -1178,8 +1176,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
|||||||
if end_user_id:
|
if end_user_id:
|
||||||
try:
|
try:
|
||||||
# 获取数据库会话并查询用户信息
|
# 获取数据库会话并查询用户信息
|
||||||
db = next(get_db())
|
with get_db_context() as db:
|
||||||
try:
|
|
||||||
repo = EndUserRepository(db)
|
repo = EndUserRepository(db)
|
||||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||||
if end_user and end_user.other_name:
|
if end_user and end_user.other_name:
|
||||||
@@ -1187,8 +1184,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
|||||||
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
|
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
|
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")
|
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user