fix(db): fix database connection leak

This commit is contained in:
Eternity
2026-03-06 10:12:21 +08:00
parent f90e102854
commit aaa0410781
12 changed files with 505 additions and 566 deletions

View File

@@ -1,28 +1,29 @@
from typing import List, Optional from typing import List, Optional
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.cache.memory.interest_memory import InterestMemoryCache from app.cache.memory.interest_memory import InterestMemoryCache
from app.celery_app import celery_app from app.celery_app import celery_app
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.rag.llm.cv_model import QWenCV from app.core.rag.llm.cv_model import QWenCV
from app.core.response_utils import fail, success from app.core.response_utils import fail, success
from app.db import get_db from app.db import get_db
from app.dependencies import cur_workspace_access_guard, get_current_user from app.dependencies import cur_workspace_access_guard, get_current_user
from app.models import ModelApiKey from app.models import ModelApiKey
from app.models.user_model import User from app.models.user_model import User
from app.core.memory.agent.utils.session_tools import SessionService from app.repositories import knowledge_repository
from app.core.memory.agent.utils.redis_tool import store
from app.repositories import knowledge_repository, WorkspaceRepository
from app.schemas.memory_agent_schema import UserInput, Write_UserInput from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services import task_service, workspace_service from app.services import task_service, workspace_service
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
load_dotenv() load_dotenv()
api_logger = get_api_logger() api_logger = get_api_logger()
@@ -37,7 +38,7 @@ router = APIRouter(
@router.get("/health/status", response_model=ApiResponse) @router.get("/health/status", response_model=ApiResponse)
async def get_health_status( async def get_health_status(
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Get latest health status written by Celery periodic task Get latest health status written by Celery periodic task
@@ -55,8 +56,9 @@ async def get_health_status(
@router.get("/download_log") @router.get("/download_log")
async def download_log( async def download_log(
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"), log_type: str = Query("file", regex="^(file|transmission)$",
current_user: User = Depends(get_current_user) description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
): ):
""" """
Download or stream agent service log file Download or stream agent service log file
@@ -119,10 +121,10 @@ async def download_log(
@router.post("/writer_service", response_model=ApiResponse) @router.post("/writer_service", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def write_server( async def write_server(
user_input: Write_UserInput, user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Write service endpoint - processes write operations synchronously Write service endpoint - processes write operations synchronously
@@ -161,13 +163,15 @@ async def write_server(
if knowledge: if knowledge:
user_rag_memory_id = str(knowledge.id) user_rag_memory_id = str(knowledge.id)
else: else:
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储") api_logger.warning(
f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j' storage_type = 'neo4j'
else: else:
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j' storage_type = 'neo4j'
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") api_logger.info(
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try: try:
messages_list = memory_agent_service.get_messages_list(user_input) messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory( result = await memory_agent_service.write_memory(
@@ -195,10 +199,10 @@ async def write_server(
@router.post("/writer_service_async", response_model=ApiResponse) @router.post("/writer_service_async", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def write_server_async( async def write_server_async(
user_input: Write_UserInput, user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Async write service endpoint - enqueues write processing to Celery Async write service endpoint - enqueues write processing to Celery
@@ -216,7 +220,8 @@ async def write_server_async(
config_id = user_input.config_id config_id = user_input.config_id
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") api_logger.info(
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
# 获取 storage_type如果为 None 则使用默认值 # 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type( storage_type = workspace_service.get_workspace_storage_type(
@@ -254,9 +259,9 @@ async def write_server_async(
@router.post("/read_service", response_model=ApiResponse) @router.post("/read_service", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def read_server( async def read_server(
user_input: UserInput, user_input: UserInput,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Read service endpoint - processes read operations synchronously Read service endpoint - processes read operations synchronously
@@ -292,7 +297,8 @@ async def read_server(
if knowledge: if knowledge:
user_rag_memory_id = str(knowledge.id) user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") api_logger.info(
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try: try:
result = await memory_agent_service.read_memory( result = await memory_agent_service.read_memory(
user_input.end_user_id, user_input.end_user_id,
@@ -306,7 +312,8 @@ async def read_server(
) )
if str(user_input.search_switch) == "2": if str(user_input.search_switch) == "2":
retrieve_info = result['answer'] retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id) history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
user_input.end_user_id)
query = user_input.message query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案 # 调用 memory_agent_service 的方法生成最终答案
@@ -319,7 +326,7 @@ async def read_server(
db=db db=db
) )
if "信息不足,无法回答" in result['answer']: if "信息不足,无法回答" in result['answer']:
result['answer']=retrieve_info result['answer'] = retrieve_info
return success(data=result, msg="回复对话消息成功") return success(data=result, msg="回复对话消息成功")
except BaseException as e: except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -335,9 +342,10 @@ async def read_server(
@router.post("/file", response_model=ApiResponse) @router.post("/file", response_model=ApiResponse)
async def file_update( async def file_update(
files: List[UploadFile] = File(..., description="要上传的文件"), files: List[UploadFile] = File(..., description="要上传的文件"),
model_id:str = Form(..., description="模型ID"), model_id: str = Form(..., description="模型ID"),
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"), metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
): ):
""" """
文件上传接口 - 支持图片识别 文件上传接口 - 支持图片识别
@@ -350,9 +358,6 @@ async def file_update(
Returns: Returns:
文件处理结果 文件处理结果
""" """
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen)
api_logger.info(f"File upload requested, file count: {len(files)}") api_logger.info(f"File upload requested, file count: {len(files)}")
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
apiConfig: ModelApiKey = config.api_keys[0] apiConfig: ModelApiKey = config.api_keys[0]
@@ -430,8 +435,8 @@ async def read_server_async(
@router.get("/read_result/", response_model=ApiResponse) @router.get("/read_result/", response_model=ApiResponse)
async def get_read_task_result( async def get_read_task_result(
task_id: str, task_id: str,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Get the status and result of an async read task Get the status and result of an async read task
@@ -507,8 +512,8 @@ async def get_read_task_result(
@router.get("/write_result/", response_model=ApiResponse) @router.get("/write_result/", response_model=ApiResponse)
async def get_write_task_result( async def get_write_task_result(
task_id: str, task_id: str,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Get the status and result of an async write task Get the status and result of an async write task
@@ -584,9 +589,9 @@ async def get_write_task_result(
@router.post("/status_type", response_model=ApiResponse) @router.post("/status_type", response_model=ApiResponse)
async def status_type( async def status_type(
user_input: Write_UserInput, user_input: Write_UserInput,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Determine the type of user message (read or write) Determine the type of user message (read or write)
@@ -629,9 +634,10 @@ async def status_type(
@router.get("/stats/types", response_model=ApiResponse) @router.get("/stats/types", response_model=ApiResponse)
async def get_knowledge_type_stats_api( async def get_knowledge_type_stats_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"), end_user_id: Optional[str] = Query(None, description="用户ID可选"),
only_active: bool = Query(True, description="仅统计有效记录(status=1)"), only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
): ):
""" """
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。 统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
@@ -640,14 +646,9 @@ async def get_knowledge_type_stats_api(
- 知识库类型根据当前用户的 current_workspace_id 过滤 - 知识库类型根据当前用户的 current_workspace_id 过滤
- 如果用户没有当前工作空间,对应的统计返回 0 - 如果用户没有当前工作空间,对应的统计返回 0
""" """
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") api_logger.info(
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
try: try:
from app.db import get_db
# 获取数据库会话
db_gen = get_db()
db = next(db_gen)
# 调用service层函数 # 调用service层函数
result = await memory_agent_service.get_knowledge_type_stats( result = await memory_agent_service.get_knowledge_type_stats(
end_user_id=end_user_id, end_user_id=end_user_id,
@@ -664,11 +665,11 @@ async def get_knowledge_type_stats_api(
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse) @router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
async def get_interest_distribution_by_user_api( async def get_interest_distribution_by_user_api(
end_user_id: str = Query(..., description="用户ID必填"), end_user_id: str = Query(..., description="用户ID必填"),
limit: int = Query(5, le=5, description="返回兴趣标签数量限制最多5个"), limit: int = Query(5, le=5, description="返回兴趣标签数量限制最多5个"),
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
""" """
获取指定用户的兴趣分布标签 获取指定用户的兴趣分布标签
@@ -716,9 +717,9 @@ async def get_interest_distribution_by_user_api(
@router.get("/analytics/user_profile", response_model=ApiResponse) @router.get("/analytics/user_profile", response_model=ApiResponse)
async def get_user_profile_api( async def get_user_profile_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"), end_user_id: Optional[str] = Query(None, description="用户ID可选"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
获取用户详情,包含: 获取用户详情,包含:
@@ -782,9 +783,9 @@ async def get_user_profile_api(
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse) @router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
async def get_end_user_connected_config( async def get_end_user_connected_config(
end_user_id: str, end_user_id: str,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
获取终端用户关联的记忆配置 获取终端用户关联的记忆配置

View File

@@ -1,10 +1,10 @@
import os
import json import json
import os
import time import time
from app.core.logging_config import get_agent_logger
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_, PROJECT_ROOT_,
ReadState, ReadState,
@@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
@@ -53,13 +52,14 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
try: try:
# 使用优化的LLM服务 # 使用优化的LLM服务
structured = await problem_service.call_llm_structured( with get_db_context() as db_session:
state=state, structured = await problem_service.call_llm_structured(
db_session=db_session, state=state,
system_prompt=system_prompt, db_session=db_session,
response_model=ProblemExtensionResponse, system_prompt=system_prompt,
fallback_value=[] response_model=ProblemExtensionResponse,
) fallback_value=[]
)
# 添加更详细的日志记录 # 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
@@ -171,13 +171,14 @@ async def Problem_Extension(state: ReadState) -> ReadState:
try: try:
# 使用优化的LLM服务 # 使用优化的LLM服务
response_content = await problem_service.call_llm_structured( with get_db_context() as db_session:
state=state, response_content = await problem_service.call_llm_structured(
db_session=db_session, state=state,
system_prompt=system_prompt, db_session=db_session,
response_model=ProblemExtensionResponse, system_prompt=system_prompt,
fallback_value=[] response_model=ProblemExtensionResponse,
) fallback_value=[]
)
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}") logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")

View File

@@ -6,31 +6,26 @@ import os
# ===== 第三方库 ===== # ===== 第三方库 =====
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.db import get_db, get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
COUNTState,
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.memory.agent.langgraph_graph.tools.tool import ( from app.core.memory.agent.langgraph_graph.tools.tool import (
create_hybrid_retrieval_tool_sync, create_hybrid_retrieval_tool_sync,
create_time_retrieval_tool, create_time_retrieval_tool,
extract_tool_message_content, extract_tool_message_content,
) )
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
db = next(get_db())
async def rag_config(state): async def rag_config(state):
@@ -50,10 +45,12 @@ async def rag_config(state):
"reranker_top_k": 10 "reranker_top_k": 10
} }
return kb_config return kb_config
async def rag_knowledge(state,question):
async def rag_knowledge(state, question):
kb_config = await rag_config(state) kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '') end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id = state.get("user_rag_memory_id", '')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try: try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
@@ -61,13 +58,13 @@ async def rag_knowledge(state,question):
cleaned_query = question cleaned_query = question
raw_results = clean_content raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception : except Exception:
retrieval_knowledge=[] retrieval_knowledge = []
clean_content = '' clean_content = ''
raw_results = '' raw_results = ''
cleaned_query = question cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge,clean_content,cleaned_query,raw_results return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def llm_infomation(state: ReadState) -> ReadState: async def llm_infomation(state: ReadState) -> ReadState:
@@ -141,7 +138,6 @@ async def clean_databases(data) -> str:
elif isinstance(item, str): elif isinstance(item, str):
text_parts.append(item) text_parts.append(item)
return '\n'.join(text_parts).strip() return '\n'.join(text_parts).strip()
except Exception as e: except Exception as e:
@@ -150,23 +146,23 @@ async def clean_databases(data) -> str:
async def retrieve_nodes(state: ReadState) -> ReadState: async def retrieve_nodes(state: ReadState) -> ReadState:
''' '''
模型信息 模型信息
''' '''
problem_extension=state.get('problem_extension', '')['context'] problem_extension = state.get('problem_extension', '')['context']
storage_type=state.get('storage_type', '') storage_type = state.get('storage_type', '')
user_rag_memory_id=state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
end_user_id=state.get('end_user_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
original=state.get('data', '') original = state.get('data', '')
problem_list=[] problem_list = []
for key,values in problem_extension.items(): for key, values in problem_extension.items():
for data in values: for data in values:
problem_list.append(data) problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# 创建异步任务处理单个问题 # 创建异步任务处理单个问题
async def process_question_nodes(idx, question): async def process_question_nodes(idx, question):
try: try:
@@ -244,7 +240,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
send_verify = [] send_verify = []
for i, j in zip(keys, val, strict=False): for i, j in zip(keys, val, strict=False):
if j!=['']: if j != ['']:
send_verify.append({ send_verify.append({
"Query_small": i, "Query_small": i,
"Answer_Small": j "Answer_Small": j
@@ -257,15 +253,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
} }
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve':dup_databases} return {'retrieve': dup_databases}
async def retrieve(state: ReadState) -> ReadState: async def retrieve(state: ReadState) -> ReadState:
# 从state中获取end_user_id # 从state中获取end_user_id
import time import time
start=time.time() start = time.time()
problem_extension = state.get('problem_extension', '')['context'] problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '') storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
@@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState:
with get_db_context() as db: # 使用同步数据库上下文管理器 with get_db_context() as db: # 使用同步数据库上下文管理器
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
return await llm_infomation(state) return await llm_infomation(state)
llm_config = await get_llm_info() llm_config = await get_llm_info()
api_key_obj = llm_config.api_keys[0] api_key_obj = llm_config.api_keys[0]
api_key = api_key_obj.api_key api_key = api_key_obj.api_key
@@ -296,11 +291,11 @@ async def retrieve(state: ReadState) -> ReadState:
) )
time_retrieval_tool = create_time_retrieval_tool(end_user_id) time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = { "end_user_id": end_user_id, "return_raw_results": True } search_params = {"end_user_id": end_user_id, "return_raw_results": True}
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params) hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent( agent = create_agent(
llm, llm,
tools=[time_retrieval_tool,hybrid_retrieval], tools=[time_retrieval_tool, hybrid_retrieval],
system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}" system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
) )
@@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState:
async with SEMAPHORE: # 限制并发 async with SEMAPHORE: # 限制并发
try: try:
if storage_type == "rag" and user_rag_memory_id: if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question) retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
question)
else: else:
cleaned_query = question cleaned_query = question
# 使用 asyncio 在线程池中运行同步的 agent.invoke # 使用 asyncio 在线程池中运行同步的 agent.invoke
@@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState:
# json.dump(dup_databases, f, indent=4) # json.dump(dup_databases, f, indent=4)
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve': dup_databases} return {'retrieve': dup_databases}

View File

@@ -1,5 +1,3 @@
import os import os
import time import time
@@ -18,12 +16,11 @@ from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.db import get_db
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
db_session = next(get_db())
class SummaryNodeService(LLMServiceMixin): class SummaryNodeService(LLMServiceMixin):
"""总结节点服务类""" """总结节点服务类"""
@@ -32,8 +29,11 @@ class SummaryNodeService(LLMServiceMixin):
super().__init__() super().__init__()
self.template_service = TemplateService(template_root) self.template_service = TemplateService(template_root)
# 创建全局服务实例 # 创建全局服务实例
summary_service = SummaryNodeService() summary_service = SummaryNodeService()
async def rag_config(state): async def rag_config(state):
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = { kb_config = {
@@ -51,10 +51,12 @@ async def rag_config(state):
"reranker_top_k": 10 "reranker_top_k": 10
} }
return kb_config return kb_config
async def rag_knowledge(state,question):
async def rag_knowledge(state, question):
kb_config = await rag_config(state) kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '') end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id = state.get("user_rag_memory_id", '')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try: try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
@@ -62,20 +64,23 @@ async def rag_knowledge(state,question):
cleaned_query = question cleaned_query = question
raw_results = clean_content raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception : except Exception:
retrieval_knowledge=[] retrieval_knowledge = []
clean_content = '' clean_content = ''
raw_results = '' raw_results = ''
cleaned_query = question cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge,clean_content,cleaned_query,raw_results return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def summary_history(state: ReadState) -> ReadState: async def summary_history(state: ReadState) -> ReadState:
end_user_id = state.get("end_user_id", '') end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
search_mode) -> str:
""" """
增强的summary_llm函数包含更好的错误处理和数据验证 增强的summary_llm函数包含更好的错误处理和数据验证
""" """
@@ -99,13 +104,14 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
) )
try: try:
# 使用优化的LLM服务进行结构化输出 # 使用优化的LLM服务进行结构化输出
structured = await summary_service.call_llm_structured( with get_db_context() as db_session:
state=state, structured = await summary_service.call_llm_structured(
db_session=db_session, state=state,
system_prompt=system_prompt, db_session=db_session,
response_model=response_model, system_prompt=system_prompt,
fallback_value=None response_model=response_model,
) fallback_value=None
)
# 验证结构化响应 # 验证结构化响应
if structured is None: if structured is None:
logger.warning("LLM返回None使用默认回答") logger.warning("LLM返回None使用默认回答")
@@ -157,7 +163,8 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
logger.error(f"Fallback也失败: {fallback_error}") logger.error(f"Fallback也失败: {fallback_error}")
return "信息不足,无法回答" return "信息不足,无法回答"
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
data = state.get("data", '') data = state.get("data", '')
end_user_id = state.get("end_user_id", '') end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session( await SessionService(store).save_session(
@@ -169,10 +176,12 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
) )
await SessionService(store).cleanup_duplicates() await SessionService(store).cleanup_duplicates()
logger.info(f"sessionid: {aimessages} 写入成功") logger.info(f"sessionid: {aimessages} 写入成功")
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
storage_type=state.get("storage_type",'')
user_rag_memory_id=state.get("user_rag_memory_id",'') async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
data=state.get("data", '') storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
data = state.get("data", '')
input_summary = { input_summary = {
"status": "success", "status": "success",
"summary_result": aimessages, "summary_result": aimessages,
@@ -189,14 +198,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
"user_rag_memory_id": user_rag_memory_id "user_rag_memory_id": user_rag_memory_id
} }
} }
retrieve={ retrieve = {
"status": "success", "status": "success",
"summary_result": aimessages, "summary_result": aimessages,
"storage_type": storage_type, "storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id, "user_rag_memory_id": user_rag_memory_id,
"_intermediate": { "_intermediate": {
"type": "retrieval_summary", "type": "retrieval_summary",
"title":"快速检索", "title": "快速检索",
"summary": aimessages, "summary": aimessages,
"query": data, "query": data,
"storage_type": storage_type, "storage_type": storage_type,
@@ -204,17 +213,18 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
} }
} }
return input_summary,retrieve return input_summary, retrieve
async def Input_Summary(state: ReadState) -> ReadState: async def Input_Summary(state: ReadState) -> ReadState:
start=time.time() start = time.time()
storage_type=state.get("storage_type",'') storage_type = state.get("storage_type", '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id = state.get("user_rag_memory_id", '')
data=state.get("data", '') data = state.get("data", '')
end_user_id=state.get("end_user_id", '') end_user_id = state.get("end_user_id", '')
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
history = await summary_history( state) history = await summary_history(state)
search_params = { search_params = {
"end_user_id": end_user_id, "end_user_id": end_user_id,
"question": data, "question": data,
@@ -223,12 +233,13 @@ async def Input_Summary(state: ReadState) -> ReadState:
} }
try: try:
if storage_type!="rag": if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config) retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
memory_config=memory_config)
else: else:
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data) retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
except Exception as e: except Exception as e:
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True ) logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
retrieve_info, question, raw_results = "", data, [] retrieve_info, question, raw_results = "", data, []
try: try:
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2', # aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
@@ -237,8 +248,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
summary_result = await summary_prompt(state, retrieve_info, retrieve_info) summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
summary = summary_result[0] summary = summary_result[0]
except Exception as e: except Exception as e:
logger.error( f"Input_Summary failed: {e}", exc_info=True ) logger.error(f"Input_Summary failed: {e}", exc_info=True)
summary= { summary = {
"status": "fail", "status": "fail",
"summary_result": "信息不足,无法回答", "summary_result": "信息不足,无法回答",
"storage_type": storage_type, "storage_type": storage_type,
@@ -251,30 +262,31 @@ async def Input_Summary(state: ReadState) -> ReadState:
except Exception: except Exception:
duration = 0.0 duration = 0.0
log_time('检索', duration) log_time('检索', duration)
return {"summary":summary} return {"summary": summary}
async def Retrieve_Summary(state: ReadState)-> ReadState:
retrieve=state.get("retrieve", '') async def Retrieve_Summary(state: ReadState) -> ReadState:
history = await summary_history( state) retrieve = state.get("retrieve", '')
history = await summary_history(state)
import json import json
with open("检索.json","w",encoding='utf-8') as f: with open("检索.json", "w", encoding='utf-8') as f:
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False)) f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
retrieve=retrieve.get("Expansion_issue", []) retrieve = retrieve.get("Expansion_issue", [])
start=time.time() start = time.time()
retrieve_info_str=[] retrieve_info_str = []
for data in retrieve: for data in retrieve:
if data=='': if data == '':
retrieve_info_str='' retrieve_info_str = ''
else: else:
for key, value in data.items(): for key, value in data.items():
if key=='Answer_Small': if key == 'Answer_Small':
for i in value: for i in value:
retrieve_info_str.append(i) retrieve_info_str.append(i)
retrieve_info_str=list(set(retrieve_info_str)) retrieve_info_str = list(set(retrieve_info_str))
retrieve_info_str='\n'.join(retrieve_info_str) retrieve_info_str = '\n'.join(retrieve_info_str)
aimessages=await summary_llm(state,history,retrieve_info_str, aimessages = await summary_llm(state, history, retrieve_info_str,
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1") 'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages) await summary_redis_save(state, aimessages)
if aimessages == '': if aimessages == '':
@@ -290,29 +302,29 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
# 修复协程调用 - 先await然后访问返回值 # 修复协程调用 - 先await然后访问返回值
summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1] summary = summary_result[1]
return {"summary":summary} return {"summary": summary}
async def Summary(state: ReadState)-> ReadState: async def Summary(state: ReadState) -> ReadState:
start=time.time() start = time.time()
query = state.get("data", '') query = state.get("data", '')
verify=state.get("verify", '') verify = state.get("verify", '')
verify_expansion_issue=verify.get("verified_data", '') verify_expansion_issue = verify.get("verified_data", '')
retrieve_info_str='' retrieve_info_str = ''
for data in verify_expansion_issue: for data in verify_expansion_issue:
for key, value in data.items(): for key, value in data.items():
if key=='answer_small': if key == 'answer_small':
for i in value: for i in value:
retrieve_info_str+=i+'\n' retrieve_info_str += i + '\n'
history=await summary_history(state) history = await summary_history(state)
data = { data = {
"query": query, "query": query,
"history": history, "history": history,
"retrieve_info": retrieve_info_str "retrieve_info": retrieve_info_str
} }
aimessages=await summary_llm(state,history,data, aimessages = await summary_llm(state, history, data,
'summary_prompt.jinja2','summary',SummaryResponse,0) 'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages) await summary_redis_save(state, aimessages)
@@ -327,10 +339,12 @@ async def Summary(state: ReadState)-> ReadState:
# 修复协程调用 - 先await然后访问返回值 # 修复协程调用 - 先await然后访问返回值
summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1] summary = summary_result[1]
return {"summary":summary} return {"summary": summary}
async def Summary_fails(state: ReadState)-> ReadState:
storage_type=state.get("storage_type", '')
user_rag_memory_id=state.get("user_rag_memory_id", '') async def Summary_fails(state: ReadState) -> ReadState:
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
history = await summary_history(state) history = await summary_history(state)
query = state.get("data", '') query = state.get("data", '')
verify = state.get("verify", '') verify = state.get("verify", '')
@@ -346,12 +360,12 @@ async def Summary_fails(state: ReadState)-> ReadState:
"history": history, "history": history,
"retrieve_info": retrieve_info_str "retrieve_info": retrieve_info_str
} }
aimessages = await summary_llm(state, history, data, aimessages = await summary_llm(state, history, data,
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0) 'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
result= { result = {
"status": "success", "status": "success",
"summary_result": aimessages, "summary_result": aimessages,
"storage_type": storage_type, "storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id "user_rag_memory_id": user_rag_memory_id
} }
return {"summary":result} return {"summary": result}

View File

@@ -1,8 +1,9 @@
import asyncio
import os import os
from app.core.logging_config import get_agent_logger
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.verification_models import VerificationResult from app.core.memory.agent.models.verification_models import VerificationResult
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_, PROJECT_ROOT_,
ReadState, ReadState,
@@ -10,12 +11,12 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
class VerificationNodeService(LLMServiceMixin): class VerificationNodeService(LLMServiceMixin):
"""验证节点服务类""" """验证节点服务类"""
@@ -23,9 +24,11 @@ class VerificationNodeService(LLMServiceMixin):
super().__init__() super().__init__()
self.template_service = TemplateService(template_root) self.template_service = TemplateService(template_root)
# 创建全局服务实例 # 创建全局服务实例
verification_service = VerificationNodeService() verification_service = VerificationNodeService()
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
"""处理验证结果并生成输出格式""" """处理验证结果并生成输出格式"""
storage_type = state.get('storage_type', '') storage_type = state.get('storage_type', '')
@@ -58,6 +61,8 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
} }
} }
return Verify_result return Verify_result
async def Verify(state: ReadState): async def Verify(state: ReadState):
logger.info("=== Verify 节点开始执行 ===") logger.info("=== Verify 节点开始执行 ===")
try: try:
@@ -71,7 +76,8 @@ async def Verify(state: ReadState):
logger.info(f"Verify: 获取历史记录完成history length={len(history)}") logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", {}) retrieve = state.get("retrieve", {})
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}") logger.info(
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else [] retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}") logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
@@ -100,23 +106,24 @@ async def Verify(state: ReadState):
try: try:
# 添加 asyncio.wait_for 超时包裹,防止无限等待 # 添加 asyncio.wait_for 超时包裹,防止无限等待
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长) # 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
import asyncio
structured = await asyncio.wait_for( with get_db_context() as db_session:
verification_service.call_llm_structured( structured = await asyncio.wait_for(
state=state, verification_service.call_llm_structured(
db_session=db_session, state=state,
system_prompt=system_prompt, db_session=db_session,
response_model=VerificationResult, system_prompt=system_prompt,
fallback_value={ response_model=VerificationResult,
"query": content, fallback_value={
"history": history if isinstance(history, list) else [], "query": content,
"expansion_issue": [], "history": history if isinstance(history, list) else [],
"split_result": "failed", "expansion_issue": [],
"reason": "验证失败或超时" "split_result": "failed",
} "reason": "验证失败或超时"
), }
timeout=150.0 # 150秒超时 ),
) timeout=150.0 # 150秒超时
)
logger.info(f"Verify: LLM 调用完成result={structured}") logger.info(f"Verify: LLM 调用完成result={structured}")
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.error("Verify: LLM 调用超时150秒使用 fallback 值") logger.error("Verify: LLM 调用超时150秒使用 fallback 值")

View File

@@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage
from langgraph.constants import START, END from langgraph.constants import START, END
from langgraph.graph import StateGraph from langgraph.graph import StateGraph
from app.db import get_db from app.db import get_db
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
@@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
) )
@asynccontextmanager @asynccontextmanager
async def make_read_graph(): async def make_read_graph():
"""创建并返回 LangGraph 工作流""" """创建并返回 LangGraph 工作流"""
@@ -62,7 +60,6 @@ async def make_read_graph():
workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END) workflow.add_edge("Summary", END)
'''-----''' '''-----'''
# workflow.add_edge("Retrieve", END) # workflow.add_edge("Retrieve", END)
@@ -76,6 +73,7 @@ async def make_read_graph():
finally: finally:
print("工作流创建完成") print("工作流创建完成")
async def main(): async def main():
"""主函数 - 运行工作流""" """主函数 - 运行工作流"""
message = "昨天有什么好看的电影" message = "昨天有什么好看的电影"
@@ -92,13 +90,15 @@ async def main():
service_name="MemoryAgentService" service_name="MemoryAgentService"
) )
import time import time
start=time.time() start = time.time()
try: try:
async with make_read_graph() as graph: async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}} config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config} "end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息 # 获取节点更新信息
_intermediate_outputs = [] _intermediate_outputs = []
summary = '' summary = ''
@@ -141,7 +141,6 @@ async def main():
if verify_n and verify_n != [] and verify_n != {}: if verify_n and verify_n != [] and verify_n != {}:
_intermediate_outputs.append(verify_n) _intermediate_outputs.append(verify_n)
# Summary 节点 # Summary 节点
summary_n = node_data.get('summary', {}).get('_intermediate', None) summary_n = node_data.get('summary', {}).get('_intermediate', None)
if summary_n and summary_n != [] and summary_n != {}: if summary_n and summary_n != [] and summary_n != {}:
@@ -165,13 +164,16 @@ async def main():
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
finally:
db_session.close()
end=time.time() end = time.time()
print(100*'y') print(100 * 'y')
print(f"总耗时: {end-start}s") print(f"总耗时: {end - start}s")
print(100*'y') print(100 * 'y')
if __name__ == "__main__": if __name__ == "__main__":
import asyncio import asyncio
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,56 +0,0 @@
import asyncio
from typing import Dict, Optional
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
from app.db import get_db
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
class LLMClientPool:
"""LLM客户端连接池"""
def __init__(self, max_size: int = 5):
self.max_size = max_size
self.pools: Dict[str, asyncio.Queue] = {}
self.active_clients: Dict[str, int] = {}
async def get_client(self, llm_model_id: str):
"""获取LLM客户端"""
if llm_model_id not in self.pools:
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
self.active_clients[llm_model_id] = 0
pool = self.pools[llm_model_id]
try:
# 尝试从池中获取客户端
client = pool.get_nowait()
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
return client
except asyncio.QueueEmpty:
# 池为空,创建新客户端
if self.active_clients[llm_model_id] < self.max_size:
db_session = next(get_db())
client = get_llm_client_fast(llm_model_id, db_session)
self.active_clients[llm_model_id] += 1
logger.debug(f"创建新LLM客户端: {llm_model_id}")
return client
else:
# 等待可用客户端
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
return await pool.get()
async def return_client(self, llm_model_id: str, client):
"""归还LLM客户端到池中"""
if llm_model_id in self.pools:
try:
self.pools[llm_model_id].put_nowait(client)
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
except asyncio.QueueFull:
# 池已满,丢弃客户端
self.active_clients[llm_model_id] -= 1
logger.debug(f"池已满丢弃LLM客户端: {llm_model_id}")
# 全局客户端池
llm_client_pool = LLMClientPool()

View File

@@ -14,7 +14,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db from app.db import get_db_context
from app.models import AppRelease from app.models import AppRelease
from app.services.draft_run_service import AgentRunService from app.services.draft_run_service import AgentRunService
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING} return {"output": VariableType.STRING}
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]: def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]:
"""准备 Agent公共逻辑 """准备 Agent公共逻辑
Args: Args:
@@ -57,17 +57,17 @@ class AgentNode(BaseNode):
if not agent_id: if not agent_id:
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置") raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
db = next(get_db()) with get_db_context() as db:
release = db.query(AppRelease).filter( release = db.query(AppRelease).filter(
AppRelease.id == agent_id AppRelease.id == agent_id
).first() ).first()
if not release: if not release:
raise ValueError(f"Agent 不存在: {agent_id}") raise ValueError(f"Agent 不存在: {agent_id}")
draft_service = AgentRunService(db)
return draft_service, release, message
return release, message
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""非流式执行 """非流式执行
@@ -79,19 +79,21 @@ class AgentNode(BaseNode):
Returns: Returns:
状态更新字典 状态更新字典
""" """
draft_service, release, message = self._prepare_agent(variable_pool) release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)") logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
with get_db_context() as db:
draft_service = AgentRunService(db)
# 执行 Agent非流式 # 执行 Agent非流式
result = await draft_service.run( result = await draft_service.run(
agent_config=release.config, agent_config=release.config,
model_config=None, model_config=None,
message=message, message=message,
workspace_id=variable_pool.get_value("sys.workspace_id"), workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"), user_id=state.get("user_id"),
variables=variable_pool.get_all_conversation_vars() variables=variable_pool.get_all_conversation_vars()
) )
response = result.get("response", "") response = result.get("response", "")
@@ -118,34 +120,35 @@ class AgentNode(BaseNode):
Yields: Yields:
流式事件字典 流式事件字典
""" """
draft_service, release, message = self._prepare_agent(variable_pool) release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)") logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
# 累积完整响应 # 累积完整响应
full_response = "" full_response = ""
with get_db_context() as db:
draft_service = AgentRunService(db)
# 执行 Agent流式 # 执行 Agent流式
async for chunk in draft_service.run_stream( async for chunk in draft_service.run_stream(
agent_config=release.config, agent_config=release.config,
model_config=None, model_config=None,
message=message, message=message,
workspace_id=variable_pool.get_value("sys.workspace_id"), workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"), user_id=state.get("user_id"),
variables=variable_pool.get_all_conversation_vars() variables=variable_pool.get_all_conversation_vars()
): ):
# 提取内容 # 提取内容
content = chunk.get("content", "") content = chunk.get("content", "")
full_response += content full_response += content
# 流式返回每个 chunk # 流式返回每个 chunk
yield { yield {
"type": "chunk", "type": "chunk",
"node_id": self.node_id, "node_id": self.node_id,
"content": content, "content": content,
"full_content": full_response, "full_content": full_response,
"meta_data": chunk.get("meta_data", {}) "meta_data": chunk.get("meta_data", {})
} }
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}") logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")

View File

@@ -22,6 +22,7 @@ from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.models import AgentConfig, ModelConfig from app.models import AgentConfig, ModelConfig
from app.repositories.tool_repository import ToolRepository from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput from app.schemas.app_schema import FileInput
@@ -103,9 +104,7 @@ def create_long_term_memory_tool(
""" """
logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}") logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}")
try: try:
from app.db import get_db with get_db_context() as db:
db = next(get_db())
try:
memory_content = asyncio.run( memory_content = asyncio.run(
MemoryAgentService().read_memory( MemoryAgentService().read_memory(
end_user_id=end_user_id, end_user_id=end_user_id,
@@ -127,9 +126,6 @@ def create_long_term_memory_tool(
logger.info(f"读取任务状态:{status}") logger.info(f"读取任务状态:{status}")
if memory_content: if memory_content:
memory_content = memory_content['answer'] memory_content = memory_content['answer']
finally:
db.close()
logger.info(f'用户IDAgent:{end_user_id}') logger.info(f'用户IDAgent:{end_user_id}')
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id}) logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})

View File

@@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config
""" """
import json import json
import os import os
import re
import time import time
import uuid import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
@@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import (
reorder_output_results, reorder_output_results,
) )
from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError from app.schemas.memory_config_schema import ConfigurationError
@@ -69,7 +66,8 @@ class MemoryAgentService:
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作 # 记录成功的操作
if audit_logger: if audit_logger:
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True, audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=True,
duration=duration, details={"message_length": len(message)}) duration=duration, details={"message_length": len(message)})
return context return context
else: else:
@@ -88,8 +86,6 @@ class MemoryAgentService:
raise ValueError(f"写入失败: {messages}") raise ValueError(f"写入失败: {messages}")
def extract_tool_call_info(self, event: Dict) -> bool: def extract_tool_call_info(self, event: Dict) -> bool:
"""Extract tool call information from event""" """Extract tool call information from event"""
last_message = event["messages"][-1] last_message = event["messages"][-1]
@@ -271,7 +267,8 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources") logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic # LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str: async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
""" """
Process write operation with config_id Process write operation with config_id
@@ -300,7 +297,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id") config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None: if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e: except Exception as e:
if "No memory configuration found" in str(e): if "No memory configuration found" in str(e):
raise # Re-raise our specific error raise # Re-raise our specific error
@@ -331,7 +329,8 @@ class MemoryAgentService:
# Log failed operation # Log failed operation
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
@@ -351,9 +350,9 @@ class MemoryAgentService:
langchain_messages.append(HumanMessage(content=msg['content'])) langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant': elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content'])) langchain_messages.append(AIMessage(content=msg['content']))
print(100*'-') print(100 * '-')
print(langchain_messages) print(langchain_messages)
print(100*'-') print(100 * '-')
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = { initial_state = {
"messages": langchain_messages, "messages": langchain_messages,
@@ -375,29 +374,28 @@ class MemoryAgentService:
contents = massages.get('write_result') contents = massages.get('write_result')
# Convert messages back to string for logging # Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
contents)
except Exception as e: except Exception as e:
# Ensure proper error handling and logging # Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}" error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
async def read_memory( async def read_memory(
self, self,
end_user_id: str, end_user_id: str,
message: str, message: str,
history: List[Dict], history: List[Dict],
search_switch: str, search_switch: str,
config_id: Optional[uuid.UUID]|int, config_id: Optional[uuid.UUID] | int,
db: Session, db: Session,
storage_type: str, storage_type: str,
user_rag_memory_id: str) -> Dict: user_rag_memory_id: str) -> Dict:
""" """
Process read operation with config_id Process read operation with config_id
@@ -425,7 +423,7 @@ class MemoryAgentService:
import time import time
start_time = time.time() start_time = time.time()
ori_message= message ori_message = message
# Resolve config_id and workspace_id # Resolve config_id and workspace_id
# Always get workspace_id from end_user for fallback, even if config_id is provided # Always get workspace_id from end_user for fallback, even if config_id is provided
@@ -437,7 +435,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id") config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None: if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e: except Exception as e:
if "No memory configuration found" in str(e): if "No memory configuration found" in str(e):
raise # Re-raise our specific error raise # Re-raise our specific error
@@ -454,7 +453,6 @@ class MemoryAgentService:
except ImportError: except ImportError:
audit_logger = None audit_logger = None
config_load_start = time.time() config_load_start = time.time()
try: try:
# Use a separate database session to avoid transaction failures # Use a separate database session to avoid transaction failures
@@ -576,7 +574,8 @@ class MemoryAgentService:
raw_results = intermediate.get('raw_results', {}) raw_results = intermediate.get('raw_results', {})
try: try:
reranked_results = raw_results.get('reranked_results', []) reranked_results = raw_results.get('reranked_results', [])
statements = [statement['statement'] for statement in reranked_results.get('statements', [])] statements = [statement['statement'] for statement in
reranked_results.get('statements', [])]
except Exception: except Exception:
statements = [] statements = []
@@ -602,7 +601,8 @@ class MemoryAgentService:
) )
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}") logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
else: else:
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}") logger.debug(
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
except Exception as save_error: except Exception as save_error:
# 保存失败不应该影响主流程,只记录错误 # 保存失败不应该影响主流程,只记录错误
@@ -610,7 +610,8 @@ class MemoryAgentService:
# Log successful operation # Log successful operation
total_time = time.time() - start_time total_time = time.time() - start_time
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") logger.info(
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation( audit_logger.log_operation(
@@ -641,7 +642,6 @@ class MemoryAgentService:
) )
raise ValueError(error_msg) raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]: def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
""" """
Get standardized message list from user input. Get standardized message list from user input.
@@ -665,7 +665,8 @@ class MemoryAgentService:
for idx, msg in enumerate(user_input.messages): for idx, msg in enumerate(user_input.messages):
if not isinstance(msg, dict): if not isinstance(msg, dict):
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}") logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}") raise ValueError(
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
if 'role' not in msg: if 'role' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}") logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
@@ -673,7 +674,8 @@ class MemoryAgentService:
if 'content' not in msg: if 'content' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}") logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}") raise ValueError(
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
if msg['role'] not in ['user', 'assistant']: if msg['role'] not in ['user', 'assistant']:
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}") logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
@@ -687,11 +689,11 @@ class MemoryAgentService:
return user_input.messages return user_input.messages
async def classify_message_type( async def classify_message_type(
self, self,
message: str, message: str,
config_id: UUID, config_id: UUID,
db: Session, db: Session,
workspace_id: Optional[UUID] = None workspace_id: Optional[UUID] = None
) -> Dict: ) -> Dict:
""" """
Determine the type of user message (read or write) Determine the type of user message (read or write)
@@ -719,14 +721,15 @@ class MemoryAgentService:
status = await status_typle(message, memory_config.llm_model_id) status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}") logger.debug(f"Message type: {status}")
return status return status
async def generate_summary_from_retrieve( async def generate_summary_from_retrieve(
self, self,
end_user_id: str, end_user_id: str,
retrieve_info: str, retrieve_info: str,
history: List[Dict], history: List[Dict],
query: str, query: str,
config_id: str, config_id: str,
db: Session db: Session
) -> str: ) -> str:
""" """
基于检索信息、历史对话和查询生成最终答案 基于检索信息、历史对话和查询生成最终答案
@@ -805,13 +808,12 @@ class MemoryAgentService:
logger.error(f"生成摘要失败: {str(e)}", exc_info=True) logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
return "信息不足,无法回答。" return "信息不足,无法回答。"
async def get_knowledge_type_stats( async def get_knowledge_type_stats(
self, self,
end_user_id: Optional[str] = None, db: Session,
only_active: bool = True, end_user_id: Optional[str] = None,
current_workspace_id: Optional[uuid.UUID] = None, only_active: bool = True,
db: Session = None current_workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
统计知识库类型分布,包含: 统计知识库类型分布,包含:
@@ -837,11 +839,6 @@ class MemoryAgentService:
# 1. 统计 PostgreSQL 中的知识库类型 # 1. 统计 PostgreSQL 中的知识库类型
try: try:
if db is None:
from app.db import get_db
db_gen = get_db()
db = next(db_gen)
# 初始化所有标准类型为 0 # 初始化所有标准类型为 0
for kb_type in KnowledgeType: for kb_type in KnowledgeType:
result[kb_type.value] = 0 result[kb_type.value] = 0
@@ -881,21 +878,19 @@ class MemoryAgentService:
# 3. 计算知识库类型总和(不包括 memory # 3. 计算知识库类型总和(不包括 memory
result["total"] = ( result["total"] = (
result.get("General", 0) + result.get("General", 0) +
result.get("Web", 0) + result.get("Web", 0) +
result.get("Third-party", 0) + result.get("Third-party", 0) +
result.get("Folder", 0) result.get("Folder", 0)
) )
return result return result
async def get_interest_distribution_by_user( async def get_interest_distribution_by_user(
self, self,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 5, limit: int = 5,
language: str = "zh" language: str = "zh"
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
获取指定用户的兴趣分布标签。 获取指定用户的兴趣分布标签。
@@ -921,13 +916,12 @@ class MemoryAgentService:
logger.error(f"兴趣分布标签查询失败: {e}") logger.error(f"兴趣分布标签查询失败: {e}")
raise Exception(f"兴趣分布标签查询失败: {e}") raise Exception(f"兴趣分布标签查询失败: {e}")
async def get_user_profile( async def get_user_profile(
self, self,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None, current_user_id: Optional[str] = None,
llm_id: Optional[str] = None, llm_id: Optional[str] = None,
db: Session = None db: Session = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
获取用户详情,包含: 获取用户详情,包含:
@@ -1017,7 +1011,8 @@ class MemoryAgentService:
# 定义标签提取的结构 # 定义标签提取的结构
class UserTags(BaseModel): class UserTags(BaseModel):
tags: list[str] = Field(..., description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友") tags: list[str] = Field(...,
description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友")
messages = [ messages = [
{ {
@@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
ValueError: 当终端用户不存在或应用未发布时 ValueError: 当终端用户不存在或应用未发布时
""" """
import json as json_module import json as json_module
import uuid
from sqlalchemy import select from sqlalchemy import select
@@ -1268,7 +1262,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
"workspace_id": str(app.workspace_id) "workspace_id": str(app.workspace_id)
} }
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}") logger.info(
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
return result return result

View File

@@ -1,45 +1,42 @@
# 修改 memory_konwledges_server.py 文件 # 修改 memory_konwledges_server.py 文件
import asyncio
import os import os
import re
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from fastapi import HTTPException, status
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.config import settings
from app.core.logging_config import get_api_logger
from app.core.rag.models.chunk import DocumentChunk from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.response_utils import success from app.core.response_utils import success
from app.db import get_db from app.db import get_db_context
from app.schemas import file_schema, document_schema
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
from app.models.document_model import Document from app.models.document_model import Document
import uuid
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.core.config import settings
from app.models.user_model import User from app.models.user_model import User
from app.schemas import file_schema, document_schema
from app.schemas.file_schema import CustomTextFileCreate from app.schemas.file_schema import CustomTextFileCreate
from app.services import document_service, file_service, knowledge_service from app.services import document_service, file_service, knowledge_service
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.schemas.file_schema import CustomTextFileCreate
from app.db import get_db
# 创建一个简单的用户类用于测试 # 创建一个简单的用户类用于测试
api_logger = get_api_logger() api_logger = get_api_logger()
class ChunkCreate(BaseModel): class ChunkCreate(BaseModel):
content: str content: str
class SimpleUser: class SimpleUser:
def __init__(self, user_id: str): def __init__(self, user_id: str):
# 确保ID是UUID类型 # 确保ID是UUID类型
self.id = user_id self.id = user_id
self.username = user_id self.username = user_id
'''解析'''
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User): async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
""" """
解析指定文档 解析指定文档
@@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}") api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
raise raise
'''获取块ID'''
async def get_document_chunks( async def get_document_chunks(
kb_id: uuid.UUID, kb_id: uuid.UUID,
document_id: uuid.UUID, document_id: uuid.UUID,
@@ -198,7 +195,7 @@ async def get_document_chunks(
return success(data=result, msg="文档块列表查询成功") return success(data=result, msg="文档块列表查询成功")
'''查找文档ID'''
def find_document_id_by_kb_and_filename( def find_document_id_by_kb_and_filename(
db: Session, db: Session,
kb_id: str, kb_id: str,
@@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename(
except Exception as e: except Exception as e:
return None return None
'''获取知识库ID'''
def find_documents_by_kb_id( def find_documents_by_kb_id(
db: Session, db: Session,
kb_id: str, kb_id: str,
@@ -268,18 +265,14 @@ def find_documents_by_kb_id(
except Exception as e: except Exception as e:
return [] return []
''''上传文件'''
async def memory_konwledges_up( async def memory_konwledges_up(
kb_id: str, kb_id: str,
parent_id: str, parent_id: str,
create_data: file_schema.CustomTextFileCreate, create_data: file_schema.CustomTextFileCreate,
db: Session = Depends(get_db), db: Session,
current_user: SimpleUser = None, # 修改为SimpleUser current_user: SimpleUser,
): ):
# 如果没有提供current_user则创建一个默认的
if current_user is None:
current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60")
content_bytes = create_data.content.encode('utf-8') content_bytes = create_data.content.encode('utf-8')
file_size = len(content_bytes) file_size = len(content_bytes)
print(f"file size: {file_size} byte") print(f"file size: {file_size} byte")
@@ -350,8 +343,6 @@ async def memory_konwledges_up(
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful") return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
'''添加新块'''
async def create_document_chunk( async def create_document_chunk(
kb_id: uuid.UUID, kb_id: uuid.UUID,
@@ -450,6 +441,7 @@ async def create_document_chunk(
return success(data=chunk, msg="文档块创建成功") return success(data=chunk, msg="文档块创建成功")
async def write_rag(end_user_id, message, user_rag_memory_id): async def write_rag(end_user_id, message, user_rag_memory_id):
""" """
将消息写入 RAG 知识库 将消息写入 RAG 知识库
@@ -483,15 +475,12 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
detail=f"知识库ID格式无效: {user_rag_memory_id}" detail=f"知识库ID格式无效: {user_rag_memory_id}"
) )
db_gen = get_db() with get_db_context() as db:
db = next(db_gen)
try:
create_data = CustomTextFileCreate(title=end_user_id, content=message) create_data = CustomTextFileCreate(title=end_user_id, content=message)
current_user = SimpleUser(user_rag_memory_id) current_user = SimpleUser(user_rag_memory_id)
# 检查文档是否已存在 # 检查文档是否已存在
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt") document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
print('======',document) print('======', document)
api_logger.info(f"查找文档结果: document_id={document}") api_logger.info(f"查找文档结果: document_id={document}")
if document is not None: if document is not None:
# 文档已存在,直接添加新块 # 文档已存在,直接添加新块
@@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
else: else:
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}") api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
return result return result
finally:
# 确保数据库会话被关闭
db.close()

View File

@@ -21,8 +21,7 @@ from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.cypher_queries import Graph_Node_query from app.repositories.neo4j.cypher_queries import Graph_Node_query
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
from app.services.implicit_memory_service import ImplicitMemoryService from app.services.memory_base_service import MemoryBaseService
from app.services.memory_base_service import MemoryBaseService, MemoryTransService
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService from app.services.memory_short_service import ShortService
@@ -1167,7 +1166,6 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
from app.core.language_utils import validate_language from app.core.language_utils import validate_language
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository from app.repositories.end_user_repository import EndUserRepository
# 验证语言参数 # 验证语言参数
@@ -1178,8 +1176,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
if end_user_id: if end_user_id:
try: try:
# 获取数据库会话并查询用户信息 # 获取数据库会话并查询用户信息
db = next(get_db()) with get_db_context() as db:
try:
repo = EndUserRepository(db) repo = EndUserRepository(db)
end_user = repo.get_by_id(uuid.UUID(end_user_id)) end_user = repo.get_by_id(uuid.UUID(end_user_id))
if end_user and end_user.other_name: if end_user and end_user.other_name:
@@ -1187,8 +1184,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}") logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
else: else:
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}") logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
finally:
db.close()
except Exception as e: except Exception as e:
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}") logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")