- Unify the return type of `rag_write` in Celery tasks for consistency. - Remove two deprecated downstream API calls to avoid obsolete dependencies.
527 lines
17 KiB
Python
527 lines
17 KiB
Python
# 修改 memory_konwledges_server.py 文件
|
||
|
||
import os
|
||
import uuid
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
from fastapi import HTTPException, status
|
||
from pydantic import BaseModel
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.celery_app import celery_app
|
||
from app.core.config import settings
|
||
from app.core.logging_config import get_api_logger
|
||
from app.core.rag.models.chunk import DocumentChunk
|
||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||
from app.core.response_utils import success
|
||
from app.db import get_db_context
|
||
from app.models.document_model import Document
|
||
from app.models.user_model import User
|
||
from app.schemas import file_schema, document_schema
|
||
from app.schemas.file_schema import CustomTextFileCreate
|
||
from app.services import document_service, file_service, knowledge_service
|
||
|
||
# 创建一个简单的用户类用于测试
|
||
api_logger = get_api_logger()
|
||
|
||
|
||
class ChunkCreate(BaseModel):
|
||
content: str
|
||
|
||
|
||
class SimpleUser:
|
||
def __init__(self, user_id: str):
|
||
# 确保ID是UUID类型
|
||
self.id = user_id
|
||
self.username = user_id
|
||
|
||
|
||
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
|
||
"""
|
||
解析指定文档
|
||
|
||
Args:
|
||
document_id: 文档ID
|
||
db: 数据库会话
|
||
current_user: 当前用户
|
||
|
||
Returns:
|
||
dict: 包含任务ID的结果字典
|
||
|
||
Raises:
|
||
HTTPException: 当文档、文件或知识库不存在时抛出异常
|
||
"""
|
||
|
||
try:
|
||
# 1. 检查文档是否存在
|
||
api_logger.debug(f"检查文档是否存在: {document_id}")
|
||
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
|
||
|
||
if not db_document:
|
||
api_logger.warning(f"文档不存在或无访问权限: document_id={document_id}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="文档不存在或无访问权限"
|
||
)
|
||
|
||
# 2. 检查文件是否存在
|
||
api_logger.debug(f"检查文件是否存在: {db_document.file_id}")
|
||
db_file = file_service.get_file_by_id(db, file_id=db_document.file_id)
|
||
|
||
if not db_file:
|
||
api_logger.warning(f"文件不存在或无访问权限: file_id={db_document.file_id}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="文件不存在或无访问权限"
|
||
)
|
||
|
||
# 3. 构建文件路径:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||
file_path = os.path.join(
|
||
settings.FILE_PATH,
|
||
str(db_file.kb_id),
|
||
str(db_file.parent_id),
|
||
f"{db_file.id}{db_file.file_ext}"
|
||
)
|
||
|
||
# 4. 检查文件是否存在于磁盘上
|
||
if not os.path.exists(file_path):
|
||
api_logger.warning(f"文件未找到(可能已被删除): file_path={file_path}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="文件未找到(可能已被删除)"
|
||
)
|
||
|
||
# 5. 获取知识库信息
|
||
api_logger.info(f"获取知识库详情: knowledge_id={db_document.kb_id}")
|
||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id,
|
||
current_user=current_user)
|
||
if not db_knowledge:
|
||
api_logger.warning(f"知识库不存在或访问被拒绝: knowledge_id={db_document.kb_id}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="知识库不存在或访问被拒绝"
|
||
)
|
||
|
||
# 6. 发送解析任务到Celery后台队列
|
||
task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id])
|
||
|
||
result = {
|
||
"task_id": task.id
|
||
}
|
||
|
||
api_logger.info(f"文档解析任务已接受: document_id={document_id}, task_id={task.id}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
|
||
raise
|
||
|
||
|
||
async def get_document_chunks(
|
||
kb_id: uuid.UUID,
|
||
document_id: uuid.UUID,
|
||
page: int = 1,
|
||
pagesize: int = 20,
|
||
keywords: Optional[str] = None,
|
||
db: Session = None,
|
||
current_user: User = None
|
||
):
|
||
"""
|
||
分页查询文档块列表
|
||
|
||
Args:
|
||
kb_id: 知识库ID
|
||
document_id: 文档ID
|
||
page: 页码,默认为1
|
||
pagesize: 每页大小,默认为20
|
||
keywords: 用于匹配块内容的关键字
|
||
db: 数据库会话
|
||
current_user: 当前用户
|
||
|
||
Returns:
|
||
dict: 包含分页数据的响应结果
|
||
|
||
Raises:
|
||
HTTPException: 当知识库不存在或查询失败时抛出异常
|
||
"""
|
||
api_logger.info(
|
||
f"分页查询文档块列表: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||
|
||
# 参数验证
|
||
if page < 1 or pagesize < 1:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="分页参数必须大于0"
|
||
)
|
||
|
||
# 获取知识库信息
|
||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||
if not db_knowledge:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="知识库不存在或访问被拒绝"
|
||
)
|
||
|
||
# 执行分页查询
|
||
try:
|
||
api_logger.debug("开始执行文档块查询")
|
||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||
total, items = vector_service.search_by_segment(
|
||
document_id=str(document_id),
|
||
query=keywords,
|
||
pagesize=pagesize,
|
||
page=page,
|
||
asc=True
|
||
)
|
||
api_logger.info(f"文档块查询成功: total={total}, returned={len(items)} records")
|
||
except Exception as e:
|
||
api_logger.error(f"文档块查询失败: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"查询失败: {str(e)}"
|
||
)
|
||
|
||
# 构造响应结果
|
||
result = {
|
||
"items": items,
|
||
"page": {
|
||
"page": page,
|
||
"pagesize": pagesize,
|
||
"total": total,
|
||
"has_next": True if page * pagesize < total else False
|
||
}
|
||
}
|
||
|
||
return success(data=result, msg="文档块列表查询成功")
|
||
|
||
|
||
def find_document_id_by_kb_and_filename(
|
||
db: Session,
|
||
kb_id: str,
|
||
file_name: str
|
||
) -> str | None:
|
||
"""
|
||
通过 kb_id 和 file_name 在 documents 表中查找对应的 ID
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
kb_id: 知识库ID
|
||
file_name: 文件名
|
||
|
||
Returns:
|
||
str | None: 找到的 document ID,未找到返回 None
|
||
"""
|
||
try:
|
||
# 查询 documents 表
|
||
document = db.query(Document).filter(
|
||
Document.kb_id == kb_id,
|
||
Document.file_name == file_name
|
||
).first()
|
||
|
||
if document:
|
||
print(f"找到文档: ID={document.id}, kb_id={kb_id}, file_name={file_name}")
|
||
return str(document.id)
|
||
else:
|
||
return None
|
||
|
||
except Exception as e:
|
||
return None
|
||
|
||
|
||
def find_documents_by_kb_id(
|
||
db: Session,
|
||
kb_id: str,
|
||
limit: int = 10
|
||
) -> list[dict]:
|
||
"""
|
||
通过 kb_id 查找所有相关文档
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
kb_id: 知识库ID
|
||
limit: 返回结果数量限制
|
||
|
||
Returns:
|
||
list[dict]: 文档列表,包含 id, file_name, created_at 等信息
|
||
"""
|
||
try:
|
||
documents = db.query(Document).filter(
|
||
Document.kb_id == kb_id
|
||
).limit(limit).all()
|
||
|
||
result = []
|
||
for doc in documents:
|
||
result.append({
|
||
"id": str(doc.id),
|
||
"file_name": doc.file_name,
|
||
"file_ext": doc.file_ext,
|
||
"file_size": doc.file_size,
|
||
"created_at": doc.created_at.isoformat() if doc.created_at else None,
|
||
"status": getattr(doc, 'status', None)
|
||
})
|
||
return result
|
||
|
||
except Exception as e:
|
||
return []
|
||
|
||
|
||
async def memory_konwledges_up(
|
||
kb_id: str,
|
||
parent_id: str,
|
||
create_data: file_schema.CustomTextFileCreate,
|
||
db: Session,
|
||
current_user: SimpleUser,
|
||
):
|
||
content_bytes = create_data.content.encode('utf-8')
|
||
file_size = len(content_bytes)
|
||
print(f"file size: {file_size} byte")
|
||
|
||
if file_size == 0:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="The content is empty."
|
||
)
|
||
|
||
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
|
||
if file_size > settings.MAX_FILE_SIZE:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit"
|
||
)
|
||
|
||
upload_file = file_schema.FileCreate(
|
||
kb_id=kb_id,
|
||
created_by=current_user.id, # 现在是UUID类型
|
||
parent_id=parent_id,
|
||
file_name=f"{create_data.title}.txt",
|
||
file_ext=".txt",
|
||
file_size=file_size,
|
||
)
|
||
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
|
||
|
||
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||
# 使用 settings.FILE_PATH 确保与 parse_document_by_id 一致
|
||
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
|
||
|
||
# 确保目录存在
|
||
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
||
|
||
save_path = os.path.join(save_dir, f"{db_file.id}.txt")
|
||
|
||
# Save file
|
||
with open(save_path, "wb") as f:
|
||
f.write(content_bytes)
|
||
|
||
# Verify whether the file has been saved successfully
|
||
if not os.path.exists(save_path):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="File save failed"
|
||
)
|
||
|
||
# Create a document
|
||
create_document_data = document_schema.DocumentCreate(
|
||
kb_id=kb_id,
|
||
created_by=current_user.id,
|
||
file_id=db_file.id,
|
||
file_name=db_file.file_name,
|
||
file_ext=db_file.file_ext,
|
||
file_size=db_file.file_size,
|
||
file_meta={},
|
||
parser_id="naive",
|
||
parser_config={
|
||
"layout_recognize": "DeepDOC",
|
||
"chunk_token_num": 128,
|
||
"delimiter": "\n",
|
||
"auto_keywords": 0,
|
||
"auto_questions": 0,
|
||
"html4excel": "false"
|
||
}
|
||
)
|
||
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
||
|
||
return db_document
|
||
|
||
|
||
async def create_document_chunk(
|
||
kb_id: uuid.UUID,
|
||
document_id: uuid.UUID,
|
||
create_data: ChunkCreate,
|
||
db: Session,
|
||
current_user: User
|
||
) -> DocumentChunk:
|
||
"""
|
||
创建文档块
|
||
|
||
Args:
|
||
kb_id: 知识库ID
|
||
document_id: 文档ID
|
||
create_data: 创建数据
|
||
db: 数据库会话
|
||
current_user: 当前用户
|
||
|
||
Returns:
|
||
dict: 包含创建的文档块信息的成功响应
|
||
|
||
Raises:
|
||
HTTPException: 当知识库或文档不存在时抛出相应异常
|
||
"""
|
||
api_logger.info(
|
||
f"创建文档块请求: kb_id={kb_id}, document_id={document_id}, content={create_data.content}, username: {current_user.username}")
|
||
|
||
# 1. 获取知识库信息
|
||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||
if not db_knowledge:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="知识库不存在或访问被拒绝"
|
||
)
|
||
|
||
# 2. 获取文档信息
|
||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||
if not db_document:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="文档不存在或您无访问权限"
|
||
)
|
||
|
||
# 3. 初始化向量服务
|
||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||
|
||
# 4. 获取排序ID(处理索引不存在的情况)
|
||
sort_id = 0
|
||
try:
|
||
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
|
||
if items:
|
||
sort_id = items[0].metadata["sort_id"]
|
||
except Exception as e:
|
||
# 如果索引不存在,从 0 开始
|
||
error_msg = str(e)
|
||
if "index_not_found_exception" in error_msg or "no such index" in error_msg:
|
||
api_logger.warning(f"索引不存在,将从 sort_id=0 开始: {error_msg}")
|
||
sort_id = 0
|
||
else:
|
||
# 其他错误则抛出
|
||
api_logger.error(f"查询文档块失败: {error_msg}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"查询文档块失败: {error_msg}"
|
||
)
|
||
|
||
sort_id = sort_id + 1
|
||
|
||
# 5. 创建文档块
|
||
doc_id = uuid.uuid4().hex
|
||
metadata = {
|
||
"doc_id": doc_id,
|
||
"file_id": str(db_document.file_id),
|
||
"file_name": db_document.file_name,
|
||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||
"document_id": str(document_id),
|
||
"knowledge_id": str(kb_id),
|
||
"sort_id": sort_id,
|
||
"status": 1,
|
||
}
|
||
chunk = DocumentChunk(page_content=create_data.content, metadata=metadata)
|
||
|
||
# 6. 存储向量化的文档块(这会自动创建索引如果不存在)
|
||
try:
|
||
vector_service.add_chunks([chunk])
|
||
except Exception as e:
|
||
api_logger.error(f"添加文档块到向量库失败: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"添加文档块到向量库失败: {str(e)}"
|
||
)
|
||
|
||
# 7. 更新 chunk_num
|
||
db_document.chunk_num += 1
|
||
db.commit()
|
||
|
||
return chunk
|
||
|
||
|
||
async def write_rag(end_user_id, message, user_rag_memory_id) -> DocumentChunk:
|
||
"""
|
||
将消息写入 RAG 知识库
|
||
|
||
Args:
|
||
end_user_id: 组ID,用作文件标题
|
||
message: 消息内容
|
||
user_rag_memory_id: 知识库ID(必须是有效的UUID)
|
||
|
||
Returns:
|
||
写入结果
|
||
|
||
Raises:
|
||
HTTPException: 当参数无效或操作失败时
|
||
"""
|
||
# 验证 user_rag_memory_id 是否为有效的 UUID
|
||
if not user_rag_memory_id:
|
||
api_logger.error("user_rag_memory_id 为空,无法执行 RAG 写入操作")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="知识库ID不能为空"
|
||
)
|
||
|
||
try:
|
||
# 尝试将字符串转换为 UUID 以验证格式
|
||
kb_uuid = uuid.UUID(user_rag_memory_id)
|
||
except (ValueError, AttributeError) as e:
|
||
api_logger.error(f"user_rag_memory_id 不是有效的UUID: {user_rag_memory_id}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"知识库ID格式无效: {user_rag_memory_id}"
|
||
)
|
||
|
||
with get_db_context() as db:
|
||
create_data = CustomTextFileCreate(title=end_user_id, content=message)
|
||
current_user = SimpleUser(user_rag_memory_id)
|
||
# 检查文档是否已存在
|
||
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
||
print('======', document)
|
||
api_logger.info(f"查找文档结果: document_id={document}")
|
||
create_chunks = ChunkCreate(content=message)
|
||
if document is not None:
|
||
# 文档已存在,直接添加新块
|
||
api_logger.info(f"文档已存在,添加新块: document_id={document}")
|
||
|
||
result = await create_document_chunk(
|
||
kb_id=kb_uuid,
|
||
document_id=uuid.UUID(document),
|
||
create_data=create_chunks,
|
||
db=db,
|
||
current_user=current_user
|
||
)
|
||
return result
|
||
else:
|
||
# 文档不存在,创建新文档
|
||
api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}")
|
||
document = await memory_konwledges_up(
|
||
kb_id=user_rag_memory_id,
|
||
parent_id=user_rag_memory_id,
|
||
create_data=create_data,
|
||
db=db,
|
||
current_user=current_user
|
||
)
|
||
result = await create_document_chunk(
|
||
kb_id=kb_uuid,
|
||
document_id=document.id,
|
||
create_data=create_chunks,
|
||
db=db,
|
||
current_user=current_user
|
||
)
|
||
# 重新查询刚创建的文档ID
|
||
new_document_id = find_document_id_by_kb_and_filename(
|
||
db=db,
|
||
kb_id=user_rag_memory_id,
|
||
file_name=f"{end_user_id}.txt"
|
||
)
|
||
|
||
if new_document_id:
|
||
await parse_document_by_id(new_document_id, db=db, current_user=current_user)
|
||
else:
|
||
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
|
||
return result
|