Merge pull request #478 from SuanmoSuanyangTechnology/fix/db-connect-leak

fix(db): fix database connection leak
This commit is contained in:
Mark
2026-03-06 10:40:35 +08:00
committed by GitHub
12 changed files with 505 additions and 566 deletions

View File

@@ -1,28 +1,29 @@
from typing import List, Optional
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.cache.memory.interest_memory import InterestMemoryCache
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.rag.llm.cv_model import QWenCV
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import cur_workspace_access_guard, get_current_user
from app.models import ModelApiKey
from app.models.user_model import User
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.redis_tool import store
from app.repositories import knowledge_repository, WorkspaceRepository
from app.repositories import knowledge_repository
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse
from app.services import task_service, workspace_service
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
load_dotenv()
api_logger = get_api_logger()
@@ -55,7 +56,8 @@ async def get_health_status(
@router.get("/download_log")
async def download_log(
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
log_type: str = Query("file", regex="^(file|transmission)$",
description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
):
"""
@@ -161,13 +163,15 @@ async def write_server(
if knowledge:
user_rag_memory_id = str(knowledge.id)
else:
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
api_logger.warning(
f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j'
else:
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j'
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
api_logger.info(
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try:
messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory(
@@ -216,7 +220,8 @@ async def write_server_async(
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
api_logger.info(
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
@@ -292,7 +297,8 @@ async def read_server(
if knowledge:
user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
api_logger.info(
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try:
result = await memory_agent_service.read_memory(
user_input.end_user_id,
@@ -306,7 +312,8 @@ async def read_server(
)
if str(user_input.search_switch) == "2":
retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
user_input.end_user_id)
query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案
@@ -337,7 +344,8 @@ async def file_update(
files: List[UploadFile] = File(..., description="要上传的文件"),
model_id: str = Form(..., description="模型ID"),
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
文件上传接口 - 支持图片识别
@@ -350,9 +358,6 @@ async def file_update(
Returns:
文件处理结果
"""
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen)
api_logger.info(f"File upload requested, file count: {len(files)}")
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
apiConfig: ModelApiKey = config.api_keys[0]
@@ -631,7 +636,8 @@ async def status_type(
async def get_knowledge_type_stats_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
@@ -640,14 +646,9 @@ async def get_knowledge_type_stats_api(
- 知识库类型根据当前用户的 current_workspace_id 过滤
- 如果用户没有当前工作空间,对应的统计返回 0
"""
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
api_logger.info(
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
try:
from app.db import get_db
# 获取数据库会话
db_gen = get_db()
db = next(db_gen)
# 调用service层函数
result = await memory_agent_service.get_knowledge_type_stats(
end_user_id=end_user_id,

View File

@@ -1,10 +1,10 @@
import os
import json
import os
import time
from app.core.logging_config import get_agent_logger
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
@@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__)
@@ -53,6 +52,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
try:
# 使用优化的LLM服务
with get_db_context() as db_session:
structured = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
@@ -171,6 +171,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
try:
# 使用优化的LLM服务
with get_db_context() as db_session:
response_content = await problem_service.call_llm_structured(
state=state,
db_session=db_session,

View File

@@ -6,31 +6,26 @@ import os
# ===== 第三方库 =====
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from app.core.logging_config import get_agent_logger
from app.db import get_db, get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
COUNTState,
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.memory.agent.langgraph_graph.tools.tool import (
create_hybrid_retrieval_tool_sync,
create_time_retrieval_tool,
extract_tool_message_content,
)
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
logger = get_agent_logger(__name__)
db = next(get_db())
async def rag_config(state):
@@ -50,6 +45,8 @@ async def rag_config(state):
"reranker_top_k": 10
}
return kb_config
async def rag_knowledge(state, question):
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
@@ -141,7 +138,6 @@ async def clean_databases(data) -> str:
elif isinstance(item, str):
text_parts.append(item)
return '\n'.join(text_parts).strip()
except Exception as e:
@@ -150,7 +146,6 @@ async def clean_databases(data) -> str:
async def retrieve_nodes(state: ReadState) -> ReadState:
'''
模型信息
@@ -167,6 +162,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
for data in values:
problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# 创建异步任务处理单个问题
async def process_question_nodes(idx, question):
try:
@@ -260,8 +256,6 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
return {'retrieve': dup_databases}
async def retrieve(state: ReadState) -> ReadState:
# 从state中获取end_user_id
import time
@@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState:
with get_db_context() as db: # 使用同步数据库上下文管理器
config_service = MemoryConfigService(db)
return await llm_infomation(state)
llm_config = await get_llm_info()
api_key_obj = llm_config.api_keys[0]
api_key = api_key_obj.api_key
@@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState:
async with SEMAPHORE: # 限制并发
try:
if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
question)
else:
cleaned_query = question
# 使用 asyncio 在线程池中运行同步的 agent.invoke
@@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState:
# json.dump(dup_databases, f, indent=4)
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve': dup_databases}

View File

@@ -1,5 +1,3 @@
import os
import time
@@ -18,12 +16,11 @@ from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__)
db_session = next(get_db())
class SummaryNodeService(LLMServiceMixin):
"""总结节点服务类"""
@@ -32,8 +29,11 @@ class SummaryNodeService(LLMServiceMixin):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
summary_service = SummaryNodeService()
async def rag_config(state):
user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = {
@@ -51,6 +51,8 @@ async def rag_config(state):
"reranker_top_k": 10
}
return kb_config
async def rag_knowledge(state, question):
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
@@ -70,12 +72,15 @@ async def rag_knowledge(state,question):
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def summary_history(state: ReadState) -> ReadState:
end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
search_mode) -> str:
"""
增强的summary_llm函数包含更好的错误处理和数据验证
"""
@@ -99,6 +104,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
)
try:
# 使用优化的LLM服务进行结构化输出
with get_db_context() as db_session:
structured = await summary_service.call_llm_structured(
state=state,
db_session=db_session,
@@ -157,6 +163,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
logger.error(f"Fallback也失败: {fallback_error}")
return "信息不足,无法回答"
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
data = state.get("data", '')
end_user_id = state.get("end_user_id", '')
@@ -169,6 +176,8 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
)
await SessionService(store).cleanup_duplicates()
logger.info(f"sessionid: {aimessages} 写入成功")
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
@@ -206,6 +215,7 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
return input_summary, retrieve
async def Input_Summary(state: ReadState) -> ReadState:
start = time.time()
storage_type = state.get("storage_type", '')
@@ -224,7 +234,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
try:
if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
memory_config=memory_config)
else:
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
except Exception as e:
@@ -253,6 +264,7 @@ async def Input_Summary(state: ReadState) -> ReadState:
log_time('检索', duration)
return {"summary": summary}
async def Retrieve_Summary(state: ReadState) -> ReadState:
retrieve = state.get("retrieve", '')
history = await summary_history(state)
@@ -328,6 +340,8 @@ async def Summary(state: ReadState)-> ReadState:
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary": summary}
async def Summary_fails(state: ReadState) -> ReadState:
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')

View File

@@ -1,8 +1,9 @@
import asyncio
import os
from app.core.logging_config import get_agent_logger
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.verification_models import VerificationResult
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
@@ -10,12 +11,12 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__)
class VerificationNodeService(LLMServiceMixin):
"""验证节点服务类"""
@@ -23,9 +24,11 @@ class VerificationNodeService(LLMServiceMixin):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
verification_service = VerificationNodeService()
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
"""处理验证结果并生成输出格式"""
storage_type = state.get('storage_type', '')
@@ -58,6 +61,8 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
}
}
return Verify_result
async def Verify(state: ReadState):
logger.info("=== Verify 节点开始执行 ===")
try:
@@ -71,7 +76,8 @@ async def Verify(state: ReadState):
logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", {})
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
logger.info(
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
@@ -100,7 +106,8 @@ async def Verify(state: ReadState):
try:
# 添加 asyncio.wait_for 超时包裹,防止无限等待
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
import asyncio
with get_db_context() as db_session:
structured = await asyncio.wait_for(
verification_service.call_llm_structured(
state=state,

View File

@@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph
from app.db import get_db
from app.services.memory_config_service import MemoryConfigService
@@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
)
@asynccontextmanager
async def make_read_graph():
"""创建并返回 LangGraph 工作流"""
@@ -62,7 +60,6 @@ async def make_read_graph():
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
'''-----'''
# workflow.add_edge("Retrieve", END)
@@ -76,6 +73,7 @@ async def make_read_graph():
finally:
print("工作流创建完成")
async def main():
"""主函数 - 运行工作流"""
message = "昨天有什么好看的电影"
@@ -97,8 +95,10 @@ async def main():
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
"end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息
_intermediate_outputs = []
summary = ''
@@ -141,7 +141,6 @@ async def main():
if verify_n and verify_n != [] and verify_n != {}:
_intermediate_outputs.append(verify_n)
# Summary 节点
summary_n = node_data.get('summary', {}).get('_intermediate', None)
if summary_n and summary_n != [] and summary_n != {}:
@@ -165,6 +164,8 @@ async def main():
except Exception as e:
import traceback
traceback.print_exc()
finally:
db_session.close()
end = time.time()
print(100 * 'y')
@@ -174,4 +175,5 @@ async def main():
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -1,56 +0,0 @@
import asyncio
from typing import Dict, Optional
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
from app.db import get_db
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
class LLMClientPool:
"""LLM客户端连接池"""
def __init__(self, max_size: int = 5):
self.max_size = max_size
self.pools: Dict[str, asyncio.Queue] = {}
self.active_clients: Dict[str, int] = {}
async def get_client(self, llm_model_id: str):
"""获取LLM客户端"""
if llm_model_id not in self.pools:
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
self.active_clients[llm_model_id] = 0
pool = self.pools[llm_model_id]
try:
# 尝试从池中获取客户端
client = pool.get_nowait()
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
return client
except asyncio.QueueEmpty:
# 池为空,创建新客户端
if self.active_clients[llm_model_id] < self.max_size:
db_session = next(get_db())
client = get_llm_client_fast(llm_model_id, db_session)
self.active_clients[llm_model_id] += 1
logger.debug(f"创建新LLM客户端: {llm_model_id}")
return client
else:
# 等待可用客户端
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
return await pool.get()
async def return_client(self, llm_model_id: str, client):
"""归还LLM客户端到池中"""
if llm_model_id in self.pools:
try:
self.pools[llm_model_id].put_nowait(client)
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
except asyncio.QueueFull:
# 池已满,丢弃客户端
self.active_clients[llm_model_id] -= 1
logger.debug(f"池已满丢弃LLM客户端: {llm_model_id}")
# 全局客户端池
llm_client_pool = LLMClientPool()

View File

@@ -14,7 +14,7 @@ from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db
from app.db import get_db_context
from app.models import AppRelease
from app.services.draft_run_service import AgentRunService
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING}
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]:
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]:
"""准备 Agent公共逻辑
Args:
@@ -57,7 +57,7 @@ class AgentNode(BaseNode):
if not agent_id:
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
db = next(get_db())
with get_db_context() as db:
release = db.query(AppRelease).filter(
AppRelease.id == agent_id
).first()
@@ -65,9 +65,9 @@ class AgentNode(BaseNode):
if not release:
raise ValueError(f"Agent 不存在: {agent_id}")
draft_service = AgentRunService(db)
return draft_service, release, message
return release, message
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""非流式执行
@@ -79,9 +79,11 @@ class AgentNode(BaseNode):
Returns:
状态更新字典
"""
draft_service, release, message = self._prepare_agent(variable_pool)
release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
with get_db_context() as db:
draft_service = AgentRunService(db)
# 执行 Agent非流式
result = await draft_service.run(
@@ -118,13 +120,14 @@ class AgentNode(BaseNode):
Yields:
流式事件字典
"""
draft_service, release, message = self._prepare_agent(variable_pool)
release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
# 累积完整响应
full_response = ""
with get_db_context() as db:
draft_service = AgentRunService(db)
# 执行 Agent流式
async for chunk in draft_service.run_stream(
agent_config=release.config,

View File

@@ -22,6 +22,7 @@ from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.models import AgentConfig, ModelConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput
@@ -103,9 +104,7 @@ def create_long_term_memory_tool(
"""
logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}")
try:
from app.db import get_db
db = next(get_db())
try:
with get_db_context() as db:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
end_user_id=end_user_id,
@@ -127,9 +126,6 @@ def create_long_term_memory_tool(
logger.info(f"读取任务状态:{status}")
if memory_content:
memory_content = memory_content['answer']
finally:
db.close()
logger.info(f'用户IDAgent:{end_user_id}')
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})

View File

@@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config
"""
import json
import os
import re
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
@@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import (
reorder_output_results,
)
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError
@@ -69,7 +66,8 @@ class MemoryAgentService:
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作
if audit_logger:
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=True,
duration=duration, details={"message_length": len(message)})
return context
else:
@@ -88,8 +86,6 @@ class MemoryAgentService:
raise ValueError(f"写入失败: {messages}")
def extract_tool_call_info(self, event: Dict) -> bool:
"""Extract tool call information from event"""
last_message = event["messages"][-1]
@@ -271,7 +267,8 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
"""
Process write operation with config_id
@@ -300,7 +297,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
@@ -331,7 +329,8 @@ class MemoryAgentService:
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
@@ -375,19 +374,18 @@ class MemoryAgentService:
contents = massages.get('write_result')
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
contents)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
async def read_memory(
self,
end_user_id: str,
@@ -437,7 +435,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
@@ -454,7 +453,6 @@ class MemoryAgentService:
except ImportError:
audit_logger = None
config_load_start = time.time()
try:
# Use a separate database session to avoid transaction failures
@@ -576,7 +574,8 @@ class MemoryAgentService:
raw_results = intermediate.get('raw_results', {})
try:
reranked_results = raw_results.get('reranked_results', [])
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
statements = [statement['statement'] for statement in
reranked_results.get('statements', [])]
except Exception:
statements = []
@@ -602,7 +601,8 @@ class MemoryAgentService:
)
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
else:
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
logger.debug(
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
except Exception as save_error:
# 保存失败不应该影响主流程,只记录错误
@@ -610,7 +610,8 @@ class MemoryAgentService:
# Log successful operation
total_time = time.time() - start_time
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
logger.info(
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
@@ -641,7 +642,6 @@ class MemoryAgentService:
)
raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
"""
Get standardized message list from user input.
@@ -665,7 +665,8 @@ class MemoryAgentService:
for idx, msg in enumerate(user_input.messages):
if not isinstance(msg, dict):
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
raise ValueError(
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
if 'role' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
@@ -673,7 +674,8 @@ class MemoryAgentService:
if 'content' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
raise ValueError(
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
if msg['role'] not in ['user', 'assistant']:
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
@@ -719,6 +721,7 @@ class MemoryAgentService:
status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}")
return status
async def generate_summary_from_retrieve(
self,
end_user_id: str,
@@ -805,13 +808,12 @@ class MemoryAgentService:
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
return "信息不足,无法回答。"
async def get_knowledge_type_stats(
self,
db: Session,
end_user_id: Optional[str] = None,
only_active: bool = True,
current_workspace_id: Optional[uuid.UUID] = None,
db: Session = None
current_workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""
统计知识库类型分布,包含:
@@ -837,11 +839,6 @@ class MemoryAgentService:
# 1. 统计 PostgreSQL 中的知识库类型
try:
if db is None:
from app.db import get_db
db_gen = get_db()
db = next(db_gen)
# 初始化所有标准类型为 0
for kb_type in KnowledgeType:
result[kb_type.value] = 0
@@ -889,8 +886,6 @@ class MemoryAgentService:
return result
async def get_interest_distribution_by_user(
self,
end_user_id: Optional[str] = None,
@@ -921,7 +916,6 @@ class MemoryAgentService:
logger.error(f"兴趣分布标签查询失败: {e}")
raise Exception(f"兴趣分布标签查询失败: {e}")
async def get_user_profile(
self,
end_user_id: Optional[str] = None,
@@ -1017,7 +1011,8 @@ class MemoryAgentService:
# 定义标签提取的结构
class UserTags(BaseModel):
tags: list[str] = Field(..., description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友")
tags: list[str] = Field(...,
description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友")
messages = [
{
@@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
ValueError: 当终端用户不存在或应用未发布时
"""
import json as json_module
import uuid
from sqlalchemy import select
@@ -1268,7 +1262,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
"workspace_id": str(app.workspace_id)
}
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
logger.info(
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
return result

View File

@@ -1,45 +1,42 @@
# 修改 memory_konwledges_server.py 文件
import asyncio
import os
import re
import uuid
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field
from fastapi import HTTPException, status
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.config import settings
from app.core.logging_config import get_api_logger
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.response_utils import success
from app.db import get_db
from app.schemas import file_schema, document_schema
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
from app.db import get_db_context
from app.models.document_model import Document
import uuid
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.core.config import settings
from app.models.user_model import User
from app.schemas import file_schema, document_schema
from app.schemas.file_schema import CustomTextFileCreate
from app.services import document_service, file_service, knowledge_service
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.schemas.file_schema import CustomTextFileCreate
from app.db import get_db
# 创建一个简单的用户类用于测试
api_logger = get_api_logger()
class ChunkCreate(BaseModel):
content: str
class SimpleUser:
def __init__(self, user_id: str):
# 确保ID是UUID类型
self.id = user_id
self.username = user_id
'''解析'''
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
"""
解析指定文档
@@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
raise
'''获取块ID'''
async def get_document_chunks(
kb_id: uuid.UUID,
document_id: uuid.UUID,
@@ -198,7 +195,7 @@ async def get_document_chunks(
return success(data=result, msg="文档块列表查询成功")
'''查找文档ID'''
def find_document_id_by_kb_and_filename(
db: Session,
kb_id: str,
@@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename(
except Exception as e:
return None
'''获取知识库ID'''
def find_documents_by_kb_id(
db: Session,
kb_id: str,
@@ -268,18 +265,14 @@ def find_documents_by_kb_id(
except Exception as e:
return []
''''上传文件'''
async def memory_konwledges_up(
kb_id: str,
parent_id: str,
create_data: file_schema.CustomTextFileCreate,
db: Session = Depends(get_db),
current_user: SimpleUser = None, # 修改为SimpleUser
db: Session,
current_user: SimpleUser,
):
# 如果没有提供current_user则创建一个默认的
if current_user is None:
current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60")
content_bytes = create_data.content.encode('utf-8')
file_size = len(content_bytes)
print(f"file size: {file_size} byte")
@@ -350,8 +343,6 @@ async def memory_konwledges_up(
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
'''添加新块'''
async def create_document_chunk(
kb_id: uuid.UUID,
@@ -450,6 +441,7 @@ async def create_document_chunk(
return success(data=chunk, msg="文档块创建成功")
async def write_rag(end_user_id, message, user_rag_memory_id):
"""
将消息写入 RAG 知识库
@@ -483,10 +475,7 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
detail=f"知识库ID格式无效: {user_rag_memory_id}"
)
db_gen = get_db()
db = next(db_gen)
try:
with get_db_context() as db:
create_data = CustomTextFileCreate(title=end_user_id, content=message)
current_user = SimpleUser(user_rag_memory_id)
# 检查文档是否已存在
@@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
else:
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
return result
finally:
# 确保数据库会话被关闭
db.close()

View File

@@ -21,8 +21,7 @@ from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.cypher_queries import Graph_Node_query
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
from app.services.implicit_memory_service import ImplicitMemoryService
from app.services.memory_base_service import MemoryBaseService, MemoryTransService
from app.services.memory_base_service import MemoryBaseService
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService
@@ -1167,7 +1166,6 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
from app.core.language_utils import validate_language
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository
# 验证语言参数
@@ -1178,8 +1176,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
if end_user_id:
try:
# 获取数据库会话并查询用户信息
db = next(get_db())
try:
with get_db_context() as db:
repo = EndUserRepository(db)
end_user = repo.get_by_id(uuid.UUID(end_user_id))
if end_user and end_user.other_name:
@@ -1187,8 +1184,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
else:
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
finally:
db.close()
except Exception as e:
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")