Files
MemoryBear/api/app/tasks.py
2025-12-15 14:09:43 +08:00

467 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import asyncio
from typing import Any, Dict, List, Optional
import requests
from datetime import datetime, timezone
import time
import uuid
from math import ceil
import redis
import json
from app.db import get_db
from app.models.document_model import Document
from app.models.knowledge_model import Knowledge
from app.core.rag.llm.cv_model import QWenCV
from app.core.rag.llm.chat_model import Base
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.rag.models.chunk import DocumentChunk
from app.services.memory_agent_service import MemoryAgentService
from app.core.config import settings
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
from app.core.rag.prompts.generator import question_proposal
# Import a unified Celery instance
from app.celery_app import celery_app
@celery_app.task(name="tasks.process_item")
def process_item(item: dict):
"""
A simulated long-running task that processes an item.
In a real-world scenario, this could be anything:
- Sending an email
- Generating a report
- Performing a complex calculation
- Calling a third-party API
"""
print(f"Processing item: {item['name']}")
# Simulate work for 5 seconds
time.sleep(5)
result = f"Item '{item['name']}' processed successfully at a price of ${item['price']}."
print(result)
return result
@celery_app.task(name="app.core.rag.tasks.parse_document")
def parse_document(file_path: str, document_id: uuid.UUID):
"""
Document parsing, vectorization, and storage
"""
db = next(get_db()) # Manually call the generator
db_document = None
db_knowledge = None
progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n"
try:
db_document = db.query(Document).filter(Document.id == document_id).first()
db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first()
# 1. Document parsing & segmentation
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to parse.\n"
start_time = time.time()
db_document.progress = 0.0
db_document.progress_msg = progress_msg
db_document.process_begin_at = datetime.now(tz=timezone.utc)
db_document.process_duration = 0.0
db_document.run = 1
db.commit()
db.refresh(db_document)
def progress_callback(prog=None, msg=None):
nonlocal progress_msg # Declare the use of an external progress_msg variable
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.\n"
# Prepare to configure chat_mdl、vision_model information
chat_model = Base(
key=db_knowledge.llm.api_keys[0].api_key,
model_name=db_knowledge.llm.api_keys[0].model_name,
base_url=db_knowledge.llm.api_keys[0].api_base
)
vision_model = QWenCV(
key=db_knowledge.image2text.api_keys[0].api_key,
model_name=db_knowledge.image2text.api_keys[0].model_name,
lang="Chinese",
base_url=db_knowledge.image2text.api_keys[0].api_base
)
from app.core.rag.app.naive import chunk
res = chunk(filename=file_path,
from_page=0,
to_page=100000,
callback=progress_callback,
vision_model=vision_model,
parser_config=db_document.parser_config,
is_root=False)
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.\n"
db_document.progress = 0.8
db_document.progress_msg = progress_msg
db.commit()
db.refresh(db_document)
# 2. Document vectorization and storage
total_chunks = len(res)
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.\n"
batch_size = 100
total_batches = ceil(total_chunks / batch_size)
progress_per_batch = 0.2 / total_batches # Progress of each batch
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# 2.1 Delete document vector index
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
# 2.2 Vectorize and import batch documents
for batch_start in range(0, total_chunks, batch_size):
batch_end = min(batch_start + batch_size, total_chunks) # prevent out-of-bounds
batch = res[batch_start: batch_end] # Retrieve the current batch
chunks = []
# Process the current batch
for idx_in_batch, item in enumerate(batch):
global_idx = batch_start + idx_in_batch # Calculate global index
metadata = {
"doc_id": uuid.uuid4().hex,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(db_document.id),
"knowledge_id": str(db_document.kb_id),
"sort_id": global_idx,
"status": 1,
}
if db_document.parser_config.get("auto_questions", 0):
topn = db_document.parser_config["auto_questions"]
cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", {"topn": topn})
if not cached:
cached = question_proposal(chat_model, item["content_with_weight"], topn)
set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", {"topn": topn})
chunks.append(DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", metadata=metadata))
else:
chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata))
# Bulk segmented vector import
vector_service.add_chunks(chunks)
# Update progress
db_document.progress += progress_per_batch
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).\n"
db_document.progress_msg = progress_msg
db_document.process_duration = time.time() - start_time
db_document.run = 0
db.commit()
db.refresh(db_document)
# Vectorization and data entry completed
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Indexing done.\n"
db_document.chunk_num = total_chunks
db_document.progress = 1.0
db_document.process_duration = time.time() - start_time
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).\n"
db_document.progress_msg = progress_msg
db_document.run = 0
db.commit()
result = f"parse document '{db_document.file_name}' processed successfully."
return result
except Exception as e:
if 'db_document' in locals():
db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n"
db_document.run = 0
db.commit()
result = f"parse document '{db_document.file_name}' failed."
return result
finally:
db.close()
@celery_app.task(name="app.core.memory.agent.read_message", bind=True)
def read_message_task(self, group_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
"""Celery task to process a read message via MemoryAgentService.
Args:
group_id: Group ID for the memory agent
message: User message to process
history: Conversation history
search_switch: Search switch parameter
config_id: Optional configuration ID
Returns:
Dict containing the result and metadata
Raises:
Exception on failure
"""
start_time = time.time()
async def _run() -> str:
service = MemoryAgentService()
return await service.read_memory(group_id, message, history, search_switch, config_id,storage_type,user_rag_memory_id)
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
return {
"status": "SUCCESS",
"result": result,
"group_id": group_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
except Exception as e:
elapsed_time = time.time() - start_time
return {
"status": "FAILURE",
"error": str(e),
"group_id": group_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
def write_message_task(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService.
Args:
group_id: Group ID for the memory agent
message: Message to write
config_id: Optional configuration ID
Returns:
Dict containing the result and metadata
Raises:
Exception on failure
"""
start_time = time.time()
async def _run() -> str:
service = MemoryAgentService()
return await service.write_memory(group_id, message, config_id,storage_type,user_rag_memory_id)
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
return {
"status": "SUCCESS",
"result": result,
"group_id": group_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
except Exception as e:
elapsed_time = time.time() - start_time
return {
"status": "FAILURE",
"error": str(e),
"group_id": group_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
def reflection_engine() -> None:
"""Empty function placeholder for timed background reflection.
Intentionally left blank; replace with real reflection logic later.
"""
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
import asyncio
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
asyncio.run(self_reflexion(host_id))
@celery_app.task(name="app.core.memory.agent.reflection.timer")
def reflection_timer_task() -> None:
"""Periodic Celery task that invokes reflection_engine.
Raises an exception on failure.
"""
reflection_engine()
@celery_app.task(name="app.core.memory.agent.health.check_read_service")
def check_read_service_task() -> Dict[str, str]:
"""Call read_service and write latest status to Redis.
Returns status data dict that gets written to Redis.
"""
client = redis.Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
)
try:
api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service"
payload = {
"user_id": "健康检查",
"apply_id": "健康检查",
"group_id": "健康检查",
"message": "你好",
"history": [],
"search_switch": "2",
}
resp = requests.post(api_url, json=payload, timeout=15)
ok = resp.status_code == 200
status = "Success" if ok else "Fail"
msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}"
error = "" if ok else resp.text
code = 0 if ok else 500
except Exception as e:
status = "Fail"
msg = "接口请求失败"
error = str(e)
code = 500
data = {
"status": status,
"msg": msg,
"error": error,
"code": str(code),
"time": str(int(time.time())),
}
client.hset("memsci:health:read_service", mapping=data)
client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS))
return data
@celery_app.task(name="app.controllers.memory_storage_controller.search_all")
def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
"""定时任务:查询工作空间下所有宿主的记忆总量并写入数据库
Args:
workspace_id: 工作空间ID
Returns:
包含任务执行结果的字典
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.services.memory_storage_service import search_all
from app.repositories.memory_increment_repository import write_memory_increment
from app.models.end_user_model import EndUser
from app.models.app_model import App
db = next(get_db())
try:
workspace_uuid = uuid.UUID(workspace_id)
# 1. 查询当前workspace下的所有app
apps = db.query(App).filter(App.workspace_id == workspace_uuid).all()
if not apps:
# 如果没有app总量为0
memory_increment = write_memory_increment(
db=db,
workspace_id=workspace_uuid,
total_num=0
)
return {
"status": "SUCCESS",
"workspace_id": workspace_id,
"total_num": 0,
"end_user_count": 0,
"memory_increment_id": str(memory_increment.id),
"created_at": memory_increment.created_at.isoformat(),
}
# 2. 查询所有app下的end_user_id去重
app_ids = [app.id for app in apps]
end_users = db.query(EndUser.id).filter(
EndUser.app_id.in_(app_ids)
).distinct().all()
# 3. 遍历所有end_user查询每个宿主的记忆总量并累加
total_num = 0
end_user_details = []
for (end_user_id,) in end_users:
try:
# 调用 search_all 接口查询该宿主的总量
result = await search_all(str(end_user_id))
user_total = result.get("total", 0)
total_num += user_total
end_user_details.append({
"end_user_id": str(end_user_id),
"total": user_total
})
except Exception as e:
# 记录单个用户查询失败,但继续处理其他用户
end_user_details.append({
"end_user_id": str(end_user_id),
"total": 0,
"error": str(e)
})
# 4. 写入数据库
memory_increment = write_memory_increment(
db=db,
workspace_id=workspace_uuid,
total_num=total_num
)
return {
"status": "SUCCESS",
"workspace_id": workspace_id,
"total_num": total_num,
"end_user_count": len(end_users),
"end_user_details": end_user_details,
"memory_increment_id": str(memory_increment.id),
"created_at": memory_increment.created_at.isoformat(),
}
finally:
db.close()
try:
result = asyncio.run(_run())
elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time
return result
except Exception as e:
elapsed_time = time.time() - start_time
return {
"status": "FAILURE",
"error": str(e),
"workspace_id": workspace_id,
"elapsed_time": elapsed_time,
}