fix(celery, rag): unify rag_write return type and remove deprecated downstream calls

- Unify the return type of `rag_write` in Celery tasks for consistency.
- Remove two deprecated downstream API calls to avoid obsolete dependencies.
This commit is contained in:
Eternity
2026-03-20 18:28:21 +08:00
parent c17a2dad2d
commit dce7206c44
4 changed files with 169 additions and 190 deletions

View File

@@ -267,8 +267,16 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
async def write_memory(
self,
end_user_id: str,
messages: list[dict],
config_id: Optional[uuid.UUID] | int,
db: Session,
storage_type: str,
user_rag_memory_id: str,
language: str = "zh"
) -> str:
"""
Process write operation with config_id
@@ -297,8 +305,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
raise ValueError(f"No memory configuration found for end_user {end_user_id}. "
f"Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
@@ -338,8 +346,8 @@ class MemoryAgentService:
if storage_type == "rag":
# For RAG storage, convert messages to single string
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
result = await write_rag(end_user_id, message_text, user_rag_memory_id)
return result
await write_rag(end_user_id, message_text, user_rag_memory_id)
return "success"
else:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}

View File

@@ -341,7 +341,7 @@ async def memory_konwledges_up(
)
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
return db_document
async def create_document_chunk(
@@ -350,7 +350,7 @@ async def create_document_chunk(
create_data: ChunkCreate,
db: Session,
current_user: User
):
) -> DocumentChunk:
"""
创建文档块
@@ -439,10 +439,10 @@ async def create_document_chunk(
db_document.chunk_num += 1
db.commit()
return success(data=chunk, msg="文档块创建成功")
return chunk
async def write_rag(end_user_id, message, user_rag_memory_id):
async def write_rag(end_user_id, message, user_rag_memory_id) -> DocumentChunk:
"""
将消息写入 RAG 知识库
@@ -482,11 +482,11 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
print('======', document)
api_logger.info(f"查找文档结果: document_id={document}")
create_chunks = ChunkCreate(content=message)
if document is not None:
# 文档已存在,直接添加新块
api_logger.info(f"文档已存在,添加新块: document_id={document}")
create_chunks = ChunkCreate(content=message)
result = await create_document_chunk(
kb_id=kb_uuid,
document_id=uuid.UUID(document),
@@ -498,13 +498,20 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
else:
# 文档不存在,创建新文档
api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}")
result = await memory_konwledges_up(
document = await memory_konwledges_up(
kb_id=user_rag_memory_id,
parent_id=user_rag_memory_id,
create_data=create_data,
db=db,
current_user=current_user
)
result = await create_document_chunk(
kb_id=kb_uuid,
document_id=document.id,
create_data=create_chunks,
db=db,
current_user=current_user
)
# 重新查询刚创建的文档ID
new_document_id = find_document_id_by_kb_and_filename(
db=db,