feat: Add base project structure with API and web components
This commit is contained in:
0
api/app/core/rag/common/__init__.py
Normal file
0
api/app/core/rag/common/__init__.py
Normal file
106
api/app/core/rag/common/connection_utils.py
Normal file
106
api/app/core/rag/common/connection_utils.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
from typing import Any, Callable, Coroutine, Optional, Type, Union
|
||||
import asyncio
|
||||
import trio
|
||||
from functools import wraps
|
||||
from flask import make_response, jsonify
|
||||
from .constants import RetCode
|
||||
|
||||
TimeoutException = Union[Type[BaseException], BaseException]
|
||||
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
|
||||
|
||||
|
||||
def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None,
|
||||
on_timeout: Optional[OnTimeoutCallback] = None):
|
||||
if isinstance(seconds, str):
|
||||
seconds = float(seconds)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
result_queue = queue.Queue(maxsize=1)
|
||||
|
||||
def target():
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
result_queue.put(result)
|
||||
except Exception as e:
|
||||
result_queue.put(e)
|
||||
|
||||
thread = threading.Thread(target=target)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
for a in range(attempts):
|
||||
try:
|
||||
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
|
||||
result = result_queue.get(timeout=seconds)
|
||||
else:
|
||||
result = result_queue.get()
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
except queue.Empty:
|
||||
pass
|
||||
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.")
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs) -> Any:
|
||||
if seconds is None:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
for a in range(attempts):
|
||||
try:
|
||||
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
|
||||
with trio.fail_after(seconds):
|
||||
return await func(*args, **kwargs)
|
||||
else:
|
||||
return await func(*args, **kwargs)
|
||||
except trio.TooSlowError:
|
||||
if a < attempts - 1:
|
||||
continue
|
||||
if on_timeout is not None:
|
||||
if callable(on_timeout):
|
||||
result = on_timeout()
|
||||
if isinstance(result, Coroutine):
|
||||
return await result
|
||||
return result
|
||||
return on_timeout
|
||||
|
||||
if exception is None:
|
||||
raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
|
||||
|
||||
if isinstance(exception, BaseException):
|
||||
raise exception
|
||||
|
||||
if isinstance(exception, type) and issubclass(exception, BaseException):
|
||||
raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
|
||||
|
||||
raise RuntimeError("Invalid exception type provided")
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
|
||||
result_dict = {"code": code, "message": message, "data": data}
|
||||
response_dict = {}
|
||||
for key, value in result_dict.items():
|
||||
if value is None and key != "code":
|
||||
continue
|
||||
else:
|
||||
response_dict[key] = value
|
||||
response = make_response(jsonify(response_dict))
|
||||
if auth:
|
||||
response.headers["Authorization"] = auth
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Method"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
response.headers["Access-Control-Expose-Headers"] = "Authorization"
|
||||
return response
|
||||
180
api/app/core/rag/common/constants.py
Normal file
180
api/app/core/rag/common/constants.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from enum import Enum, IntEnum
|
||||
from strenum import StrEnum
|
||||
|
||||
SERVICE_CONF = "service_conf.yaml"
|
||||
RAG_SERVICE_NAME = "rag"
|
||||
|
||||
class CustomEnum(Enum):
|
||||
@classmethod
|
||||
def valid(cls, value):
|
||||
try:
|
||||
cls(value)
|
||||
return True
|
||||
except BaseException:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [member.value for member in cls.__members__.values()]
|
||||
|
||||
@classmethod
|
||||
def names(cls):
|
||||
return [member.name for member in cls.__members__.values()]
|
||||
|
||||
|
||||
class RetCode(IntEnum, CustomEnum):
|
||||
SUCCESS = 0
|
||||
NOT_EFFECTIVE = 10
|
||||
EXCEPTION_ERROR = 100
|
||||
ARGUMENT_ERROR = 101
|
||||
DATA_ERROR = 102
|
||||
OPERATING_ERROR = 103
|
||||
CONNECTION_ERROR = 105
|
||||
RUNNING = 106
|
||||
PERMISSION_ERROR = 108
|
||||
AUTHENTICATION_ERROR = 109
|
||||
UNAUTHORIZED = 401
|
||||
SERVER_ERROR = 500
|
||||
FORBIDDEN = 403
|
||||
NOT_FOUND = 404
|
||||
|
||||
|
||||
class StatusEnum(Enum):
|
||||
VALID = "1"
|
||||
INVALID = "0"
|
||||
|
||||
|
||||
class ActiveEnum(Enum):
|
||||
ACTIVE = "1"
|
||||
INACTIVE = "0"
|
||||
|
||||
|
||||
class LLMType(StrEnum):
|
||||
CHAT = 'chat'
|
||||
EMBEDDING = 'embedding'
|
||||
SPEECH2TEXT = 'speech2text'
|
||||
IMAGE2TEXT = 'image2text'
|
||||
RERANK = 'rerank'
|
||||
TTS = 'tts'
|
||||
|
||||
|
||||
class TaskStatus(StrEnum):
|
||||
UNSTART = "0"
|
||||
RUNNING = "1"
|
||||
CANCEL = "2"
|
||||
DONE = "3"
|
||||
FAIL = "4"
|
||||
SCHEDULE = "5"
|
||||
|
||||
|
||||
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL,
|
||||
TaskStatus.SCHEDULE}
|
||||
|
||||
|
||||
class ParserType(StrEnum):
|
||||
PRESENTATION = "presentation"
|
||||
LAWS = "laws"
|
||||
MANUAL = "manual"
|
||||
PAPER = "paper"
|
||||
RESUME = "resume"
|
||||
BOOK = "book"
|
||||
QA = "qa"
|
||||
TABLE = "table"
|
||||
NAIVE = "naive"
|
||||
PICTURE = "picture"
|
||||
ONE = "one"
|
||||
AUDIO = "audio"
|
||||
EMAIL = "email"
|
||||
KG = "knowledge_graph"
|
||||
TAG = "tag"
|
||||
|
||||
|
||||
class FileSource(StrEnum):
|
||||
LOCAL = ""
|
||||
KNOWLEDGEBASE = "knowledgebase"
|
||||
S3 = "s3"
|
||||
NOTION = "notion"
|
||||
DISCORD = "discord"
|
||||
CONFLUENCE = "confluence"
|
||||
GMAIL = "gmail"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
JIRA = "jira"
|
||||
SHAREPOINT = "sharepoint"
|
||||
SLACK = "slack"
|
||||
TEAMS = "teams"
|
||||
|
||||
|
||||
class PipelineTaskType(StrEnum):
|
||||
PARSE = "Parse"
|
||||
DOWNLOAD = "Download"
|
||||
RAPTOR = "RAPTOR"
|
||||
GRAPH_RAG = "GraphRAG"
|
||||
MINDMAP = "Mindmap"
|
||||
|
||||
|
||||
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR,
|
||||
PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
|
||||
|
||||
class MCPServerType(StrEnum):
|
||||
SSE = "sse"
|
||||
STREAMABLE_HTTP = "streamable-http"
|
||||
|
||||
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
||||
|
||||
class Storage(Enum):
|
||||
MINIO = 1
|
||||
AZURE_SPN = 2
|
||||
AZURE_SAS = 3
|
||||
AWS_S3 = 4
|
||||
OSS = 5
|
||||
OPENDAL = 6
|
||||
|
||||
# environment
|
||||
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"
|
||||
# ENV_RAG_SECRET_KEY = "RAG_SECRET_KEY"
|
||||
# ENV_REGISTER_ENABLED = "REGISTER_ENABLED"
|
||||
# ENV_DOC_ENGINE = "DOC_ENGINE"
|
||||
# ENV_SANDBOX_ENABLED = "SANDBOX_ENABLED"
|
||||
# ENV_SANDBOX_HOST = "SANDBOX_HOST"
|
||||
# ENV_MAX_CONTENT_LENGTH = "MAX_CONTENT_LENGTH"
|
||||
# ENV_COMPONENT_EXEC_TIMEOUT = "COMPONENT_EXEC_TIMEOUT"
|
||||
# ENV_TRINO_USE_TLS = "TRINO_USE_TLS"
|
||||
# ENV_MAX_FILE_NUM_PER_USER = "MAX_FILE_NUM_PER_USER"
|
||||
# ENV_MACOS = "MACOS"
|
||||
# ENV_RAG_DEBUGPY_LISTEN = "RAG_DEBUGPY_LISTEN"
|
||||
# ENV_WERKZEUG_RUN_MAIN = "WERKZEUG_RUN_MAIN"
|
||||
# ENV_DISABLE_SDK = "DISABLE_SDK"
|
||||
# ENV_ENABLE_TIMEOUT_ASSERTION = "ENABLE_TIMEOUT_ASSERTION"
|
||||
# ENV_LOG_LEVELS = "LOG_LEVELS"
|
||||
# ENV_TENSORRT_DLA_SVR = "TENSORRT_DLA_SVR"
|
||||
# ENV_OCR_GPU_MEM_LIMIT_MB = "OCR_GPU_MEM_LIMIT_MB"
|
||||
# ENV_OCR_ARENA_EXTEND_STRATEGY = "OCR_ARENA_EXTEND_STRATEGY"
|
||||
# ENV_MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK = "MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK"
|
||||
# ENV_MAX_MAX_CONCURRENT_CHATS = "MAX_CONCURRENT_CHATS"
|
||||
# ENV_RAG_MCP_BASE_URL = "RAG_MCP_BASE_URL"
|
||||
# ENV_RAG_MCP_HOST = "RAG_MCP_HOST"
|
||||
# ENV_RAG_MCP_PORT = "RAG_MCP_PORT"
|
||||
# ENV_RAG_MCP_LAUNCH_MODE = "RAG_MCP_LAUNCH_MODE"
|
||||
# ENV_RAG_MCP_HOST_API_KEY = "RAG_MCP_HOST_API_KEY"
|
||||
# ENV_MINERU_EXECUTABLE = "MINERU_EXECUTABLE"
|
||||
# ENV_MINERU_APISERVER = "MINERU_APISERVER"
|
||||
# ENV_MINERU_OUTPUT_DIR = "MINERU_OUTPUT_DIR"
|
||||
# ENV_MINERU_BACKEND = "MINERU_BACKEND"
|
||||
# ENV_MINERU_DELETE_OUTPUT = "MINERU_DELETE_OUTPUT"
|
||||
# ENV_TCADP_OUTPUT_DIR = "TCADP_OUTPUT_DIR"
|
||||
# ENV_LM_TIMEOUT_SECONDS = "LM_TIMEOUT_SECONDS"
|
||||
# ENV_LLM_MAX_RETRIES = "LLM_MAX_RETRIES"
|
||||
# ENV_LLM_BASE_DELAY = "LLM_BASE_DELAY"
|
||||
# ENV_OLLAMA_KEEP_ALIVE = "OLLAMA_KEEP_ALIVE"
|
||||
# ENV_DOC_BULK_SIZE = "DOC_BULK_SIZE"
|
||||
# ENV_EMBEDDING_BATCH_SIZE = "EMBEDDING_BATCH_SIZE"
|
||||
# ENV_MAX_CONCURRENT_TASKS = "MAX_CONCURRENT_TASKS"
|
||||
# ENV_MAX_CONCURRENT_CHUNK_BUILDERS = "MAX_CONCURRENT_CHUNK_BUILDERS"
|
||||
# ENV_MAX_CONCURRENT_MINIO = "MAX_CONCURRENT_MINIO"
|
||||
# ENV_WORKER_HEARTBEAT_TIMEOUT = "WORKER_HEARTBEAT_TIMEOUT"
|
||||
# ENV_TRACE_MALLOC_ENABLED = "TRACE_MALLOC_ENABLED"
|
||||
|
||||
PAGERANK_FLD = "pagerank_fea"
|
||||
SVR_QUEUE_NAME = "rag_svr_queue"
|
||||
SVR_CONSUMER_GROUP_NAME = "rag_svr_task_broker"
|
||||
TAG_FLD = "tag_feas"
|
||||
28
api/app/core/rag/common/file_utils.py
Normal file
28
api/app/core/rag/common/file_utils.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
|
||||
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
||||
|
||||
|
||||
def get_project_base_directory(*args):
|
||||
global PROJECT_BASE
|
||||
if PROJECT_BASE is None:
|
||||
PROJECT_BASE = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
os.pardir,
|
||||
os.pardir,
|
||||
os.pardir,
|
||||
os.pardir,
|
||||
)
|
||||
)
|
||||
|
||||
if args:
|
||||
return os.path.join(PROJECT_BASE, *args)
|
||||
return PROJECT_BASE
|
||||
|
||||
|
||||
def traversal_files(base):
|
||||
for root, ds, fs in os.walk(base):
|
||||
for f in fs:
|
||||
fullname = os.path.join(root, f)
|
||||
yield fullname
|
||||
30
api/app/core/rag/common/float_utils.py
Normal file
30
api/app/core/rag/common/float_utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
def get_float(v):
|
||||
"""
|
||||
Convert a value to float, handling None and exceptions gracefully.
|
||||
|
||||
Attempts to convert the input value to a float. If the value is None or
|
||||
cannot be converted to float, returns negative infinity as a default value.
|
||||
|
||||
Args:
|
||||
v: The value to convert to float. Can be any type that float() accepts,
|
||||
or None.
|
||||
|
||||
Returns:
|
||||
float: The converted float value if successful, otherwise float('-inf').
|
||||
|
||||
Examples:
|
||||
>>> get_float("3.14")
|
||||
3.14
|
||||
>>> get_float(None)
|
||||
-inf
|
||||
>>> get_float("invalid")
|
||||
-inf
|
||||
>>> get_float(42)
|
||||
42.0
|
||||
"""
|
||||
if v is None:
|
||||
return float('-inf')
|
||||
try:
|
||||
return float(v)
|
||||
except Exception:
|
||||
return float('-inf')
|
||||
92
api/app/core/rag/common/misc_utils.py
Normal file
92
api/app/core/rag/common/misc_utils.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import uuid
|
||||
import requests
|
||||
import threading
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
|
||||
def get_uuid():
|
||||
return uuid.uuid1().hex
|
||||
|
||||
|
||||
def download_img(url):
|
||||
if not url:
|
||||
return ""
|
||||
response = requests.get(url)
|
||||
return "data:" + \
|
||||
response.headers.get('Content-Type', 'image/jpg') + ";" + \
|
||||
"base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
|
||||
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
|
||||
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
|
||||
|
||||
def convert_bytes(size_in_bytes: int) -> str:
|
||||
"""
|
||||
Format size in bytes.
|
||||
"""
|
||||
if size_in_bytes == 0:
|
||||
return "0 B"
|
||||
|
||||
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
||||
i = 0
|
||||
size = float(size_in_bytes)
|
||||
|
||||
while size >= 1024 and i < len(units) - 1:
|
||||
size /= 1024
|
||||
i += 1
|
||||
|
||||
if i == 0 or size >= 100:
|
||||
return f"{size:.0f} {units[i]}"
|
||||
elif size >= 10:
|
||||
return f"{size:.1f} {units[i]}"
|
||||
else:
|
||||
return f"{size:.2f} {units[i]}"
|
||||
|
||||
|
||||
def once(func):
|
||||
"""
|
||||
A thread-safe decorator that ensures the decorated function runs exactly once,
|
||||
caching and returning its result for all subsequent calls. This prevents
|
||||
race conditions in multi-thread environments by using a lock to protect
|
||||
the execution state.
|
||||
|
||||
Args:
|
||||
func (callable): The function to be executed only once.
|
||||
|
||||
Returns:
|
||||
callable: A wrapper function that executes `func` on the first call
|
||||
and returns the cached result thereafter.
|
||||
|
||||
Example:
|
||||
@once
|
||||
def compute_expensive_value():
|
||||
print("Computing...")
|
||||
return 42
|
||||
|
||||
# First call: executes and prints
|
||||
# Subsequent calls: return 42 without executing
|
||||
"""
|
||||
executed = False
|
||||
result = None
|
||||
lock = threading.Lock()
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal executed, result
|
||||
with lock:
|
||||
if not executed:
|
||||
executed = True
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
@once
|
||||
def pip_install_torch():
|
||||
device = os.getenv("DEVICE", "cpu")
|
||||
if device=="cpu":
|
||||
return
|
||||
logging.info("Installing pytorch")
|
||||
pkg_names = ["torch>=2.5.0,<3.0.0"]
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])
|
||||
2
api/app/core/rag/common/settings.py
Normal file
2
api/app/core/rag/common/settings.py
Normal file
@@ -0,0 +1,2 @@
|
||||
PARALLEL_DEVICES: int = 0
|
||||
|
||||
57
api/app/core/rag/common/string_utils.py
Normal file
57
api/app/core/rag/common/string_utils.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import re
|
||||
|
||||
|
||||
def remove_redundant_spaces(txt: str):
|
||||
"""
|
||||
Remove redundant spaces around punctuation marks while preserving meaningful spaces.
|
||||
|
||||
This function performs two main operations:
|
||||
1. Remove spaces after left-boundary characters (opening brackets, etc.)
|
||||
2. Remove spaces before right-boundary characters (closing brackets, punctuation, etc.)
|
||||
|
||||
Args:
|
||||
txt (str): Input text to process
|
||||
|
||||
Returns:
|
||||
str: Text with redundant spaces removed
|
||||
"""
|
||||
# First pass: Remove spaces after left-boundary characters
|
||||
# Matches: [non-alphanumeric-and-specific-right-punctuation] + [non-space]
|
||||
# Removes spaces after characters like '(', '<', and other non-alphanumeric chars
|
||||
# Examples:
|
||||
# "( test" → "(test"
|
||||
txt = re.sub(r"([^a-z0-9.,\)>]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE)
|
||||
|
||||
# Second pass: Remove spaces before right-boundary characters
|
||||
# Matches: [non-space] + [non-alphanumeric-and-specific-left-punctuation]
|
||||
# Removes spaces before characters like non-')', non-',', non-'.', and non-alphanumeric chars
|
||||
# Examples:
|
||||
# "world !" → "world!"
|
||||
return re.sub(r"([^ ]) +([^a-z0-9.,\(<])", r"\1\2", txt, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def clean_markdown_block(text):
|
||||
"""
|
||||
Remove Markdown code block syntax from the beginning and end of text.
|
||||
|
||||
This function cleans Markdown code blocks by removing:
|
||||
- Opening ```Markdown tags (with optional whitespace and newlines)
|
||||
- Closing ``` tags (with optional whitespace and newlines)
|
||||
|
||||
Args:
|
||||
text (str): Input text that may be wrapped in Markdown code blocks
|
||||
|
||||
Returns:
|
||||
str: Cleaned text with Markdown code block syntax removed, and stripped of surrounding whitespace
|
||||
|
||||
"""
|
||||
# Remove opening ```markdown tag with optional whitespace and newlines
|
||||
# Matches: optional whitespace + ```markdown + optional whitespace + optional newline
|
||||
text = re.sub(r'^\s*```markdown\s*\n?', '', text)
|
||||
|
||||
# Remove closing ``` tag with optional whitespace and newlines
|
||||
# Matches: optional newline + optional whitespace + ``` + optional whitespace at end
|
||||
text = re.sub(r'\n?\s*```\s*$', '', text)
|
||||
|
||||
# Return text with surrounding whitespace removed
|
||||
return text.strip()
|
||||
59
api/app/core/rag/common/token_utils.py
Normal file
59
api/app/core/rag/common/token_utils.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import os
|
||||
import tiktoken
|
||||
|
||||
from .file_utils import get_project_base_directory
|
||||
|
||||
tiktoken_cache_dir = os.path.join(get_project_base_directory(), "res")
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
|
||||
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
try:
|
||||
code_list = encoder.encode(string)
|
||||
return len(code_list)
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def total_token_count_from_response(resp):
|
||||
if resp is None:
|
||||
return 0
|
||||
|
||||
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
|
||||
try:
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
|
||||
try:
|
||||
return resp.usage_metadata.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'usage' in resp and 'total_tokens' in resp['usage']:
|
||||
try:
|
||||
return resp["usage"]["total_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'usage' in resp and 'input_tokens' in resp['usage'] and 'output_tokens' in resp['usage']:
|
||||
try:
|
||||
return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']:
|
||||
try:
|
||||
return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"]
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
def truncate(string: str, max_len: int) -> str:
|
||||
"""Returns truncated text if the length of text exceed max_len."""
|
||||
return encoder.decode(encoder.encode(string)[:max_len])
|
||||
|
||||
Reference in New Issue
Block a user