feat: Add base project structure with API and web components

This commit is contained in:
Ke Sun
2025-12-02 20:28:01 +08:00
parent f3de6d6cc9
commit c1adc62ec6
817 changed files with 111226 additions and 106 deletions

View File

View File

@@ -0,0 +1,106 @@
import os
import queue
import threading
from typing import Any, Callable, Coroutine, Optional, Type, Union
import asyncio
import trio
from functools import wraps
from flask import make_response, jsonify
from .constants import RetCode
TimeoutException = Union[Type[BaseException], BaseException]
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None,
on_timeout: Optional[OnTimeoutCallback] = None):
if isinstance(seconds, str):
seconds = float(seconds)
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
result_queue = queue.Queue(maxsize=1)
def target():
try:
result = func(*args, **kwargs)
result_queue.put(result)
except Exception as e:
result_queue.put(e)
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
for a in range(attempts):
try:
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
result = result_queue.get(timeout=seconds)
else:
result = result_queue.get()
if isinstance(result, Exception):
raise result
return result
except queue.Empty:
pass
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.")
@wraps(func)
async def async_wrapper(*args, **kwargs) -> Any:
if seconds is None:
return await func(*args, **kwargs)
for a in range(attempts):
try:
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
with trio.fail_after(seconds):
return await func(*args, **kwargs)
else:
return await func(*args, **kwargs)
except trio.TooSlowError:
if a < attempts - 1:
continue
if on_timeout is not None:
if callable(on_timeout):
result = on_timeout()
if isinstance(result, Coroutine):
return await result
return result
return on_timeout
if exception is None:
raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
if isinstance(exception, BaseException):
raise exception
if isinstance(exception, type) and issubclass(exception, BaseException):
raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
raise RuntimeError("Invalid exception type provided")
if asyncio.iscoroutinefunction(func):
return async_wrapper
return wrapper
return decorator
def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
result_dict = {"code": code, "message": message, "data": data}
response_dict = {}
for key, value in result_dict.items():
if value is None and key != "code":
continue
else:
response_dict[key] = value
response = make_response(jsonify(response_dict))
if auth:
response.headers["Authorization"] = auth
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Method"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Expose-Headers"] = "Authorization"
return response

View File

@@ -0,0 +1,180 @@
from enum import Enum, IntEnum
from strenum import StrEnum
SERVICE_CONF = "service_conf.yaml"
RAG_SERVICE_NAME = "rag"
class CustomEnum(Enum):
@classmethod
def valid(cls, value):
try:
cls(value)
return True
except BaseException:
return False
@classmethod
def values(cls):
return [member.value for member in cls.__members__.values()]
@classmethod
def names(cls):
return [member.name for member in cls.__members__.values()]
class RetCode(IntEnum, CustomEnum):
SUCCESS = 0
NOT_EFFECTIVE = 10
EXCEPTION_ERROR = 100
ARGUMENT_ERROR = 101
DATA_ERROR = 102
OPERATING_ERROR = 103
CONNECTION_ERROR = 105
RUNNING = 106
PERMISSION_ERROR = 108
AUTHENTICATION_ERROR = 109
UNAUTHORIZED = 401
SERVER_ERROR = 500
FORBIDDEN = 403
NOT_FOUND = 404
class StatusEnum(Enum):
VALID = "1"
INVALID = "0"
class ActiveEnum(Enum):
ACTIVE = "1"
INACTIVE = "0"
class LLMType(StrEnum):
CHAT = 'chat'
EMBEDDING = 'embedding'
SPEECH2TEXT = 'speech2text'
IMAGE2TEXT = 'image2text'
RERANK = 'rerank'
TTS = 'tts'
class TaskStatus(StrEnum):
UNSTART = "0"
RUNNING = "1"
CANCEL = "2"
DONE = "3"
FAIL = "4"
SCHEDULE = "5"
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL,
TaskStatus.SCHEDULE}
class ParserType(StrEnum):
PRESENTATION = "presentation"
LAWS = "laws"
MANUAL = "manual"
PAPER = "paper"
RESUME = "resume"
BOOK = "book"
QA = "qa"
TABLE = "table"
NAIVE = "naive"
PICTURE = "picture"
ONE = "one"
AUDIO = "audio"
EMAIL = "email"
KG = "knowledge_graph"
TAG = "tag"
class FileSource(StrEnum):
LOCAL = ""
KNOWLEDGEBASE = "knowledgebase"
S3 = "s3"
NOTION = "notion"
DISCORD = "discord"
CONFLUENCE = "confluence"
GMAIL = "gmail"
GOOGLE_DRIVE = "google_drive"
JIRA = "jira"
SHAREPOINT = "sharepoint"
SLACK = "slack"
TEAMS = "teams"
class PipelineTaskType(StrEnum):
PARSE = "Parse"
DOWNLOAD = "Download"
RAPTOR = "RAPTOR"
GRAPH_RAG = "GraphRAG"
MINDMAP = "Mindmap"
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR,
PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
class MCPServerType(StrEnum):
SSE = "sse"
STREAMABLE_HTTP = "streamable-http"
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
class Storage(Enum):
MINIO = 1
AZURE_SPN = 2
AZURE_SAS = 3
AWS_S3 = 4
OSS = 5
OPENDAL = 6
# environment
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"
# ENV_RAG_SECRET_KEY = "RAG_SECRET_KEY"
# ENV_REGISTER_ENABLED = "REGISTER_ENABLED"
# ENV_DOC_ENGINE = "DOC_ENGINE"
# ENV_SANDBOX_ENABLED = "SANDBOX_ENABLED"
# ENV_SANDBOX_HOST = "SANDBOX_HOST"
# ENV_MAX_CONTENT_LENGTH = "MAX_CONTENT_LENGTH"
# ENV_COMPONENT_EXEC_TIMEOUT = "COMPONENT_EXEC_TIMEOUT"
# ENV_TRINO_USE_TLS = "TRINO_USE_TLS"
# ENV_MAX_FILE_NUM_PER_USER = "MAX_FILE_NUM_PER_USER"
# ENV_MACOS = "MACOS"
# ENV_RAG_DEBUGPY_LISTEN = "RAG_DEBUGPY_LISTEN"
# ENV_WERKZEUG_RUN_MAIN = "WERKZEUG_RUN_MAIN"
# ENV_DISABLE_SDK = "DISABLE_SDK"
# ENV_ENABLE_TIMEOUT_ASSERTION = "ENABLE_TIMEOUT_ASSERTION"
# ENV_LOG_LEVELS = "LOG_LEVELS"
# ENV_TENSORRT_DLA_SVR = "TENSORRT_DLA_SVR"
# ENV_OCR_GPU_MEM_LIMIT_MB = "OCR_GPU_MEM_LIMIT_MB"
# ENV_OCR_ARENA_EXTEND_STRATEGY = "OCR_ARENA_EXTEND_STRATEGY"
# ENV_MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK = "MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK"
# ENV_MAX_MAX_CONCURRENT_CHATS = "MAX_CONCURRENT_CHATS"
# ENV_RAG_MCP_BASE_URL = "RAG_MCP_BASE_URL"
# ENV_RAG_MCP_HOST = "RAG_MCP_HOST"
# ENV_RAG_MCP_PORT = "RAG_MCP_PORT"
# ENV_RAG_MCP_LAUNCH_MODE = "RAG_MCP_LAUNCH_MODE"
# ENV_RAG_MCP_HOST_API_KEY = "RAG_MCP_HOST_API_KEY"
# ENV_MINERU_EXECUTABLE = "MINERU_EXECUTABLE"
# ENV_MINERU_APISERVER = "MINERU_APISERVER"
# ENV_MINERU_OUTPUT_DIR = "MINERU_OUTPUT_DIR"
# ENV_MINERU_BACKEND = "MINERU_BACKEND"
# ENV_MINERU_DELETE_OUTPUT = "MINERU_DELETE_OUTPUT"
# ENV_TCADP_OUTPUT_DIR = "TCADP_OUTPUT_DIR"
# ENV_LM_TIMEOUT_SECONDS = "LM_TIMEOUT_SECONDS"
# ENV_LLM_MAX_RETRIES = "LLM_MAX_RETRIES"
# ENV_LLM_BASE_DELAY = "LLM_BASE_DELAY"
# ENV_OLLAMA_KEEP_ALIVE = "OLLAMA_KEEP_ALIVE"
# ENV_DOC_BULK_SIZE = "DOC_BULK_SIZE"
# ENV_EMBEDDING_BATCH_SIZE = "EMBEDDING_BATCH_SIZE"
# ENV_MAX_CONCURRENT_TASKS = "MAX_CONCURRENT_TASKS"
# ENV_MAX_CONCURRENT_CHUNK_BUILDERS = "MAX_CONCURRENT_CHUNK_BUILDERS"
# ENV_MAX_CONCURRENT_MINIO = "MAX_CONCURRENT_MINIO"
# ENV_WORKER_HEARTBEAT_TIMEOUT = "WORKER_HEARTBEAT_TIMEOUT"
# ENV_TRACE_MALLOC_ENABLED = "TRACE_MALLOC_ENABLED"
PAGERANK_FLD = "pagerank_fea"
SVR_QUEUE_NAME = "rag_svr_queue"
SVR_CONSUMER_GROUP_NAME = "rag_svr_task_broker"
TAG_FLD = "tag_feas"

View File

@@ -0,0 +1,28 @@
import os
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
def get_project_base_directory(*args):
global PROJECT_BASE
if PROJECT_BASE is None:
PROJECT_BASE = os.path.abspath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
os.pardir,
os.pardir,
os.pardir,
os.pardir,
)
)
if args:
return os.path.join(PROJECT_BASE, *args)
return PROJECT_BASE
def traversal_files(base):
for root, ds, fs in os.walk(base):
for f in fs:
fullname = os.path.join(root, f)
yield fullname

View File

@@ -0,0 +1,30 @@
def get_float(v):
"""
Convert a value to float, handling None and exceptions gracefully.
Attempts to convert the input value to a float. If the value is None or
cannot be converted to float, returns negative infinity as a default value.
Args:
v: The value to convert to float. Can be any type that float() accepts,
or None.
Returns:
float: The converted float value if successful, otherwise float('-inf').
Examples:
>>> get_float("3.14")
3.14
>>> get_float(None)
-inf
>>> get_float("invalid")
-inf
>>> get_float(42)
42.0
"""
if v is None:
return float('-inf')
try:
return float(v)
except Exception:
return float('-inf')

View File

@@ -0,0 +1,92 @@
import base64
import hashlib
import uuid
import requests
import threading
import subprocess
import sys
import os
import logging
def get_uuid():
return uuid.uuid1().hex
def download_img(url):
if not url:
return ""
response = requests.get(url)
return "data:" + \
response.headers.get('Content-Type', 'image/jpg') + ";" + \
"base64," + base64.b64encode(response.content).decode("utf-8")
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
def convert_bytes(size_in_bytes: int) -> str:
"""
Format size in bytes.
"""
if size_in_bytes == 0:
return "0 B"
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
i = 0
size = float(size_in_bytes)
while size >= 1024 and i < len(units) - 1:
size /= 1024
i += 1
if i == 0 or size >= 100:
return f"{size:.0f} {units[i]}"
elif size >= 10:
return f"{size:.1f} {units[i]}"
else:
return f"{size:.2f} {units[i]}"
def once(func):
"""
A thread-safe decorator that ensures the decorated function runs exactly once,
caching and returning its result for all subsequent calls. This prevents
race conditions in multi-thread environments by using a lock to protect
the execution state.
Args:
func (callable): The function to be executed only once.
Returns:
callable: A wrapper function that executes `func` on the first call
and returns the cached result thereafter.
Example:
@once
def compute_expensive_value():
print("Computing...")
return 42
# First call: executes and prints
# Subsequent calls: return 42 without executing
"""
executed = False
result = None
lock = threading.Lock()
def wrapper(*args, **kwargs):
nonlocal executed, result
with lock:
if not executed:
executed = True
result = func(*args, **kwargs)
return result
return wrapper
@once
def pip_install_torch():
device = os.getenv("DEVICE", "cpu")
if device=="cpu":
return
logging.info("Installing pytorch")
pkg_names = ["torch>=2.5.0,<3.0.0"]
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])

View File

@@ -0,0 +1,2 @@
PARALLEL_DEVICES: int = 0

View File

@@ -0,0 +1,57 @@
import re
def remove_redundant_spaces(txt: str):
"""
Remove redundant spaces around punctuation marks while preserving meaningful spaces.
This function performs two main operations:
1. Remove spaces after left-boundary characters (opening brackets, etc.)
2. Remove spaces before right-boundary characters (closing brackets, punctuation, etc.)
Args:
txt (str): Input text to process
Returns:
str: Text with redundant spaces removed
"""
# First pass: Remove spaces after left-boundary characters
# Matches: [non-alphanumeric-and-specific-right-punctuation] + [non-space]
# Removes spaces after characters like '(', '<', and other non-alphanumeric chars
# Examples:
# "( test" → "(test"
txt = re.sub(r"([^a-z0-9.,\)>]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE)
# Second pass: Remove spaces before right-boundary characters
# Matches: [non-space] + [non-alphanumeric-and-specific-left-punctuation]
# Removes spaces before characters like non-')', non-',', non-'.', and non-alphanumeric chars
# Examples:
# "world !" → "world!"
return re.sub(r"([^ ]) +([^a-z0-9.,\(<])", r"\1\2", txt, flags=re.IGNORECASE)
def clean_markdown_block(text):
"""
Remove Markdown code block syntax from the beginning and end of text.
This function cleans Markdown code blocks by removing:
- Opening ```Markdown tags (with optional whitespace and newlines)
- Closing ``` tags (with optional whitespace and newlines)
Args:
text (str): Input text that may be wrapped in Markdown code blocks
Returns:
str: Cleaned text with Markdown code block syntax removed, and stripped of surrounding whitespace
"""
# Remove opening ```markdown tag with optional whitespace and newlines
# Matches: optional whitespace + ```markdown + optional whitespace + optional newline
text = re.sub(r'^\s*```markdown\s*\n?', '', text)
# Remove closing ``` tag with optional whitespace and newlines
# Matches: optional newline + optional whitespace + ``` + optional whitespace at end
text = re.sub(r'\n?\s*```\s*$', '', text)
# Return text with surrounding whitespace removed
return text.strip()

View File

@@ -0,0 +1,59 @@
import os
import tiktoken
from .file_utils import get_project_base_directory
tiktoken_cache_dir = os.path.join(get_project_base_directory(), "res")
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
encoder = tiktoken.get_encoding("cl100k_base")
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
try:
code_list = encoder.encode(string)
return len(code_list)
except Exception:
return 0
def total_token_count_from_response(resp):
if resp is None:
return 0
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
try:
return resp.usage.total_tokens
except Exception:
pass
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
try:
return resp.usage_metadata.total_tokens
except Exception:
pass
if 'usage' in resp and 'total_tokens' in resp['usage']:
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
if 'usage' in resp and 'input_tokens' in resp['usage'] and 'output_tokens' in resp['usage']:
try:
return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"]
except Exception:
pass
if 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']:
try:
return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"]
except Exception:
pass
return 0
def truncate(string: str, max_len: int) -> str:
"""Returns truncated text if the length of text exceed max_len."""
return encoder.decode(encoder.encode(string)[:max_len])