feat: Add base project structure with API and web components

This commit is contained in:
Ke Sun
2025-12-02 20:28:01 +08:00
parent f3de6d6cc9
commit c1adc62ec6
817 changed files with 111226 additions and 106 deletions

View File

@@ -0,0 +1,122 @@
English | [简体中文](./README_zh.md)
# *Deep*Doc
- [1. Introduction](#1)
- [2. Vision](#2)
- [3. Parser](#3)
<a name="1"></a>
## 1. Introduction
With a bunch of documents from various domains with various formats and along with diverse retrieval requirements,
an accurate analysis becomes a very challenge task. *Deep*Doc is born for that purpose.
There are 2 parts in *Deep*Doc so far: vision and parser.
You can run the flowing test programs if you're interested in our results of OCR, layout recognition and TSR.
```bash
python deepdoc/vision/t_ocr.py -h
usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR]
options:
-h, --help show this help message and exit
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
--output_dir OUTPUT_DIR
Directory where to store the output images. Default: './ocr_outputs'
```
```bash
python deepdoc/vision/t_recognizer.py -h
usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}]
options:
-h, --help show this help message and exit
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
--output_dir OUTPUT_DIR
Directory where to store the output images. Default: './layouts_outputs'
--threshold THRESHOLD
A threshold to filter out detections. Default: 0.5
--mode {layout,tsr} Task mode: layout recognition or table structure recognition
```
Our models are served on HuggingFace. If you have trouble downloading HuggingFace models, this might help!!
```bash
export HF_ENDPOINT=https://hf-mirror.com
```
<a name="2"></a>
## 2. Vision
We use vision information to resolve problems as human being.
- OCR. Since a lot of documents presented as images or at least be able to transform to image,
OCR is a very essential and fundamental or even universal solution for text extraction.
```bash
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
```
The inputs could be directory to images or PDF, or a image or PDF.
You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results,
txt files which contain the OCR text.
<div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/f25bee3d-aaf7-4102-baf5-d5208361d110" width="900"/>
</div>
- Layout recognition. Documents from different domain may have various layouts,
like, newspaper, magazine, book and résumé are distinct in terms of layout.
Only when machine have an accurate layout analysis, it can decide if these text parts are successive or not,
or this part needs Table Structure Recognition(TSR) to process, or this part is a figure and described with this caption.
We have 10 basic layout components which covers most cases:
- Text
- Title
- Figure
- Figure caption
- Table
- Table caption
- Header
- Footer
- Reference
- Equation
Have a try on the following command to see the layout detection results.
```bash
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
```
The inputs could be directory to images or PDF, or a image or PDF.
You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following:
<div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
</div>
- Table Structure Recognition(TSR). Data table is a frequently used structure to present data including numbers or text.
And the structure of a table might be very complex, like hierarchy headers, spanning cells and projected row headers.
Along with TSR, we also reassemble the content into sentences which could be well comprehended by LLM.
We have five labels for TSR task:
- Column
- Row
- Column header
- Projected row header
- Spanning cell
Have a try on the following command to see the layout detection results.
```bash
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result
```
The inputs could be directory to images or PDF, or a image or PDF.
You can look into the folder 'path_to_store_result' where has both images and html pages which demonstrate the detection results as following:
<div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/>
</div>
<a name="3"></a>
## 3. Parser
Four kinds of document formats as PDF, DOCX, EXCEL and PPT have their corresponding parser.
The most complex one is PDF parser since PDF's flexibility. The output of PDF parser includes:
- Text chunks with their own positions in PDF(page number and rectangular positions).
- Tables with cropped image from the PDF, and contents which has already translated into natural language sentences.
- Figures with caption and text in the figures.
### Résumé
The résumé is a very complicated kind of document. A résumé which is composed of unstructured text
with various layouts could be resolved into structured data composed of nearly a hundred of fields.
We haven't opened the parser yet, as we open the processing method after parsing procedure.

View File

@@ -0,0 +1,116 @@
[English](./README.md) | 简体中文
# *Deep*Doc
- [*Deep*Doc](#deepdoc)
- [1. 介绍](#1-介绍)
- [2. 视觉处理](#2-视觉处理)
- [3. 解析器](#3-解析器)
- [简历](#简历)
<a name="1"></a>
## 1. 介绍
对于来自不同领域、具有不同格式和不同检索要求的大量文档,准确的分析成为一项极具挑战性的任务。*Deep*Doc 就是为了这个目的而诞生的。到目前为止,*Deep*Doc 中有两个组成部分视觉处理和解析器。如果您对我们的OCR、布局识别和TSR结果感兴趣您可以运行下面的测试程序。
```bash
python deepdoc/vision/t_ocr.py -h
usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR]
options:
-h, --help show this help message and exit
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
--output_dir OUTPUT_DIR
Directory where to store the output images. Default: './ocr_outputs'
```
```bash
python deepdoc/vision/t_recognizer.py -h
usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}]
options:
-h, --help show this help message and exit
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
--output_dir OUTPUT_DIR
Directory where to store the output images. Default: './layouts_outputs'
--threshold THRESHOLD
A threshold to filter out detections. Default: 0.5
--mode {layout,tsr} Task mode: layout recognition or table structure recognition
```
HuggingFace为我们的模型提供服务。如果你在下载HuggingFace模型时遇到问题这可能会有所帮助
```bash
export HF_ENDPOINT=https://hf-mirror.com
```
<a name="2"></a>
## 2. 视觉处理
作为人类,我们使用视觉信息来解决问题。
- **OCROptical Character Recognition光学字符识别**。由于许多文档都是以图像形式呈现的或者至少能够转换为图像因此OCR是文本提取的一个非常重要、基本甚至通用的解决方案。
```bash
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
```
输入可以是图像或PDF的目录或者单个图像、PDF文件。您可以查看文件夹 `path_to_store_result` 其中有演示结果位置的图像以及包含OCR文本的txt文件。
<div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/f25bee3d-aaf7-4102-baf5-d5208361d110" width="900"/>
</div>
- 布局识别Layout recognition。来自不同领域的文件可能有不同的布局如报纸、杂志、书籍和简历在布局方面是不同的。只有当机器有准确的布局分析时它才能决定这些文本部分是连续的还是不连续的或者这个部分需要表结构识别Table Structure RecognitionTSR来处理或者这个部件是一个图形并用这个标题来描述。我们有10个基本布局组件涵盖了大多数情况
- 文本
- 标题
- 配图
- 配图标题
- 表格
- 表格标题
- 页头
- 页尾
- 参考引用
- 公式
请尝试以下命令以查看布局检测结果。
```bash
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
```
输入可以是图像或PDF的目录或者单个图像、PDF文件。您可以查看文件夹 `path_to_store_result` ,其中有显示检测结果的图像,如下所示:
<div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
</div>
- **TSRTable Structure Recognition表结构识别**。数据表是一种常用的结构用于表示包括数字或文本在内的数据。表的结构可能非常复杂比如层次结构标题、跨单元格和投影行标题。除了TSR我们还将内容重新组合成LLM可以很好理解的句子。TSR任务有五个标签
- 列
- 行
- 列标题
- 行标题
- 合并单元格
请尝试以下命令以查看布局检测结果。
```bash
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result
```
输入可以是图像或PDF的目录或者单个图像、PDF文件。您可以查看文件夹 `path_to_store_result` 其中包含图像和html页面这些页面展示了以下检测结果
<div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/>
</div>
<a name="3"></a>
## 3. 解析器
PDF、DOCX、EXCEL和PPT四种文档格式都有相应的解析器。最复杂的是PDF解析器因为PDF具有灵活性。PDF解析器的输出包括
- 在PDF中有自己位置的文本块页码和矩形位置
- 带有PDF裁剪图像的表格以及已经翻译成自然语言句子的内容。
- 图中带标题和文字的图。
### 简历
简历是一种非常复杂的文档。由各种格式的非结构化文本构成的简历可以被解析为包含近百个字段的结构化数据。我们还没有启用解析器,因为在解析过程之后才会启动处理方法。

View File

@@ -0,0 +1,2 @@
from beartype.claw import beartype_this_package
beartype_this_package()

View File

@@ -0,0 +1,24 @@
from .docx_parser import RAGDocxParser as DocxParser
from .excel_parser import RAGExcelParser as ExcelParser
from .html_parser import RAGHtmlParser as HtmlParser
from .json_parser import RAGJsonParser as JsonParser
from .markdown_parser import MarkdownElementExtractor
from .markdown_parser import RAGMarkdownParser as MarkdownParser
from .pdf_parser import PlainParser
from .pdf_parser import RAGPdfParser as PdfParser
from .ppt_parser import RAGPptParser as PptParser
from .txt_parser import RAGTxtParser as TxtParser
__all__ = [
"PdfParser",
"PlainParser",
"DocxParser",
"ExcelParser",
"PptParser",
"HtmlParser",
"JsonParser",
"MarkdownParser",
"TxtParser",
"MarkdownElementExtractor",
]

View File

@@ -0,0 +1,123 @@
from docx import Document
import re
import pandas as pd
from collections import Counter
from app.core.rag.nlp import rag_tokenizer
from io import BytesIO
class RAGDocxParser:
def __extract_table_content(self, tb):
df = []
for row in tb.rows:
df.append([c.text for c in row.cells])
return self.__compose_table_content(pd.DataFrame(df))
def __compose_table_content(self, df):
def blockType(b):
pattern = [
("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
(r"^(20|19)[0-9]{2}年$", "Dt"),
(r"^(20|19)[0-9]{2}[年/-][0-9]{1,2}月*$", "Dt"),
("^[0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
(r"^第*[一二三四1-4]季度$", "Dt"),
(r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
(r"^(20|19)[0-9]{2}[ABCDE]$", "DT"),
("^[0-9.,+%/ -]+$", "Nu"),
(r"^[0-9A-Z/\._~-]+$", "Ca"),
(r"^[A-Z]*[a-z' -]+$", "En"),
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()' -]+$", "NE"),
(r"^.{1}$", "Sg")
]
for p, n in pattern:
if re.search(p, b):
return n
tks = [t for t in rag_tokenizer.tokenize(b).split() if len(t) > 1]
if len(tks) > 3:
if len(tks) < 12:
return "Tx"
else:
return "Lx"
if len(tks) == 1 and rag_tokenizer.tag(tks[0]) == "nr":
return "Nr"
return "Ot"
if len(df) < 2:
return []
max_type = Counter([blockType(str(df.iloc[i, j])) for i in range(
1, len(df)) for j in range(len(df.iloc[i, :]))])
max_type = max(max_type.items(), key=lambda x: x[1])[0]
colnm = len(df.iloc[0, :])
hdrows = [0] # header is not necessarily appear in the first line
if max_type == "Nu":
for r in range(1, len(df)):
tys = Counter([blockType(str(df.iloc[r, j]))
for j in range(len(df.iloc[r, :]))])
tys = max(tys.items(), key=lambda x: x[1])[0]
if tys != max_type:
hdrows.append(r)
lines = []
for i in range(1, len(df)):
if i in hdrows:
continue
hr = [r - i for r in hdrows]
hr = [r for r in hr if r < 0]
t = len(hr) - 1
while t > 0:
if hr[t] - hr[t - 1] > 1:
hr = hr[t:]
break
t -= 1
headers = []
for j in range(len(df.iloc[i, :])):
t = []
for h in hr:
x = str(df.iloc[i + h, j]).strip()
if x in t:
continue
t.append(x)
t = ",".join(t)
if t:
t += ": "
headers.append(t)
cells = []
for j in range(len(df.iloc[i, :])):
if not str(df.iloc[i, j]):
continue
cells.append(headers[j] + str(df.iloc[i, j]))
lines.append(";".join(cells))
if colnm > 3:
return lines
return ["\n".join(lines)]
def __call__(self, fnm, from_page=0, to_page=100000000):
self.doc = Document(fnm) if isinstance(
fnm, str) else Document(BytesIO(fnm))
pn = 0 # parsed page
secs = [] # parsed contents
for p in self.doc.paragraphs:
if pn > to_page:
break
runs_within_single_paragraph = [] # save runs within the range of pages
for run in p.runs:
if pn > to_page:
break
if from_page <= pn < to_page and p.text.strip():
runs_within_single_paragraph.append(run.text) # append run.text first
# wrap page break checker into a static method
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
secs.append(("".join(runs_within_single_paragraph), p.style.name if hasattr(p.style, 'name') else '')) # then concat run.text as part of the paragraph
tbls = [self.__extract_table_content(tb) for tb in self.doc.tables]
return secs, tbls

View File

@@ -0,0 +1,210 @@
import logging
import re
import sys
from io import BytesIO
import pandas as pd
from openpyxl import Workbook, load_workbook
from app.core.rag.nlp import find_codec
# copied from `/openpyxl/cell/cell.py`
ILLEGAL_CHARACTERS_RE = re.compile(r"[\000-\010]|[\013-\014]|[\016-\037]")
class RAGExcelParser:
@staticmethod
def _load_excel_to_workbook(file_like_object):
if isinstance(file_like_object, bytes):
file_like_object = BytesIO(file_like_object)
# Read first 4 bytes to determine file type
file_like_object.seek(0)
file_head = file_like_object.read(4)
file_like_object.seek(0)
if not (file_head.startswith(b"PK\x03\x04") or file_head.startswith(b"\xd0\xcf\x11\xe0")):
logging.info("Not an Excel file, converting CSV to Excel Workbook")
try:
file_like_object.seek(0)
df = pd.read_csv(file_like_object)
return RAGExcelParser._dataframe_to_workbook(df)
except Exception as e_csv:
raise Exception(f"Failed to parse CSV and convert to Excel Workbook: {e_csv}")
try:
return load_workbook(file_like_object, data_only=True)
except Exception as e:
logging.info(f"openpyxl load error: {e}, try pandas instead")
try:
file_like_object.seek(0)
try:
dfs = pd.read_excel(file_like_object, sheet_name=None)
return RAGExcelParser._dataframe_to_workbook(dfs)
except Exception as ex:
logging.info(f"pandas with default engine load error: {ex}, try calamine instead")
file_like_object.seek(0)
df = pd.read_excel(file_like_object, engine="calamine")
return RAGExcelParser._dataframe_to_workbook(df)
except Exception as e_pandas:
raise Exception(f"pandas.read_excel error: {e_pandas}, original openpyxl error: {e}")
@staticmethod
def _clean_dataframe(df: pd.DataFrame):
def clean_string(s):
if isinstance(s, str):
return ILLEGAL_CHARACTERS_RE.sub(" ", s)
return s
return df.apply(lambda col: col.map(clean_string))
@staticmethod
def _dataframe_to_workbook(df):
# if contains multiple sheets use _dataframes_to_workbook
if isinstance(df, dict) and len(df) > 1:
return RAGExcelParser._dataframes_to_workbook(df)
df = RAGExcelParser._clean_dataframe(df)
wb = Workbook()
ws = wb.active
ws.title = "Data"
for col_num, column_name in enumerate(df.columns, 1):
ws.cell(row=1, column=col_num, value=column_name)
for row_num, row in enumerate(df.values, 2):
for col_num, value in enumerate(row, 1):
ws.cell(row=row_num, column=col_num, value=value)
return wb
@staticmethod
def _dataframes_to_workbook(dfs: dict):
wb = Workbook()
default_sheet = wb.active
wb.remove(default_sheet)
for sheet_name, df in dfs.items():
df = RAGExcelParser._clean_dataframe(df)
ws = wb.create_sheet(title=sheet_name)
for col_num, column_name in enumerate(df.columns, 1):
ws.cell(row=1, column=col_num, value=column_name)
for row_num, row in enumerate(df.values, 2):
for col_num, value in enumerate(row, 1):
ws.cell(row=row_num, column=col_num, value=value)
return wb
def html(self, fnm, chunk_rows=256):
from html import escape
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
wb = RAGExcelParser._load_excel_to_workbook(file_like_object)
tb_chunks = []
def _fmt(v):
if v is None:
return ""
return str(v).strip()
for sheetname in wb.sheetnames:
ws = wb[sheetname]
try:
rows = list(ws.rows)
except Exception as e:
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
continue
if not rows:
continue
tb_rows_0 = "<tr>"
for t in list(rows[0]):
tb_rows_0 += f"<th>{escape(_fmt(t.value))}</th>"
tb_rows_0 += "</tr>"
for chunk_i in range((len(rows) - 1) // chunk_rows + 1):
tb = ""
tb += f"<table><caption>{sheetname}</caption>"
tb += tb_rows_0
for r in list(rows[1 + chunk_i * chunk_rows : min(1 + (chunk_i + 1) * chunk_rows, len(rows))]):
tb += "<tr>"
for i, c in enumerate(r):
if c.value is None:
tb += "<td></td>"
else:
tb += f"<td>{escape(_fmt(c.value))}</td>"
tb += "</tr>"
tb += "</table>\n"
tb_chunks.append(tb)
return tb_chunks
def markdown(self, fnm):
import pandas as pd
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
try:
file_like_object.seek(0)
df = pd.read_excel(file_like_object)
except Exception as e:
logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file")
file_like_object.seek(0)
df = pd.read_csv(file_like_object)
df = df.replace(r"^\s*$", "", regex=True)
return df.to_markdown(index=False)
def __call__(self, fnm):
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
wb = RAGExcelParser._load_excel_to_workbook(file_like_object)
res = []
for sheetname in wb.sheetnames:
ws = wb[sheetname]
try:
rows = list(ws.rows)
except Exception as e:
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
continue
if not rows:
continue
ti = list(rows[0])
for r in list(rows[1:]):
fields = []
for i, c in enumerate(r):
if not c.value:
continue
t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value)
fields.append(t)
line = "; ".join(fields)
if sheetname.lower().find("sheet") < 0:
line += " ——" + sheetname
res.append(line)
return res
@staticmethod
def row_number(fnm, binary):
if fnm.split(".")[-1].lower().find("xls") >= 0:
wb = RAGExcelParser._load_excel_to_workbook(BytesIO(binary))
total = 0
for sheetname in wb.sheetnames:
try:
ws = wb[sheetname]
total += len(list(ws.rows))
except Exception as e:
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
continue
return total
if fnm.split(".")[-1].lower() in ["csv", "txt"]:
encoding = find_codec(binary)
txt = binary.decode(encoding, errors="ignore")
return len(txt.split("\n"))
if __name__ == "__main__":
psr = RAGExcelParser()
psr(sys.argv[1])

View File

@@ -0,0 +1,118 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from PIL import Image
from app.core.rag.common.constants import LLMType
from app.core.rag.common.connection_utils import timeout
from app.core.rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
from app.core.rag.prompts.generator import vision_llm_figure_describe_prompt
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
return [
(
(figure_data[1], [figure_data[0]]),
[(0, 0, 0, 0, 0)],
)
for figure_data in figures_data_without_positions
if isinstance(figure_data[1], Image.Image)
]
def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,vision_model=None,**kwargs):
if vision_model:
figures_data = vision_figure_parser_figure_data_wrapper(sections)
try:
docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs)
boosted_figures = docx_vision_parser(callback=callback)
tbls.extend(boosted_figures)
except Exception as e:
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
return tbls
def vision_figure_parser_pdf_wrapper(tbls,callback=None,vision_model=None,**kwargs):
if vision_model:
def is_figure_item(item):
return (
isinstance(item[0][0], Image.Image) and
isinstance(item[0][1], list)
)
figures_data = [item for item in tbls if is_figure_item(item)]
try:
docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs)
boosted_figures = docx_vision_parser(callback=callback)
tbls = [item for item in tbls if not is_figure_item(item)]
tbls.extend(boosted_figures)
except Exception as e:
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
return tbls
shared_executor = ThreadPoolExecutor(max_workers=10)
class VisionFigureParser:
def __init__(self, vision_model, figures_data, *args, **kwargs):
self.vision_model = vision_model
self._extract_figures_info(figures_data)
assert len(self.figures) == len(self.descriptions)
assert not self.positions or (len(self.figures) == len(self.positions))
def _extract_figures_info(self, figures_data):
self.figures = []
self.descriptions = []
self.positions = []
for item in figures_data:
# position
if len(item) == 2 and isinstance(item[0], tuple) and len(item[0]) == 2 and isinstance(item[1], list) and isinstance(item[1][0], tuple) and len(item[1][0]) == 5:
img_desc = item[0]
assert len(img_desc) == 2 and isinstance(img_desc[0], Image.Image) and isinstance(img_desc[1], list), "Should be (figure, [description])"
self.figures.append(img_desc[0])
self.descriptions.append(img_desc[1])
self.positions.append(item[1])
else:
assert len(item) == 2 and isinstance(item[0], Image.Image) and isinstance(item[1], list), f"Unexpected form of figure data: get {len(item)=}, {item=}"
self.figures.append(item[0])
self.descriptions.append(item[1])
def _assemble(self):
self.assembled = []
self.has_positions = len(self.positions) != 0
for i in range(len(self.figures)):
figure = self.figures[i]
desc = self.descriptions[i]
pos = self.positions[i] if self.has_positions else None
figure_desc = (figure, desc)
if pos is not None:
self.assembled.append((figure_desc, pos))
else:
self.assembled.append((figure_desc,))
return self.assembled
def __call__(self, **kwargs):
callback = kwargs.get("callback", lambda prog, msg: None)
@timeout(30, 3)
def process(figure_idx, figure_binary):
description_text = picture_vision_llm_chunk(
binary=figure_binary,
vision_model=self.vision_model,
prompt=vision_llm_figure_describe_prompt(),
callback=callback,
)
return figure_idx, description_text
futures = []
for idx, img_binary in enumerate(self.figures or []):
futures.append(shared_executor.submit(process, idx, img_binary))
for future in as_completed(futures):
figure_num, txt = future.result()
if txt:
self.descriptions[figure_num] = txt + "\n".join(self.descriptions[figure_num])
self._assemble()
return self.assembled

View File

@@ -0,0 +1,197 @@
from app.core.rag.nlp import find_codec, rag_tokenizer
import uuid
import chardet
from bs4 import BeautifulSoup, NavigableString, Tag, Comment
import html
def get_encoding(file):
with open(file,'rb') as f:
tmp = chardet.detect(f.read())
return tmp['encoding']
BLOCK_TAGS = [
"h1", "h2", "h3", "h4", "h5", "h6",
"p", "div", "article", "section", "aside",
"ul", "ol", "li",
"table", "pre", "code", "blockquote",
"figure", "figcaption"
]
TITLE_TAGS = {"h1": "#", "h2": "##", "h3": "###", "h4": "#####", "h5": "#####", "h6": "######"}
class RAGHtmlParser:
def __call__(self, fnm, binary=None, chunk_token_num=512):
if binary:
encoding = find_codec(binary)
txt = binary.decode(encoding, errors="ignore")
else:
with open(fnm, "r",encoding=get_encoding(fnm)) as f:
txt = f.read()
return self.parser_txt(txt, chunk_token_num)
@classmethod
def parser_txt(cls, txt, chunk_token_num):
if not isinstance(txt, str):
raise TypeError("txt type should be string!")
temp_sections = []
soup = BeautifulSoup(txt, "html5lib")
# delete <style> tag
for style_tag in soup.find_all(["style", "script"]):
style_tag.decompose()
# delete <script> tag in <div>
for div_tag in soup.find_all("div"):
for script_tag in div_tag.find_all("script"):
script_tag.decompose()
# delete inline style
for tag in soup.find_all(True):
if 'style' in tag.attrs:
del tag.attrs['style']
# delete HTML comment
for comment in soup.find_all(string=lambda text: isinstance(text, Comment)):
comment.extract()
cls.read_text_recursively(soup.body, temp_sections, chunk_token_num=chunk_token_num)
block_txt_list, table_list = cls.merge_block_text(temp_sections)
sections = cls.chunk_block(block_txt_list, chunk_token_num=chunk_token_num)
for table in table_list:
sections.append(table.get("content", ""))
return sections
@classmethod
def split_table(cls, html_table, chunk_token_num=512):
soup = BeautifulSoup(html_table, "html.parser")
rows = soup.find_all("tr")
tables = []
current_table = []
current_count = 0
table_str_list = []
for row in rows:
tks_str = rag_tokenizer.tokenize(str(row))
token_count = len(tks_str.split(" ")) if tks_str else 0
if current_count + token_count > chunk_token_num:
tables.append(current_table)
current_table = []
current_count = 0
current_table.append(row)
current_count += token_count
if current_table:
tables.append(current_table)
for table_rows in tables:
new_table = soup.new_tag("table")
for row in table_rows:
new_table.append(row)
table_str_list.append(str(new_table))
return table_str_list
@classmethod
def read_text_recursively(cls, element, parser_result, chunk_token_num=512, parent_name=None, block_id=None):
if isinstance(element, NavigableString):
content = element.strip()
def is_valid_html(content):
try:
soup = BeautifulSoup(content, "html.parser")
return bool(soup.find())
except Exception:
return False
return_info = []
if content:
if is_valid_html(content):
soup = BeautifulSoup(content, "html.parser")
child_info = cls.read_text_recursively(soup, parser_result, chunk_token_num, element.name, block_id)
parser_result.extend(child_info)
else:
info = {"content": element.strip(), "tag_name": "inner_text", "metadata": {"block_id": block_id}}
if parent_name:
info["tag_name"] = parent_name
return_info.append(info)
return return_info
elif isinstance(element, Tag):
if str.lower(element.name) == "table":
table_info_list = []
table_id = str(uuid.uuid1())
table_list = [html.unescape(str(element))]
for t in table_list:
table_info_list.append({"content": t, "tag_name": "table",
"metadata": {"table_id": table_id, "index": table_list.index(t)}})
return table_info_list
else:
block_id = None
if str.lower(element.name) in BLOCK_TAGS:
block_id = str(uuid.uuid1())
for child in element.children:
child_info = cls.read_text_recursively(child, parser_result, chunk_token_num, element.name,
block_id)
parser_result.extend(child_info)
return []
@classmethod
def merge_block_text(cls, parser_result):
block_content = []
current_content = ""
table_info_list = []
lask_block_id = None
for item in parser_result:
content = item.get("content")
tag_name = item.get("tag_name")
title_flag = tag_name in TITLE_TAGS
block_id = item.get("metadata", {}).get("block_id")
if block_id:
if title_flag:
content = f"{TITLE_TAGS[tag_name]} {content}"
if lask_block_id != block_id:
if lask_block_id is not None:
block_content.append(current_content)
current_content = content
lask_block_id = block_id
else:
current_content += (" " if current_content else "") + content
else:
if tag_name == "table":
table_info_list.append(item)
else:
current_content += (" " if current_content else "" + content)
if current_content:
block_content.append(current_content)
return block_content, table_info_list
@classmethod
def chunk_block(cls, block_txt_list, chunk_token_num=512):
chunks = []
current_block = ""
current_token_count = 0
for block in block_txt_list:
tks_str = rag_tokenizer.tokenize(block)
block_token_count = len(tks_str.split(" ")) if tks_str else 0
if block_token_count > chunk_token_num:
if current_block:
chunks.append(current_block)
start = 0
tokens = tks_str.split(" ")
while start < len(tokens):
end = start + chunk_token_num
split_tokens = tokens[start:end]
chunks.append(" ".join(split_tokens))
start = end
current_block = ""
current_token_count = 0
else:
if current_token_count + block_token_count <= chunk_token_num:
current_block += ("\n" if current_block else "") + block
current_token_count += block_token_count
else:
chunks.append(current_block)
current_block = block
current_token_count = block_token_count
if current_block:
chunks.append(current_block)
return chunks

View File

@@ -0,0 +1,159 @@
import json
from typing import Any
from app.core.rag.nlp import find_codec
class RAGJsonParser:
def __init__(self, max_chunk_size: int = 2000, min_chunk_size: int | None = None):
super().__init__()
self.max_chunk_size = max_chunk_size * 2
self.min_chunk_size = min_chunk_size if min_chunk_size is not None else max(max_chunk_size - 200, 50)
def __call__(self, filename):
with open(filename, "r") as f:
txt = f.read()
if self.is_jsonl_format(txt):
sections = self._parse_jsonl(txt)
else:
sections = self._parse_json(txt)
return sections
@staticmethod
def _json_size(data: dict) -> int:
"""Calculate the size of the serialized JSON object."""
return len(json.dumps(data, ensure_ascii=False))
@staticmethod
def _set_nested_dict(d: dict, path: list[str], value: Any) -> None:
"""Set a value in a nested dictionary based on the given path."""
for key in path[:-1]:
d = d.setdefault(key, {})
d[path[-1]] = value
def _list_to_dict_preprocessing(self, data: Any) -> Any:
if isinstance(data, dict):
# Process each key-value pair in the dictionary
return {k: self._list_to_dict_preprocessing(v) for k, v in data.items()}
elif isinstance(data, list):
# Convert the list to a dictionary with index-based keys
return {str(i): self._list_to_dict_preprocessing(item) for i, item in enumerate(data)}
else:
# Base case: the item is neither a dict nor a list, so return it unchanged
return data
def _json_split(
self,
data,
current_path: list[str] | None,
chunks: list[dict] | None,
) -> list[dict]:
"""
Split json into maximum size dictionaries while preserving structure.
"""
current_path = current_path or []
chunks = chunks or [{}]
if isinstance(data, dict):
for key, value in data.items():
new_path = current_path + [key]
chunk_size = self._json_size(chunks[-1])
size = self._json_size({key: value})
remaining = self.max_chunk_size - chunk_size
if size < remaining:
# Add item to current chunk
self._set_nested_dict(chunks[-1], new_path, value)
else:
if chunk_size >= self.min_chunk_size:
# Chunk is big enough, start a new chunk
chunks.append({})
# Iterate
self._json_split(value, new_path, chunks)
else:
# handle single item
self._set_nested_dict(chunks[-1], current_path, data)
return chunks
def split_json(
self,
json_data,
convert_lists: bool = False,
) -> list[dict]:
"""Splits JSON into a list of JSON chunks"""
if convert_lists:
preprocessed_data = self._list_to_dict_preprocessing(json_data)
chunks = self._json_split(preprocessed_data, None, None)
else:
chunks = self._json_split(json_data, None, None)
# Remove the last chunk if it's empty
if not chunks[-1]:
chunks.pop()
return chunks
def split_text(
self,
json_data: dict[str, Any],
convert_lists: bool = False,
ensure_ascii: bool = True,
) -> list[str]:
"""Splits JSON into a list of JSON formatted strings"""
chunks = self.split_json(json_data=json_data, convert_lists=convert_lists)
# Convert to string
return [json.dumps(chunk, ensure_ascii=ensure_ascii) for chunk in chunks]
def _parse_json(self, content: str) -> list[str]:
sections = []
try:
json_data = json.loads(content)
chunks = self.split_json(json_data, True)
sections = [json.dumps(line, ensure_ascii=False) for line in chunks if line]
except json.JSONDecodeError:
pass
return sections
def _parse_jsonl(self, content: str) -> list[str]:
lines = content.strip().splitlines()
all_chunks = []
for line in lines:
if not line.strip():
continue
try:
data = json.loads(line)
chunks = self.split_json(data, convert_lists=True)
all_chunks.extend(json.dumps(chunk, ensure_ascii=False) for chunk in chunks if chunk)
except json.JSONDecodeError:
continue
return all_chunks
def is_jsonl_format(self, txt: str, sample_limit: int = 10, threshold: float = 0.8) -> bool:
lines = [line.strip() for line in txt.strip().splitlines() if line.strip()]
if not lines:
return False
try:
json.loads(txt)
return False
except json.JSONDecodeError:
pass
sample_limit = min(len(lines), sample_limit)
sample_lines = lines[:sample_limit]
valid_lines = sum(1 for line in sample_lines if self._is_valid_json(line))
if not valid_lines:
return False
return (valid_lines / len(sample_lines)) >= threshold
def _is_valid_json(self, line: str) -> bool:
try:
json.loads(line)
return True
except json.JSONDecodeError:
return False

View File

@@ -0,0 +1,277 @@
import re
from markdown import markdown
class RAGMarkdownParser:
def __init__(self, chunk_token_num=128):
self.chunk_token_num = int(chunk_token_num)
def extract_tables_and_remainder(self, markdown_text, separate_tables=True):
tables = []
working_text = markdown_text
def replace_tables_with_rendered_html(pattern, table_list, render=True):
new_text = ""
last_end = 0
for match in pattern.finditer(working_text):
raw_table = match.group()
table_list.append(raw_table)
if separate_tables:
# Skip this match (i.e., remove it)
new_text += working_text[last_end : match.start()] + "\n\n"
else:
# Replace with rendered HTML
html_table = markdown(raw_table, extensions=["markdown.extensions.tables"]) if render else raw_table
new_text += working_text[last_end : match.start()] + html_table + "\n\n"
last_end = match.end()
new_text += working_text[last_end:]
return new_text
if "|" in markdown_text: # for optimize performance
# Standard Markdown table
border_table_pattern = re.compile(
r"""
(?:\n|^)
(?:\|.*?\|.*?\|.*?\n)
(?:\|(?:\s*[:-]+[-| :]*\s*)\|.*?\n)
(?:\|.*?\|.*?\|.*?\n)+
""",
re.VERBOSE,
)
working_text = replace_tables_with_rendered_html(border_table_pattern, tables)
# Borderless Markdown table
no_border_table_pattern = re.compile(
r"""
(?:\n|^)
(?:\S.*?\|.*?\n)
(?:(?:\s*[:-]+[-| :]*\s*).*?\n)
(?:\S.*?\|.*?\n)+
""",
re.VERBOSE,
)
working_text = replace_tables_with_rendered_html(no_border_table_pattern, tables)
# Replace any TAGS e.g. <table ...> to <table>
TAGS = ["table", "td", "tr", "th", "tbody", "thead", "div"]
table_with_attributes_pattern = re.compile(
rf"<(?:{'|'.join(TAGS)})[^>]*>", re.IGNORECASE
)
def replace_tag(m):
tag_name = re.match(r"<(\w+)", m.group()).group(1)
return "<{}>".format(tag_name)
working_text = re.sub(table_with_attributes_pattern, replace_tag, working_text)
if "<table>" in working_text.lower(): # for optimize performance
# HTML table extraction - handle possible html/body wrapper tags
html_table_pattern = re.compile(
r"""
(?:\n|^)
\s*
(?:
# case1: <html><body><table>...</table></body></html>
(?:<html[^>]*>\s*<body[^>]*>\s*<table[^>]*>.*?</table>\s*</body>\s*</html>)
|
# case2: <body><table>...</table></body>
(?:<body[^>]*>\s*<table[^>]*>.*?</table>\s*</body>)
|
# case3: only<table>...</table>
(?:<table[^>]*>.*?</table>)
)
\s*
(?=\n|$)
""",
re.VERBOSE | re.DOTALL | re.IGNORECASE,
)
def replace_html_tables():
nonlocal working_text
new_text = ""
last_end = 0
for match in html_table_pattern.finditer(working_text):
raw_table = match.group()
tables.append(raw_table)
if separate_tables:
new_text += working_text[last_end : match.start()] + "\n\n"
else:
new_text += working_text[last_end : match.start()] + raw_table + "\n\n"
last_end = match.end()
new_text += working_text[last_end:]
working_text = new_text
replace_html_tables()
return working_text, tables
class MarkdownElementExtractor:
def __init__(self, markdown_content):
self.markdown_content = markdown_content
self.lines = markdown_content.split("\n")
def get_delimiters(self,delimiters):
toks = re.findall(r"`([^`]+)`", delimiters)
toks = sorted(set(toks), key=lambda x: -len(x))
return "|".join(re.escape(t) for t in toks if t)
def extract_elements(self,delimiter=None):
"""Extract individual elements (headers, code blocks, lists, etc.)"""
sections = []
i = 0
dels=""
if delimiter:
dels = self.get_delimiters(delimiter)
if len(dels) > 0:
text = "\n".join(self.lines)
parts = re.split(dels, text)
sections = [p.strip() for p in parts if p and p.strip()]
return sections
while i < len(self.lines):
line = self.lines[i]
if re.match(r"^#{1,6}\s+.*$", line):
# header
element = self._extract_header(i)
sections.append(element["content"])
i = element["end_line"] + 1
elif line.strip().startswith("```"):
# code block
element = self._extract_code_block(i)
sections.append(element["content"])
i = element["end_line"] + 1
elif re.match(r"^\s*[-*+]\s+.*$", line) or re.match(r"^\s*\d+\.\s+.*$", line):
# list block
element = self._extract_list_block(i)
sections.append(element["content"])
i = element["end_line"] + 1
elif line.strip().startswith(">"):
# blockquote
element = self._extract_blockquote(i)
sections.append(element["content"])
i = element["end_line"] + 1
elif line.strip():
# text block (paragraphs and inline elements until next block element)
element = self._extract_text_block(i)
sections.append(element["content"])
i = element["end_line"] + 1
else:
i += 1
sections = [section for section in sections if section.strip()]
return sections
def _extract_header(self, start_pos):
return {
"type": "header",
"content": self.lines[start_pos],
"start_line": start_pos,
"end_line": start_pos,
}
def _extract_code_block(self, start_pos):
end_pos = start_pos
content_lines = [self.lines[start_pos]]
# Find the end of the code block
for i in range(start_pos + 1, len(self.lines)):
content_lines.append(self.lines[i])
end_pos = i
if self.lines[i].strip().startswith("```"):
break
return {
"type": "code_block",
"content": "\n".join(content_lines),
"start_line": start_pos,
"end_line": end_pos,
}
def _extract_list_block(self, start_pos):
end_pos = start_pos
content_lines = []
i = start_pos
while i < len(self.lines):
line = self.lines[i]
# check if this line is a list item or continuation of a list
if (
re.match(r"^\s*[-*+]\s+.*$", line)
or re.match(r"^\s*\d+\.\s+.*$", line)
or (i > start_pos and not line.strip())
or (i > start_pos and re.match(r"^\s{2,}[-*+]\s+.*$", line))
or (i > start_pos and re.match(r"^\s{2,}\d+\.\s+.*$", line))
or (i > start_pos and re.match(r"^\s+\w+.*$", line))
):
content_lines.append(line)
end_pos = i
i += 1
else:
break
return {
"type": "list_block",
"content": "\n".join(content_lines),
"start_line": start_pos,
"end_line": end_pos,
}
def _extract_blockquote(self, start_pos):
end_pos = start_pos
content_lines = []
i = start_pos
while i < len(self.lines):
line = self.lines[i]
if line.strip().startswith(">") or (i > start_pos and not line.strip()):
content_lines.append(line)
end_pos = i
i += 1
else:
break
return {
"type": "blockquote",
"content": "\n".join(content_lines),
"start_line": start_pos,
"end_line": end_pos,
}
def _extract_text_block(self, start_pos):
"""Extract a text block (paragraphs, inline elements) until next block element"""
end_pos = start_pos
content_lines = [self.lines[start_pos]]
i = start_pos + 1
while i < len(self.lines):
line = self.lines[i]
# stop if we encounter a block element
if re.match(r"^#{1,6}\s+.*$", line) or line.strip().startswith("```") or re.match(r"^\s*[-*+]\s+.*$", line) or re.match(r"^\s*\d+\.\s+.*$", line) or line.strip().startswith(">"):
break
elif not line.strip():
# check if the next line is a block element
if i + 1 < len(self.lines) and (
re.match(r"^#{1,6}\s+.*$", self.lines[i + 1])
or self.lines[i + 1].strip().startswith("```")
or re.match(r"^\s*[-*+]\s+.*$", self.lines[i + 1])
or re.match(r"^\s*\d+\.\s+.*$", self.lines[i + 1])
or self.lines[i + 1].strip().startswith(">")
):
break
else:
content_lines.append(line)
end_pos = i
i += 1
else:
content_lines.append(line)
end_pos = i
i += 1
return {
"type": "text_block",
"content": "\n".join(content_lines),
"start_line": start_pos,
"end_line": end_pos,
}

View File

@@ -0,0 +1,524 @@
import json
import logging
import os
import platform
import re
import subprocess
import sys
import tempfile
import threading
import time
import zipfile
from io import BytesIO
from os import PathLike
from pathlib import Path
from queue import Empty, Queue
from typing import Any, Callable, Optional
import numpy as np
import pdfplumber
import requests
from PIL import Image
from strenum import StrEnum
from .pdf_parser import RAGPdfParser
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
class MinerUContentType(StrEnum):
IMAGE = "image"
TABLE = "table"
TEXT = "text"
EQUATION = "equation"
CODE = "code"
LIST = "list"
DISCARDED = "discarded"
class MinerUParser(RAGPdfParser):
def __init__(self, mineru_path: str = "mineru", mineru_api: str = "http://host.docker.internal:9987", mineru_server_url: str = ""):
self.mineru_path = Path(mineru_path)
self.mineru_api = mineru_api.rstrip("/")
self.mineru_server_url = mineru_server_url.rstrip("/")
self.using_api = False
self.logger = logging.getLogger(self.__class__.__name__)
def _extract_zip_no_root(self, zip_path, extract_to, root_dir):
with zipfile.ZipFile(zip_path, "r") as zip_ref:
if not root_dir:
files = zip_ref.namelist()
if files and files[0].endswith("/"):
root_dir = files[0]
else:
root_dir = None
if not root_dir or not root_dir.endswith("/"):
self.logger.info(f"[MinerU] No root directory found, extracting all...fff{root_dir}")
zip_ref.extractall(extract_to)
return
root_len = len(root_dir)
for member in zip_ref.infolist():
filename = member.filename
if filename == root_dir:
self.logger.info("[MinerU] Ignore root folder...")
continue
path = filename
if path.startswith(root_dir):
path = path[root_len:]
full_path = os.path.join(extract_to, path)
if member.is_dir():
os.makedirs(full_path, exist_ok=True)
else:
os.makedirs(os.path.dirname(full_path), exist_ok=True)
with open(full_path, "wb") as f:
f.write(zip_ref.read(filename))
def _is_http_endpoint_valid(self, url, timeout=5):
try:
response = requests.head(url, timeout=timeout, allow_redirects=True)
return response.status_code in [200, 301, 302, 307, 308]
except Exception:
return False
def check_installation(self, backend: str = "pipeline", server_url: Optional[str] = None) -> tuple[bool, str]:
reason = ""
valid_backends = ["pipeline", "vlm-http-client", "vlm-transformers", "vlm-vllm-engine"]
if backend not in valid_backends:
reason = "[MinerU] Invalid backend '{backend}'. Valid backends are: {valid_backends}"
logging.warning(reason)
return False, reason
subprocess_kwargs = {
"capture_output": True,
"text": True,
"check": True,
"encoding": "utf-8",
"errors": "ignore",
}
if platform.system() == "Windows":
subprocess_kwargs["creationflags"] = getattr(subprocess, "CREATE_NO_WINDOW", 0)
if server_url is None:
server_url = self.mineru_server_url
if backend == "vlm-http-client" and server_url:
try:
server_accessible = self._is_http_endpoint_valid(server_url + "/openapi.json")
logging.info(f"[MinerU] vlm-http-client server check: {server_accessible}")
if server_accessible:
self.using_api = False # We are using http client, not API
return True, reason
else:
reason = f"[MinerU] vlm-http-client server not accessible: {server_url}"
logging.warning(f"[MinerU] vlm-http-client server not accessible: {server_url}")
return False, reason
except Exception as e:
logging.warning(f"[MinerU] vlm-http-client server check failed: {e}")
try:
response = requests.get(server_url, timeout=5)
logging.info(f"[MinerU] vlm-http-client server connection check: success with status {response.status_code}")
self.using_api = False
return True, reason
except Exception as e:
reason = f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}"
logging.warning(f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}")
return False, reason
try:
result = subprocess.run([str(self.mineru_path), "--version"], **subprocess_kwargs)
version_info = result.stdout.strip()
if version_info:
logging.info(f"[MinerU] Detected version: {version_info}")
else:
logging.info("[MinerU] Detected MinerU, but version info is empty.")
return True, reason
except subprocess.CalledProcessError as e:
logging.warning(f"[MinerU] Execution failed (exit code {e.returncode}).")
except FileNotFoundError:
logging.warning("[MinerU] MinerU not found. Please install it via: pip install -U 'mineru[core]'")
except Exception as e:
logging.error(f"[MinerU] Unexpected error during installation check: {e}")
# If executable check fails, try API check
try:
if self.mineru_api:
# check openapi.json
openapi_exists = self._is_http_endpoint_valid(self.mineru_api + "/openapi.json")
if not openapi_exists:
reason = "[MinerU] Failed to detect vaild MinerU API server"
return openapi_exists, reason
logging.info(f"[MinerU] Detected {self.mineru_api}/openapi.json: {openapi_exists}")
self.using_api = openapi_exists
return openapi_exists, reason
else:
logging.info("[MinerU] api not exists.")
except Exception as e:
reason = f"[MinerU] Unexpected error during api check: {e}"
logging.error(f"[MinerU] Unexpected error during api check: {e}")
return False, reason
def _run_mineru(
self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, server_url: Optional[str] = None, callback: Optional[Callable] = None
):
if self.using_api:
self._run_mineru_api(input_path, output_dir, method, backend, lang, callback)
else:
self._run_mineru_executable(input_path, output_dir, method, backend, lang, server_url, callback)
def _run_mineru_api(self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, callback: Optional[Callable] = None):
OUTPUT_ZIP_PATH = os.path.join(str(output_dir), "output.zip")
pdf_file_path = str(input_path)
if not os.path.exists(pdf_file_path):
raise RuntimeError(f"[MinerU] PDF file not exists: {pdf_file_path}")
pdf_file_name = Path(pdf_file_path).stem.strip()
output_path = os.path.join(str(output_dir), pdf_file_name, method)
os.makedirs(output_path, exist_ok=True)
files = {"files": (pdf_file_name + ".pdf", open(pdf_file_path, "rb"), "application/pdf")}
data = {
"output_dir": "./output",
"lang_list": lang,
"backend": backend,
"parse_method": method,
"formula_enable": True,
"table_enable": True,
"server_url": None,
"return_md": True,
"return_middle_json": True,
"return_model_output": True,
"return_content_list": True,
"return_images": True,
"response_format_zip": True,
"start_page_id": 0,
"end_page_id": 99999,
}
headers = {"Accept": "application/json"}
try:
self.logger.info(f"[MinerU] invoke api: {self.mineru_api}/file_parse")
if callback:
callback(0.20, f"[MinerU] invoke api: {self.mineru_api}/file_parse")
response = requests.post(url=f"{self.mineru_api}/file_parse", files=files, data=data, headers=headers, timeout=1800)
response.raise_for_status()
if response.headers.get("Content-Type") == "application/zip":
self.logger.info(f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...")
if callback:
callback(0.30, f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...")
with open(OUTPUT_ZIP_PATH, "wb") as f:
f.write(response.content)
self.logger.info(f"[MinerU] Unzip to {output_path}...")
self._extract_zip_no_root(OUTPUT_ZIP_PATH, output_path, pdf_file_name + "/")
if callback:
callback(0.40, f"[MinerU] Unzip to {output_path}...")
else:
self.logger.warning("[MinerU] not zip returned from api%s " % response.headers.get("Content-Type"))
except Exception as e:
raise RuntimeError(f"[MinerU] api failed with exception {e}")
self.logger.info("[MinerU] Api completed successfully.")
def _run_mineru_executable(
self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, server_url: Optional[str] = None, callback: Optional[Callable] = None
):
cmd = [str(self.mineru_path), "-p", str(input_path), "-o", str(output_dir), "-m", method]
if backend:
cmd.extend(["-b", backend])
if lang:
cmd.extend(["-l", lang])
if server_url and backend == "vlm-http-client":
cmd.extend(["-u", server_url])
self.logger.info(f"[MinerU] Running command: {' '.join(cmd)}")
subprocess_kwargs = {
"stdout": subprocess.PIPE,
"stderr": subprocess.PIPE,
"text": True,
"encoding": "utf-8",
"errors": "ignore",
"bufsize": 1,
}
if platform.system() == "Windows":
subprocess_kwargs["creationflags"] = getattr(subprocess, "CREATE_NO_WINDOW", 0)
process = subprocess.Popen(cmd, **subprocess_kwargs)
stdout_queue, stderr_queue = Queue(), Queue()
def enqueue_output(pipe, queue, prefix):
for line in iter(pipe.readline, ""):
if line.strip():
queue.put((prefix, line.strip()))
pipe.close()
threading.Thread(target=enqueue_output, args=(process.stdout, stdout_queue, "STDOUT"), daemon=True).start()
threading.Thread(target=enqueue_output, args=(process.stderr, stderr_queue, "STDERR"), daemon=True).start()
while process.poll() is None:
for q in (stdout_queue, stderr_queue):
try:
while True:
prefix, line = q.get_nowait()
if prefix == "STDOUT":
self.logger.info(f"[MinerU] {line}")
else:
self.logger.warning(f"[MinerU] {line}")
except Empty:
pass
time.sleep(0.1)
return_code = process.wait()
if return_code != 0:
raise RuntimeError(f"[MinerU] Process failed with exit code {return_code}")
self.logger.info("[MinerU] Command completed successfully.")
def __images__(self, fnm, zoomin: int = 1, page_from=0, page_to=600, callback=None):
self.page_from = page_from
self.page_to = page_to
try:
with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf:
self.pdf = pdf
self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])]
except Exception as e:
self.page_images = None
self.total_page = 0
logging.exception(e)
def _line_tag(self, bx):
pn = [bx["page_idx"] + 1]
positions = bx["bbox"]
x0, top, x1, bott = positions
if hasattr(self, "page_images") and self.page_images and len(self.page_images) > bx["page_idx"]:
page_width, page_height = self.page_images[bx["page_idx"]].size
x0 = (x0 / 1000.0) * page_width
x1 = (x1 / 1000.0) * page_width
top = (top / 1000.0) * page_height
bott = (bott / 1000.0) * page_height
return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##".format("-".join([str(p) for p in pn]), x0, x1, top, bott)
def crop(self, text, ZM=1, need_position=False):
imgs = []
poss = self.extract_positions(text)
if not poss:
if need_position:
return None, None
return
max_width = max(np.max([right - left for (_, left, right, _, _) in poss]), 6)
GAP = 6
pos = poss[0]
poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
pos = poss[-1]
poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1], pos[4] + GAP), min(self.page_images[pos[0][-1]].size[1], pos[4] + 120)))
positions = []
for ii, (pns, left, right, top, bottom) in enumerate(poss):
right = left + max_width
if bottom <= top:
bottom = top + 2
for pn in pns[1:]:
bottom += self.page_images[pn - 1].size[1]
img0 = self.page_images[pns[0]]
x0, y0, x1, y1 = int(left), int(top), int(right), int(min(bottom, img0.size[1]))
crop0 = img0.crop((x0, y0, x1, y1))
imgs.append(crop0)
if 0 < ii < len(poss) - 1:
positions.append((pns[0] + self.page_from, x0, x1, y0, y1))
bottom -= img0.size[1]
for pn in pns[1:]:
page = self.page_images[pn]
x0, y0, x1, y1 = int(left), 0, int(right), int(min(bottom, page.size[1]))
cimgp = page.crop((x0, y0, x1, y1))
imgs.append(cimgp)
if 0 < ii < len(poss) - 1:
positions.append((pn + self.page_from, x0, x1, y0, y1))
bottom -= page.size[1]
if not imgs:
if need_position:
return None, None
return
height = 0
for img in imgs:
height += img.size[1] + GAP
height = int(height)
width = int(np.max([i.size[0] for i in imgs]))
pic = Image.new("RGB", (width, height), (245, 245, 245))
height = 0
for ii, img in enumerate(imgs):
if ii == 0 or ii + 1 == len(imgs):
img = img.convert("RGBA")
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
overlay.putalpha(128)
img = Image.alpha_composite(img, overlay).convert("RGB")
pic.paste(img, (0, int(height)))
height += img.size[1] + GAP
if need_position:
return pic, positions
return pic
@staticmethod
def extract_positions(txt: str):
poss = []
for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt):
pn, left, right, top, bottom = tag.strip("#").strip("@").split("\t")
left, right, top, bottom = float(left), float(right), float(top), float(bottom)
poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom))
return poss
def _read_output(self, output_dir: Path, file_stem: str, method: str = "auto", backend: str = "pipeline") -> list[dict[str, Any]]:
subdir = output_dir / file_stem / method
if backend.startswith("vlm-"):
subdir = output_dir / file_stem / "vlm"
json_file = subdir / f"{file_stem}_content_list.json"
if not json_file.exists():
raise FileNotFoundError(f"[MinerU] Missing output file: {json_file}")
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
for item in data:
for key in ("img_path", "table_img_path", "equation_img_path"):
if key in item and item[key]:
item[key] = str((subdir / item[key]).resolve())
return data
def _transfer_to_sections(self, outputs: list[dict[str, Any]]):
sections = []
for output in outputs:
match output["type"]:
case MinerUContentType.TEXT:
section = output["text"]
case MinerUContentType.TABLE:
section = output.get("table_body", "") + "\n".join(output.get("table_caption", [])) + "\n".join(output.get("table_footnote", []))
if not section.strip():
section = "FAILED TO PARSE TABLE"
case MinerUContentType.IMAGE:
section = "".join(output["image_caption"]) + "\n" + "".join(output["image_footnote"])
case MinerUContentType.EQUATION:
section = output["text"]
case MinerUContentType.CODE:
section = output["code_body"] + "\n".join(output.get("code_caption", []))
case MinerUContentType.LIST:
section = "\n".join(output.get("list_items", []))
case MinerUContentType.DISCARDED:
pass
if section:
sections.append((section, self._line_tag(output)))
return sections
def _transfer_to_tables(self, outputs: list[dict[str, Any]]):
return []
def parse_pdf(
self,
filepath: str | PathLike[str],
binary: BytesIO | bytes,
callback: Optional[Callable] = None,
*,
output_dir: Optional[str] = None,
backend: str = "pipeline",
lang: Optional[str] = None,
method: str = "auto",
server_url: Optional[str] = None,
delete_output: bool = True,
) -> tuple:
import shutil
temp_pdf = None
created_tmp_dir = False
# remove spaces, or mineru crash, and _read_output fail too
file_path = Path(filepath)
pdf_file_name = file_path.stem.replace(" ", "") + ".pdf"
pdf_file_path_valid = os.path.join(file_path.parent, pdf_file_name)
if binary:
temp_dir = Path(tempfile.mkdtemp(prefix="mineru_bin_pdf_"))
temp_pdf = temp_dir / pdf_file_name
with open(temp_pdf, "wb") as f:
f.write(binary)
pdf = temp_pdf
self.logger.info(f"[MinerU] Received binary PDF -> {temp_pdf}")
if callback:
callback(0.15, f"[MinerU] Received binary PDF -> {temp_pdf}")
else:
if pdf_file_path_valid != filepath:
self.logger.info(f"[MinerU] Remove all space in file name: {pdf_file_path_valid}")
shutil.move(filepath, pdf_file_path_valid)
pdf = Path(pdf_file_path_valid)
if not pdf.exists():
if callback:
callback(-1, f"[MinerU] PDF not found: {pdf}")
raise FileNotFoundError(f"[MinerU] PDF not found: {pdf}")
if output_dir:
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
else:
out_dir = Path(tempfile.mkdtemp(prefix="mineru_pdf_"))
created_tmp_dir = True
self.logger.info(f"[MinerU] Output directory: {out_dir}")
if callback:
callback(0.15, f"[MinerU] Output directory: {out_dir}")
self.__images__(pdf, zoomin=1)
try:
self._run_mineru(pdf, out_dir, method=method, backend=backend, lang=lang, server_url=server_url, callback=callback)
outputs = self._read_output(out_dir, pdf.stem, method=method, backend=backend)
self.logger.info(f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
if callback:
callback(0.75, f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
return self._transfer_to_sections(outputs), self._transfer_to_tables(outputs)
finally:
if temp_pdf and temp_pdf.exists():
try:
temp_pdf.unlink()
temp_pdf.parent.rmdir()
except Exception:
pass
if delete_output and created_tmp_dir and out_dir.exists():
try:
shutil.rmtree(out_dir)
except Exception:
pass
if __name__ == "__main__":
parser = MinerUParser("mineru")
ok, reason = parser.check_installation()
print("MinerU available:", ok)
filepath = ""
with open(filepath, "rb") as file:
outputs = parser.parse_pdf(filepath=filepath, binary=file.read())
for output in outputs:
print(output)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,83 @@
import logging
from io import BytesIO
from pptx import Presentation
class RAGPptParser:
def __init__(self):
super().__init__()
def __get_bulleted_text(self, paragraph):
is_bulleted = bool(paragraph._p.xpath("./a:pPr/a:buChar")) or bool(paragraph._p.xpath("./a:pPr/a:buAutoNum")) or bool(paragraph._p.xpath("./a:pPr/a:buBlip"))
if is_bulleted:
return f"{' '* paragraph.level}.{paragraph.text}"
else:
return paragraph.text
def __extract(self, shape):
try:
# First try to get text content
if hasattr(shape, 'has_text_frame') and shape.has_text_frame:
text_frame = shape.text_frame
texts = []
for paragraph in text_frame.paragraphs:
if paragraph.text.strip():
texts.append(self.__get_bulleted_text(paragraph))
return "\n".join(texts)
# Safely get shape_type
try:
shape_type = shape.shape_type
except NotImplementedError:
# If shape_type is not available, try to get text content
if hasattr(shape, 'text'):
return shape.text.strip()
return ""
# Handle table
if shape_type == 19:
tb = shape.table
rows = []
for i in range(1, len(tb.rows)):
rows.append("; ".join([tb.cell(
0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
return "\n".join(rows)
# Handle group shape
if shape_type == 6:
texts = []
for p in sorted(shape.shapes, key=lambda x: (x.top // 10, x.left)):
t = self.__extract(p)
if t:
texts.append(t)
return "\n".join(texts)
return ""
except Exception as e:
logging.error(f"Error processing shape: {str(e)}")
return ""
def __call__(self, fnm, from_page, to_page, callback=None):
ppt = Presentation(fnm) if isinstance(
fnm, str) else Presentation(
BytesIO(fnm))
txts = []
self.total_page = len(ppt.slides)
for i, slide in enumerate(ppt.slides):
if i < from_page:
continue
if i >= to_page:
break
texts = []
for shape in sorted(
slide.shapes, key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left if x.left is not None else 0)):
try:
txt = self.__extract(shape)
if txt:
texts.append(txt)
except Exception as e:
logging.exception(e)
txts.append("\n".join(texts))
return txts

View File

@@ -0,0 +1,48 @@
import re
from .utils import get_text
from app.core.rag.common.token_utils import num_tokens_from_string
class RAGTxtParser:
def __call__(self, fnm, binary=None, chunk_token_num=128, delimiter="\n!?;。;!?"):
txt = get_text(fnm, binary)
return self.parser_txt(txt, chunk_token_num, delimiter)
@classmethod
def parser_txt(cls, txt, chunk_token_num=128, delimiter="\n!?;。;!?"):
if not isinstance(txt, str):
raise TypeError("txt type should be str!")
cks = [""]
tk_nums = [0]
delimiter = delimiter.encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8')
def add_chunk(t):
nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if tk_nums[-1] > chunk_token_num:
cks.append(t)
tk_nums.append(tnum)
else:
cks[-1] += t
tk_nums[-1] += tnum
dels = []
s = 0
for m in re.finditer(r"`([^`]+)`", delimiter, re.I):
f, t = m.span()
dels.append(m.group(1))
dels.extend(list(delimiter[s: f]))
s = t
if s < len(delimiter):
dels.extend(list(delimiter[s:]))
dels = [re.escape(d) for d in dels if d]
dels = [d for d in dels if d]
dels = "|".join(dels)
secs = re.split(r"(%s)" % dels, txt)
for sec in secs:
if re.match(f"^{dels}$", sec):
continue
add_chunk(sec)
return [[c, ""] for c in cks]

View File

@@ -0,0 +1,16 @@
from app.core.rag.nlp import find_codec
def get_text(fnm: str, binary=None) -> str:
txt = ""
if binary:
encoding = find_codec(binary)
txt = binary.decode(encoding, errors="ignore")
else:
with open(fnm, "r") as f:
while True:
line = f.readline()
if not line:
break
txt += line
return txt

View File

@@ -0,0 +1,75 @@
import io
import sys
import threading
import pdfplumber
from .ocr import OCR
from .recognizer import Recognizer
from .layout_recognizer import AscendLayoutRecognizer
from .layout_recognizer import LayoutRecognizer4YOLOv10 as LayoutRecognizer
from .table_structure_recognizer import TableStructureRecognizer
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
def init_in_out(args):
import os
import traceback
from PIL import Image
from app.core.rag.common.file_utils import traversal_files
images = []
outputs = []
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
def pdf_pages(fnm, zoomin=3):
nonlocal outputs, images
with sys.modules[LOCK_KEY_pdfplumber]:
pdf = pdfplumber.open(fnm)
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in enumerate(pdf.pages)]
for i, page in enumerate(images):
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")
pdf.close()
def images_and_outputs(fnm):
nonlocal outputs, images
if fnm.split(".")[-1].lower() == "pdf":
pdf_pages(fnm)
return
try:
fp = open(fnm, "rb")
binary = fp.read()
fp.close()
images.append(Image.open(io.BytesIO(binary)).convert("RGB"))
outputs.append(os.path.split(fnm)[-1])
except Exception:
traceback.print_exc()
if os.path.isdir(args.inputs):
for fnm in traversal_files(args.inputs):
images_and_outputs(fnm)
else:
images_and_outputs(args.inputs)
for i in range(len(outputs)):
outputs[i] = os.path.join(args.output_dir, outputs[i])
return images, outputs
__all__ = [
"OCR",
"Recognizer",
"LayoutRecognizer",
"AscendLayoutRecognizer",
"TableStructureRecognizer",
"init_in_out",
]

View File

@@ -0,0 +1,440 @@
import logging
import math
import os
import re
from collections import Counter
from copy import deepcopy
import cv2
import numpy as np
from huggingface_hub import snapshot_download
from app.core.rag.common.file_utils import get_project_base_directory
from . import Recognizer
from .operators import nms
class LayoutRecognizer(Recognizer):
labels = [
"_background_",
"Text",
"Title",
"Figure",
"Figure caption",
"Table",
"Table caption",
"Header",
"Footer",
"Reference",
"Equation",
]
def __init__(self, domain):
try:
model_dir = os.path.join(get_project_base_directory(), "res/deepdoc")
super().__init__(self.labels, domain, model_dir)
except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "res/deepdoc"), local_dir_use_symlinks=False)
super().__init__(self.labels, domain, model_dir)
self.garbage_layouts = ["footer", "header", "reference"]
self.client = None
if os.environ.get("TENSORRT_DLA_SVR"):
from deepdoc.vision.dla_cli import DLAClient
self.client = DLAClient(os.environ["TENSORRT_DLA_SVR"])
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
def __is_garbage(b):
patt = [r"^•+$", "^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", "\\(cid *: *[0-9]+ *\\)"]
return any([re.search(p, b["text"]) for p in patt])
if self.client:
layouts = self.client.predict(image_list)
else:
layouts = super().__call__(image_list, thr, batch_size)
# save_results(image_list, layouts, self.labels, output_dir='output/', threshold=0.7)
assert len(image_list) == len(ocr_res)
# Tag layout type
boxes = []
assert len(image_list) == len(layouts)
garbages = {}
page_layout = []
for pn, lts in enumerate(layouts):
bxs = ocr_res[pn]
lts = [
{
"type": b["type"],
"score": float(b["score"]),
"x0": b["bbox"][0] / scale_factor,
"x1": b["bbox"][2] / scale_factor,
"top": b["bbox"][1] / scale_factor,
"bottom": b["bbox"][-1] / scale_factor,
"page_number": pn,
}
for b in lts
if float(b["score"]) >= 0.4 or b["type"] not in self.garbage_layouts
]
lts = self.sort_Y_firstly(lts, np.mean([lt["bottom"] - lt["top"] for lt in lts]) / 2)
lts = self.layouts_cleanup(bxs, lts)
page_layout.append(lts)
def findLayout(ty):
nonlocal bxs, lts, self
lts_ = [lt for lt in lts if lt["type"] == ty]
i = 0
while i < len(bxs):
if bxs[i].get("layout_type"):
i += 1
continue
if __is_garbage(bxs[i]):
bxs.pop(i)
continue
ii = self.find_overlapped_with_threshold(bxs[i], lts_, thr=0.4)
if ii is None:
bxs[i]["layout_type"] = ""
i += 1
continue
lts_[ii]["visited"] = True
keep_feats = [
lts_[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor,
lts_[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor,
]
if drop and lts_[ii]["type"] in self.garbage_layouts and not any(keep_feats):
if lts_[ii]["type"] not in garbages:
garbages[lts_[ii]["type"]] = []
garbages[lts_[ii]["type"]].append(bxs[i]["text"])
bxs.pop(i)
continue
bxs[i]["layoutno"] = f"{ty}-{ii}"
bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[ii]["type"] != "equation" else "figure"
i += 1
for lt in ["footer", "header", "reference", "figure caption", "table caption", "title", "table", "text", "figure", "equation"]:
findLayout(lt)
# add box to figure layouts which has not text box
for i, lt in enumerate([lt for lt in lts if lt["type"] in ["figure", "equation"]]):
if lt.get("visited"):
continue
lt = deepcopy(lt)
del lt["type"]
lt["text"] = ""
lt["layout_type"] = "figure"
lt["layoutno"] = f"figure-{i}"
bxs.append(lt)
boxes.extend(bxs)
ocr_res = boxes
garbag_set = set()
for k in garbages.keys():
garbages[k] = Counter(garbages[k])
for g, c in garbages[k].items():
if c > 1:
garbag_set.add(g)
ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
return ocr_res, page_layout
def forward(self, image_list, thr=0.7, batch_size=16):
return super().__call__(image_list, thr, batch_size)
class LayoutRecognizer4YOLOv10(LayoutRecognizer):
labels = [
"title",
"Text",
"Reference",
"Figure",
"Figure caption",
"Table",
"Table caption",
"Table caption",
"Equation",
"Figure caption",
]
def __init__(self, domain):
domain = "layout"
super().__init__(domain)
self.auto = False
self.scaleFill = False
self.scaleup = True
self.stride = 32
self.center = True
def preprocess(self, image_list):
inputs = []
new_shape = self.input_shape # height, width
for img in image_list:
shape = img.shape[:2] # current shape [height, width]
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
ww, hh = new_unpad
img = np.array(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).astype(np.float32)
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) # add border
img /= 255.0
img = img.transpose(2, 0, 1)
img = img[np.newaxis, :, :, :].astype(np.float32)
inputs.append({self.input_names[0]: img, "scale_factor": [shape[1] / ww, shape[0] / hh, dw, dh]})
return inputs
def postprocess(self, boxes, inputs, thr):
thr = 0.08
boxes = np.squeeze(boxes)
scores = boxes[:, 4]
boxes = boxes[scores > thr, :]
scores = scores[scores > thr]
if len(boxes) == 0:
return []
class_ids = boxes[:, -1].astype(int)
boxes = boxes[:, :4]
boxes[:, 0] -= inputs["scale_factor"][2]
boxes[:, 2] -= inputs["scale_factor"][2]
boxes[:, 1] -= inputs["scale_factor"][3]
boxes[:, 3] -= inputs["scale_factor"][3]
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]])
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
unique_class_ids = np.unique(class_ids)
indices = []
for class_id in unique_class_ids:
class_indices = np.where(class_ids == class_id)[0]
class_boxes = boxes[class_indices, :]
class_scores = scores[class_indices]
class_keep_boxes = nms(class_boxes, class_scores, 0.45)
indices.extend(class_indices[class_keep_boxes])
return [{"type": self.label_list[class_ids[i]].lower(), "bbox": [float(t) for t in boxes[i].tolist()], "score": float(scores[i])} for i in indices]
class AscendLayoutRecognizer(Recognizer):
labels = [
"title",
"Text",
"Reference",
"Figure",
"Figure caption",
"Table",
"Table caption",
"Table caption",
"Equation",
"Figure caption",
]
def __init__(self, domain):
from ais_bench.infer.interface import InferSession
model_dir = os.path.join(get_project_base_directory(), "res/deepdoc")
model_file_path = os.path.join(model_dir, domain + ".om")
if not os.path.exists(model_file_path):
raise ValueError(f"Model file not found: {model_file_path}")
device_id = int(os.getenv("ASCEND_LAYOUT_RECOGNIZER_DEVICE_ID", 0))
self.session = InferSession(device_id=device_id, model_path=model_file_path)
self.input_shape = self.session.get_inputs()[0].shape[2:4] # H,W
self.garbage_layouts = ["footer", "header", "reference"]
def preprocess(self, image_list):
inputs = []
H, W = self.input_shape
for img in image_list:
h, w = img.shape[:2]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
r = min(H / h, W / w)
new_unpad = (int(round(w * r)), int(round(h * r)))
dw, dh = (W - new_unpad[0]) / 2.0, (H - new_unpad[1]) / 2.0
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
img /= 255.0
img = img.transpose(2, 0, 1)[np.newaxis, :, :, :].astype(np.float32)
inputs.append(
{
"image": img,
"scale_factor": [w / new_unpad[0], h / new_unpad[1]],
"pad": [dw, dh],
"orig_shape": [h, w],
}
)
return inputs
def postprocess(self, boxes, inputs, thr=0.25):
arr = np.squeeze(boxes)
if arr.ndim == 1:
arr = arr.reshape(1, -1)
results = []
if arr.shape[1] == 6:
# [x1,y1,x2,y2,score,cls]
m = arr[:, 4] >= thr
arr = arr[m]
if arr.size == 0:
return []
xyxy = arr[:, :4].astype(np.float32)
scores = arr[:, 4].astype(np.float32)
cls_ids = arr[:, 5].astype(np.int32)
if "pad" in inputs:
dw, dh = inputs["pad"]
sx, sy = inputs["scale_factor"]
xyxy[:, [0, 2]] -= dw
xyxy[:, [1, 3]] -= dh
xyxy *= np.array([sx, sy, sx, sy], dtype=np.float32)
else:
# backup
sx, sy = inputs["scale_factor"]
xyxy *= np.array([sx, sy, sx, sy], dtype=np.float32)
keep_indices = []
for c in np.unique(cls_ids):
idx = np.where(cls_ids == c)[0]
k = nms(xyxy[idx], scores[idx], 0.45)
keep_indices.extend(idx[k])
for i in keep_indices:
cid = int(cls_ids[i])
if 0 <= cid < len(self.labels):
results.append({"type": self.labels[cid].lower(), "bbox": [float(t) for t in xyxy[i].tolist()], "score": float(scores[i])})
return results
raise ValueError(f"Unexpected output shape: {arr.shape}")
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
import re
from collections import Counter
assert len(image_list) == len(ocr_res)
images = [np.array(im) if not isinstance(im, np.ndarray) else im for im in image_list]
layouts_all_pages = [] # list of list[{"type","score","bbox":[x1,y1,x2,y2]}]
conf_thr = max(thr, 0.08)
batch_loop_cnt = math.ceil(float(len(images)) / batch_size)
for bi in range(batch_loop_cnt):
s = bi * batch_size
e = min((bi + 1) * batch_size, len(images))
batch_images = images[s:e]
inputs_list = self.preprocess(batch_images)
logging.debug("preprocess done")
for ins in inputs_list:
feeds = [ins["image"]]
out_list = self.session.infer(feeds=feeds, mode="static")
for out in out_list:
lts = self.postprocess(out, ins, conf_thr)
page_lts = []
for b in lts:
if float(b["score"]) >= 0.4 or b["type"] not in self.garbage_layouts:
x0, y0, x1, y1 = b["bbox"]
page_lts.append(
{
"type": b["type"],
"score": float(b["score"]),
"x0": float(x0) / scale_factor,
"x1": float(x1) / scale_factor,
"top": float(y0) / scale_factor,
"bottom": float(y1) / scale_factor,
"page_number": len(layouts_all_pages),
}
)
layouts_all_pages.append(page_lts)
def _is_garbage_text(box):
patt = [r"^•+$", r"^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", r"^http://[^ ]{12,}", r"\(cid *: *[0-9]+ *\)"]
return any(re.search(p, box.get("text", "")) for p in patt)
boxes_out = []
page_layout = []
garbages = {}
for pn, lts in enumerate(layouts_all_pages):
if lts:
avg_h = np.mean([lt["bottom"] - lt["top"] for lt in lts])
lts = self.sort_Y_firstly(lts, avg_h / 2 if avg_h > 0 else 0)
bxs = ocr_res[pn]
lts = self.layouts_cleanup(bxs, lts)
page_layout.append(lts)
def _tag_layout(ty):
nonlocal bxs, lts
lts_of_ty = [lt for lt in lts if lt["type"] == ty]
i = 0
while i < len(bxs):
if bxs[i].get("layout_type"):
i += 1
continue
if _is_garbage_text(bxs[i]):
bxs.pop(i)
continue
ii = self.find_overlapped_with_threshold(bxs[i], lts_of_ty, thr=0.4)
if ii is None:
bxs[i]["layout_type"] = ""
i += 1
continue
lts_of_ty[ii]["visited"] = True
keep_feats = [
lts_of_ty[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].shape[0] * 0.9 / scale_factor,
lts_of_ty[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].shape[0] * 0.1 / scale_factor,
]
if drop and lts_of_ty[ii]["type"] in self.garbage_layouts and not any(keep_feats):
garbages.setdefault(lts_of_ty[ii]["type"], []).append(bxs[i].get("text", ""))
bxs.pop(i)
continue
bxs[i]["layoutno"] = f"{ty}-{ii}"
bxs[i]["layout_type"] = lts_of_ty[ii]["type"] if lts_of_ty[ii]["type"] != "equation" else "figure"
i += 1
for ty in ["footer", "header", "reference", "figure caption", "table caption", "title", "table", "text", "figure", "equation"]:
_tag_layout(ty)
figs = [lt for lt in lts if lt["type"] in ["figure", "equation"]]
for i, lt in enumerate(figs):
if lt.get("visited"):
continue
lt = deepcopy(lt)
lt.pop("type", None)
lt["text"] = ""
lt["layout_type"] = "figure"
lt["layoutno"] = f"figure-{i}"
bxs.append(lt)
boxes_out.extend(bxs)
garbag_set = set()
for k, lst in garbages.items():
cnt = Counter(lst)
for g, c in cnt.items():
if c > 1:
garbag_set.add(g)
ocr_res_new = [b for b in boxes_out if b["text"].strip() not in garbag_set]
return ocr_res_new, page_layout

View File

@@ -0,0 +1,737 @@
import gc
import logging
import copy
import time
import os
from huggingface_hub import snapshot_download
from app.core.rag.common.file_utils import get_project_base_directory
from app.core.rag.common.misc_utils import pip_install_torch
from app.core.rag.common import settings
from .operators import * # noqa: F403
from . import operators
import math
import numpy as np
import cv2
import onnxruntime as ort
from .postprocess import build_post_process
loaded_models = {}
def transform(data, ops=None):
""" transform """
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
def create_operators(op_param_list, global_config=None):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(
op_param_list, list), ('operator config should be a list')
ops = []
for operator in op_param_list:
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
param.update(global_config)
op = getattr(operators, op_name)(**param)
ops.append(op)
return ops
def load_model(model_dir, nm, device_id: int | None = None):
model_file_path = os.path.join(model_dir, nm + ".onnx")
model_cached_tag = model_file_path + str(device_id) if device_id is not None else model_file_path
global loaded_models
loaded_model = loaded_models.get(model_cached_tag)
if loaded_model:
logging.info(f"load_model {model_file_path} reuses cached model")
return loaded_model
if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(
model_file_path))
def cuda_is_available():
try:
pip_install_torch()
import torch
target_id = 0 if device_id is None else device_id
if torch.cuda.is_available() and torch.cuda.device_count() > target_id:
return True
except Exception:
return False
return False
options = ort.SessionOptions()
options.enable_cpu_mem_arena = False
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
options.intra_op_num_threads = 2
options.inter_op_num_threads = 2
# https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580
# Shrink GPU memory after execution
run_options = ort.RunOptions()
if cuda_is_available():
gpu_mem_limit_mb = int(os.environ.get("OCR_GPU_MEM_LIMIT_MB", "2048"))
arena_strategy = os.environ.get("OCR_ARENA_EXTEND_STRATEGY", "kNextPowerOfTwo")
provider_device_id = 0 if device_id is None else device_id
cuda_provider_options = {
"device_id": provider_device_id, # Use specific GPU
"gpu_mem_limit": max(gpu_mem_limit_mb, 0) * 1024 * 1024,
"arena_extend_strategy": arena_strategy, # gpu memory allocation strategy
}
sess = ort.InferenceSession(
model_file_path,
options=options,
providers=['CUDAExecutionProvider'],
provider_options=[cuda_provider_options]
)
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(provider_device_id))
logging.info(f"load_model {model_file_path} uses GPU (device {provider_device_id}, gpu_mem_limit={cuda_provider_options['gpu_mem_limit']}, arena_strategy={arena_strategy})")
else:
sess = ort.InferenceSession(
model_file_path,
options=options,
providers=['CPUExecutionProvider'])
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu")
logging.info(f"load_model {model_file_path} uses CPU")
loaded_model = (sess, run_options)
loaded_models[model_cached_tag] = loaded_model
return loaded_model
class TextRecognizer:
def __init__(self, model_dir, device_id: int | None = None):
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
self.rec_batch_num = 16
postprocess_params = {
'name': 'CTCLabelDecode',
"character_dict_path": os.path.join(model_dir, "ocr.res"),
"use_space_char": True
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.run_options = load_model(model_dir, 'rec', device_id)
self.input_tensor = self.predictor.get_inputs()[0]
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2]
imgW = int((imgH * max_wh_ratio))
w = self.input_tensor.shape[3:][0]
if isinstance(w, str):
pass
elif w is not None and w > 0:
imgW = w
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_vl(self, img, image_shape):
imgC, imgH, imgW = image_shape
img = img[:, :, ::-1] # bgr2rgb
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
return resized_image
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(
gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(
gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
self.srn_other_inputs(image_shape, num_heads, max_text_length)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
encoder_word_pos = encoder_word_pos.astype(np.int64)
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def resize_norm_img_sar(self, img, image_shape,
width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h)
resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype('float32')
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape
return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img_spin(self, img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
img = np.array(img, np.float32)
img = np.expand_dims(img, -1)
img = img.transpose((2, 0, 1))
mean = [127.5]
std = [127.5]
mean = np.array(mean, dtype=np.float32)
std = np.array(std, dtype=np.float32)
mean = np.float32(mean.reshape(1, -1))
stdinv = 1 / np.float32(std.reshape(1, -1))
img -= mean
img *= stdinv
return img
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
def resize_norm_img_abinet(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image / 255.
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
resized_image = (
resized_image - mean[None, None, ...]) / std[None, None, ...]
resized_image = resized_image.transpose((2, 0, 1))
resized_image = resized_image.astype('float32')
return resized_image
def norm_img_can(self, img, image_shape):
img = cv2.cvtColor(
img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
if self.rec_image_shape[0] == 1:
h, w = img.shape
_, imgH, imgW = self.rec_image_shape
if h < imgH or w < imgW:
padding_h = max(imgH - h, 0)
padding_w = max(imgW - w, 0)
img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
'constant',
constant_values=(255))
img = img_padded
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
img = img.astype('float32')
return img
def close(self):
# close session and release manually
logging.info('Close text recognizer.')
if hasattr(self, "predictor"):
del self.predictor
gc.collect()
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
st = time.time()
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
# max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
for i in range(100000):
try:
outputs = self.predictor.run(None, input_dict, self.run_options)
break
except Exception as e:
if i >= 3:
raise e
time.sleep(5)
preds = outputs[0]
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
return rec_res, time.time() - st
def __del__(self):
self.close()
class TextDetector:
def __init__(self, model_dir, device_id: int | None = None):
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': 960,
'limit_type': "max",
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image', 'shape']
}
}]
postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.5, "max_candidates": 1000,
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.run_options = load_model(model_dir, 'det', device_id)
self.input_tensor = self.predictor.get_inputs()[0]
img_h, img_w = self.input_tensor.shape[2:]
if isinstance(img_h, str) or isinstance(img_w, str):
pass
elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
pre_process_list[0] = {
'DetResizeForTest': {
'image_shape': [img_h, img_w]
}
}
self.preprocess_op = create_operators(pre_process_list)
def order_points_clockwise(self, pts):
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
diff = np.diff(np.array(tmp), axis=1)
rect[1] = tmp[np.argmin(diff)]
rect[3] = tmp[np.argmax(diff)]
return rect
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if isinstance(box, list):
box = np.array(box)
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if isinstance(box, list):
box = np.array(box)
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def close(self):
logging.info("Close text detector.")
if hasattr(self, "predictor"):
del self.predictor
gc.collect()
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
st = time.time()
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
input_dict = {}
input_dict[self.input_tensor.name] = img
for i in range(100000):
try:
outputs = self.predictor.run(None, input_dict, self.run_options)
break
except Exception as e:
if i >= 3:
raise e
time.sleep(5)
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
dt_boxes = post_result[0]['points']
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
return dt_boxes, time.time() - st
def __del__(self):
self.close()
class OCR:
def __init__(self, model_dir=None):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not model_dir:
try:
model_dir = os.path.join(
get_project_base_directory(),
"res/deepdoc")
# Append muti-gpus task to the list
if settings.PARALLEL_DEVICES > 0:
self.text_detector = []
self.text_recognizer = []
for device_id in range(settings.PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:
self.text_detector = [TextDetector(model_dir)]
self.text_recognizer = [TextRecognizer(model_dir)]
except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "res/deepdoc"),
local_dir_use_symlinks=False)
if settings.PARALLEL_DEVICES > 0:
self.text_detector = []
self.text_recognizer = []
for device_id in range(settings.PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:
self.text_detector = [TextDetector(model_dir)]
self.text_recognizer = [TextRecognizer(model_dir)]
self.drop_score = 0.5
self.crop_image_res_index = 0
def get_rotate_crop_image(self, img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
# Try original orientation
rec_result = self.text_recognizer[0]([dst_img])
text, score = rec_result[0][0]
best_score = score
best_img = dst_img
# Try clockwise 90° rotation
rotated_cw = np.rot90(dst_img, k=3)
rec_result = self.text_recognizer[0]([rotated_cw])
rotated_cw_text, rotated_cw_score = rec_result[0][0]
if rotated_cw_score > best_score:
best_score = rotated_cw_score
best_img = rotated_cw
# Try counter-clockwise 90° rotation
rotated_ccw = np.rot90(dst_img, k=1)
rec_result = self.text_recognizer[0]([rotated_ccw])
rotated_ccw_text, rotated_ccw_score = rec_result[0][0]
if rotated_ccw_score > best_score:
best_img = rotated_ccw
# Use the best image
dst_img = best_img
return dst_img
def sorted_boxes(self, dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
def detect(self, img, device_id: int | None = None):
if device_id is None:
device_id = 0
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
if img is None:
return None, None, time_dict
start = time.time()
dt_boxes, elapse = self.text_detector[device_id](img)
time_dict['det'] = elapse
if dt_boxes is None:
end = time.time()
time_dict['all'] = end - start
return None, None, time_dict
return zip(self.sorted_boxes(dt_boxes), [
("", 0) for _ in range(len(dt_boxes))])
def recognize(self, ori_im, box, device_id: int | None = None):
if device_id is None:
device_id = 0
img_crop = self.get_rotate_crop_image(ori_im, box)
rec_res, elapse = self.text_recognizer[device_id]([img_crop])
text, score = rec_res[0]
if score < self.drop_score:
return ""
return text
def recognize_batch(self, img_list, device_id: int | None = None):
if device_id is None:
device_id = 0
rec_res, elapse = self.text_recognizer[device_id](img_list)
texts = []
for i in range(len(rec_res)):
text, score = rec_res[i]
if score < self.drop_score:
text = ""
texts.append(text)
return texts
def __call__(self, img, device_id = 0, cls=True):
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
if device_id is None:
device_id = 0
if img is None:
return None, None, time_dict
start = time.time()
ori_im = img.copy()
dt_boxes, elapse = self.text_detector[device_id](img)
time_dict['det'] = elapse
if dt_boxes is None:
end = time.time()
time_dict['all'] = end - start
return None, None, time_dict
img_crop_list = []
dt_boxes = self.sorted_boxes(dt_boxes)
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop)
rec_res, elapse = self.text_recognizer[device_id](img_crop_list)
time_dict['rec'] = elapse
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
end = time.time()
time_dict['all'] = end - start
# for bno in range(len(img_crop_list)):
# print(f"{bno}, {rec_res[bno]}")
return list(zip([a.tolist() for a in filter_boxes], filter_rec_res))

View File

@@ -0,0 +1,709 @@
import logging
import sys
import six
import cv2
import numpy as np
import math
from PIL import Image
class DecodeImage:
""" decode image """
def __init__(self,
img_mode='RGB',
channel_first=False,
ignore_orientation=False,
**kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
self.ignore_orientation = ignore_orientation
def __call__(self, data):
img = data['image']
if six.PY2:
assert isinstance(img, str) and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert isinstance(img, bytes) and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
if self.ignore_orientation:
img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
cv2.IMREAD_COLOR)
else:
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class StandardizeImag:
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_scale:
scale = 1.0 / 255.0
im *= scale
if self.norm_type == 'mean_std':
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
class NormalizeImage:
""" normalize image such as subtract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
return data
class ToCHWImage:
""" convert hwc image to chw image
"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
data['image'] = img.transpose((2, 0, 1))
return data
class KeepKeys:
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
class Pad:
def __init__(self, size=None, size_div=32, **kwargs):
if size is not None and not isinstance(size, (int, list, tuple)):
raise TypeError("Type of target_size is invalid. Now is {}".format(
type(size)))
if isinstance(size, int):
size = [size, size]
self.size = size
self.size_div = size_div
def __call__(self, data):
img = data['image']
img_h, img_w = img.shape[0], img.shape[1]
if self.size:
resize_h2, resize_w2 = self.size
assert (
img_h < resize_h2 and img_w < resize_w2
), '(h, w) of target size should be greater than (img_h, img_w)'
else:
resize_h2 = max(
int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
self.size_div)
resize_w2 = max(
int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
self.size_div)
img = cv2.copyMakeBorder(
img,
0,
resize_h2 - img_h,
0,
resize_w2 - img_w,
cv2.BORDER_CONSTANT,
value=0)
data['image'] = img
return data
class LinearResize:
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
_im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32')
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
_im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class Resize:
def __init__(self, size=(640, 640), **kwargs):
self.size = size
def resize_image(self, img):
resize_h, resize_w = self.size
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
return img, [ratio_h, ratio_w]
def __call__(self, data):
img = data['image']
if 'polys' in data:
text_polys = data['polys']
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
if 'polys' in data:
new_boxes = []
for box in text_polys:
new_box = []
for cord in box:
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
data['polys'] = np.array(new_boxes, dtype=np.float32)
data['image'] = img_resize
return data
class DetResizeForTest:
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
self.keep_ratio = False
if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape']
self.resize_type = 1
if 'keep_ratio' in kwargs:
self.keep_ratio = kwargs['keep_ratio']
elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
elif 'resize_long' in kwargs:
self.resize_type = 2
self.resize_long = kwargs.get('resize_long', 960)
else:
self.limit_side_len = 736
self.limit_type = 'min'
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if sum([src_h, src_w]) < 64:
img = self.image_padding(img)
if self.resize_type == 0:
# img, shape = self.resize_image_type0(img)
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
elif self.resize_type == 2:
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
data['image'] = img
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def image_padding(self, im, value=0):
h, w, c = im.shape
im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
im_pad[:h, :w, :] = im
return im_pad
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
if self.keep_ratio is True:
resize_w = ori_w * resize_h / ori_h
N = math.ceil(resize_w / 32)
resize_w = N * 32
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
# return img, np.array([ori_h, ori_w])
return img, [ratio_h, ratio_w]
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'min':
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h, w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = max(int(round(resize_w / 32) * 32), 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except BaseException:
logging.exception("{} {} {}".format(img.shape, resize_w, resize_h))
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
def resize_image_type2(self, img):
h, w, _ = img.shape
resize_w = w
resize_h = h
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
ratio = float(self.resize_long) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
img = cv2.resize(img, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
class E2EResizeForTest:
def __init__(self, **kwargs):
super(E2EResizeForTest, self).__init__()
self.max_side_len = kwargs['max_side_len']
self.valid_set = kwargs['valid_set']
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.valid_set == 'totaltext':
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
img, max_side_len=self.max_side_len)
else:
im_resized, (ratio_h, ratio_w) = self.resize_image(
img, max_side_len=self.max_side_len)
data['image'] = im_resized
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_for_totaltext(self, im, max_side_len=512):
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image(self, im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
class KieResize:
def __init__(self, **kwargs):
super(KieResize, self).__init__()
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
'img_scale'][1]
def __call__(self, data):
img = data['image']
points = data['points']
src_h, src_w, _ = img.shape
im_resized, scale_factor, [ratio_h, ratio_w
], [new_h, new_w] = self.resize_image(img)
resize_points = self.resize_boxes(img, points, scale_factor)
data['ori_image'] = img
data['ori_boxes'] = points
data['points'] = resize_points
data['image'] = im_resized
data['shape'] = np.array([new_h, new_w])
return data
def resize_image(self, img):
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
scale = [512, 1024]
h, w = img.shape[:2]
max_long_edge = max(scale)
max_short_edge = min(scale)
scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w))
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
scale_factor) + 0.5)
max_stride = 32
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(img, (resize_w, resize_h))
new_h, new_w = im.shape[:2]
w_scale = new_w / w
h_scale = new_h / h
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
norm_img[:new_h, :new_w, :] = im
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
def resize_boxes(self, im, points, scale_factor):
points = points * scale_factor
img_shape = im.shape[:2]
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
return points
class SRResize:
def __init__(self,
imgH=32,
imgW=128,
down_sample_scale=4,
keep_ratio=False,
min_ratio=1,
mask=False,
infer_mode=False,
**kwargs):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio = keep_ratio
self.min_ratio = min_ratio
self.down_sample_scale = down_sample_scale
self.mask = mask
self.infer_mode = infer_mode
def __call__(self, data):
imgH = self.imgH
imgW = self.imgW
images_lr = data["image_lr"]
transform2 = ResizeNormalize(
(imgW // self.down_sample_scale, imgH // self.down_sample_scale))
images_lr = transform2(images_lr)
data["img_lr"] = images_lr
if self.infer_mode:
return data
images_HR = data["image_hr"]
_label_strs = data["label"]
transform = ResizeNormalize((imgW, imgH))
images_HR = transform(images_HR)
data["img_hr"] = images_HR
return data
class ResizeNormalize:
def __init__(self, size, interpolation=Image.BICUBIC):
self.size = size
self.interpolation = interpolation
def __call__(self, img):
img = img.resize(self.size, self.interpolation)
img_numpy = np.array(img).astype("float32")
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
return img_numpy
class GrayImageChannelFormat:
"""
format gray scale image's channel: (3,h,w) -> (1,h,w)
Args:
inverse: inverse gray image
"""
def __init__(self, inverse=False, **kwargs):
self.inverse = inverse
def __call__(self, data):
img = data['image']
img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_expanded = np.expand_dims(img_single_channel, 0)
if self.inverse:
data['image'] = np.abs(img_expanded - 1)
else:
data['image'] = img_expanded
data['src_image'] = img
return data
class Permute:
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(self, ):
super(Permute, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class PadStride:
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride <= 0:
return im, im_info
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
return padding_im, im_info
def decode_image(im_file, im_info):
"""read rgb image
Args:
im_file (str|np.ndarray): input can be image path or np.ndarray
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if isinstance(im_file, str):
with open(im_file, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
else:
im = im_file
im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
return im, im_info
def preprocess(im, preprocess_ops):
# process image by preprocess_ops
im_info = {
'scale_factor': np.array(
[1., 1.], dtype=np.float32),
'im_shape': None,
}
im, im_info = decode_image(im, im_info)
for operator in preprocess_ops:
im, im_info = operator(im, im_info)
return im, im_info
def nms(bboxes, scores, iou_thresh):
import numpy as np
x1 = bboxes[:, 0]
y1 = bboxes[:, 1]
x2 = bboxes[:, 2]
y2 = bboxes[:, 3]
areas = (y2 - y1) * (x2 - x1)
indices = []
index = scores.argsort()[::-1]
while index.size > 0:
i = index[0]
indices.append(i)
x11 = np.maximum(x1[i], x1[index[1:]])
y11 = np.maximum(y1[i], y1[index[1:]])
x22 = np.minimum(x2[i], x2[index[1:]])
y22 = np.minimum(y2[i], y2[index[1:]])
w = np.maximum(0, x22 - x11 + 1)
h = np.maximum(0, y22 - y11 + 1)
overlaps = w * h
ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
idx = np.where(ious <= iou_thresh)[0]
index = index[idx + 1]
return indices

View File

@@ -0,0 +1,354 @@
import copy
import re
import numpy as np
import cv2
from shapely.geometry import Polygon
import pyclipper
def build_post_process(config, global_config=None):
support_dict = {'DBPostProcess': DBPostProcess, 'CTCLabelDecode': CTCLabelDecode}
config = copy.deepcopy(config)
module_name = config.pop('name')
if module_name == "None":
return
if global_config is not None:
config.update(global_config)
module_class = support_dict.get(module_name)
if module_class is None:
raise ValueError(
'post process only support {}'.format(list(support_dict)))
return module_class(**config)
class DBPostProcess:
"""
The post process for Differentiable Binarization (DB).
"""
def __init__(self,
thresh=0.3,
box_thresh=0.7,
max_candidates=1000,
unclip_ratio=2.0,
use_dilation=False,
score_mode="fast",
box_type='quad',
**kwargs):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio
self.min_size = 3
self.score_mode = score_mode
self.box_type = box_type
assert score_mode in [
"slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]])
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
boxes = []
scores = []
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours[:self.max_candidates]:
epsilon = 0.002 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape((-1, 2))
if points.shape[0] < 4:
continue
score = self.box_score_fast(pred, points.reshape(-1, 2))
if self.box_thresh > score:
continue
if points.shape[0] > 2:
box = self.unclip(points, self.unclip_ratio)
if len(box) > 1:
continue
else:
continue
box = box.reshape(-1, 2)
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.tolist())
scores.append(score)
return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
_img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = []
scores = []
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
if self.score_mode == "fast":
score = self.box_score_fast(pred, points.reshape(-1, 2))
else:
score = self.box_score_slow(pred, contour)
if self.box_thresh > score:
continue
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.astype("int32"))
scores.append(score)
return np.array(boxes, dtype="int32"), scores
def unclip(self, box, unclip_ratio):
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [
points[index_1], points[index_2], points[index_3], points[index_4]
]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
'''
box_score_fast: use bbox mean score as the mean score
'''
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
'''
box_score_slow: use polyon mean score as the mean score
'''
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if not isinstance(pred, np.ndarray):
pred = pred.numpy()
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
if self.dilation_kernel is not None:
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
self.dilation_kernel)
else:
mask = segmentation[batch_index]
if self.box_type == 'poly':
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
mask, src_w, src_h)
elif self.box_type == 'quad':
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
else:
raise ValueError(
"box_type can only be one of ['quad', 'poly']")
boxes_batch.append({'points': boxes})
return boxes_batch
class BaseRecLabelDecode:
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False):
self.beg_str = "sos"
self.end_str = "eos"
self.reverse = False
self.character_str = []
if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
else:
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str.append(line)
if use_space_char:
self.character_str.append(" ")
dict_character = list(self.character_str)
if 'arabic' in character_dict_path:
self.reverse = True
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def pred_reverse(self, pred):
pred_re = []
c_current = ''
for c in pred:
if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
if c_current != '':
pred_re.append(c_current)
pred_re.append(c)
c_current = ''
else:
c_current += c
if c_current != '':
pred_re.append(c_current)
return ''.join(pred_re[::-1])
def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[
batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
char_list = [
self.character[text_id]
for text_id in text_index[batch_idx][selection]
]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
text = ''.join(char_list)
if self.reverse: # for arabic rec
text = self.pred_reverse(text)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def get_ignored_tokens(self):
return [0] # for ctc blank
class CTCLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if not isinstance(preds, np.ndarray):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None:
return text
label = self.decode(label)
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character

View File

@@ -0,0 +1,427 @@
import gc
import logging
import os
import math
import numpy as np
import cv2
from functools import cmp_to_key
from app.core.rag.common.file_utils import get_project_base_directory
from .operators import * # noqa: F403
from .operators import preprocess
from . import operators
from .ocr import load_model
class Recognizer:
def __init__(self, label_list, task_name, model_dir=None):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not model_dir:
model_dir = os.path.join(
get_project_base_directory(),
"res/deepdoc")
self.ort_sess, self.run_options = load_model(model_dir, task_name)
self.input_names = [node.name for node in self.ort_sess.get_inputs()]
self.output_names = [node.name for node in self.ort_sess.get_outputs()]
self.input_shape = self.ort_sess.get_inputs()[0].shape[2:4]
self.label_list = label_list
@staticmethod
def sort_Y_firstly(arr, threshold):
def cmp(c1, c2):
diff = c1["top"] - c2["top"]
if abs(diff) < threshold:
diff = c1["x0"] - c2["x0"]
return diff
arr = sorted(arr, key=cmp_to_key(cmp))
return arr
@staticmethod
def sort_X_firstly(arr, threshold):
def cmp(c1, c2):
diff = c1["x0"] - c2["x0"]
if abs(diff) < threshold:
diff = c1["top"] - c2["top"]
return diff
arr = sorted(arr, key=cmp_to_key(cmp))
return arr
@staticmethod
def sort_C_firstly(arr, thr=0):
# sort using y1 first and then x1
# sorted(arr, key=lambda r: (r["x0"], r["top"]))
arr = Recognizer.sort_X_firstly(arr, thr)
for i in range(len(arr) - 1):
for j in range(i, -1, -1):
# restore the order using th
if "C" not in arr[j] or "C" not in arr[j + 1]:
continue
if arr[j + 1]["C"] < arr[j]["C"] \
or (
arr[j + 1]["C"] == arr[j]["C"]
and arr[j + 1]["top"] < arr[j]["top"]
):
tmp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = tmp
return arr
@staticmethod
def sort_R_firstly(arr, thr=0):
# sort using y1 first and then x1
# sorted(arr, key=lambda r: (r["top"], r["x0"]))
arr = Recognizer.sort_Y_firstly(arr, thr)
for i in range(len(arr) - 1):
for j in range(i, -1, -1):
if "R" not in arr[j] or "R" not in arr[j + 1]:
continue
if arr[j + 1]["R"] < arr[j]["R"] \
or (
arr[j + 1]["R"] == arr[j]["R"]
and arr[j + 1]["x0"] < arr[j]["x0"]
):
tmp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = tmp
return arr
@staticmethod
def overlapped_area(a, b, ratio=True):
tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"]
if b["x0"] > x1 or b["x1"] < x0:
return 0
if b["bottom"] < tp or b["top"] > btm:
return 0
x0_ = max(b["x0"], x0)
x1_ = min(b["x1"], x1)
assert x0_ <= x1_, "Bbox mismatch! T:{},B:{},X0:{},X1:{} ==> {}".format(
tp, btm, x0, x1, b)
tp_ = max(b["top"], tp)
btm_ = min(b["bottom"], btm)
assert tp_ <= btm_, "Bbox mismatch! T:{},B:{},X0:{},X1:{} => {}".format(
tp, btm, x0, x1, b)
ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
x0 != 0 and btm - tp != 0 else 0
if ov > 0 and ratio:
ov /= (x1 - x0) * (btm - tp)
return ov
@staticmethod
def layouts_cleanup(boxes, layouts, far=2, thr=0.7):
def not_overlapped(a, b):
return any([a["x1"] < b["x0"],
a["x0"] > b["x1"],
a["bottom"] < b["top"],
a["top"] > b["bottom"]])
i = 0
while i + 1 < len(layouts):
j = i + 1
while j < min(i + far, len(layouts)) \
and (layouts[i].get("type", "") != layouts[j].get("type", "")
or not_overlapped(layouts[i], layouts[j])):
j += 1
if j >= min(i + far, len(layouts)):
i += 1
continue
if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \
and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr:
i += 1
continue
if layouts[i].get("score") and layouts[j].get("score"):
if layouts[i]["score"] > layouts[j]["score"]:
layouts.pop(j)
else:
layouts.pop(i)
continue
area_i, area_i_1 = 0, 0
for b in boxes:
if not not_overlapped(b, layouts[i]):
area_i += Recognizer.overlapped_area(b, layouts[i], False)
if not not_overlapped(b, layouts[j]):
area_i_1 += Recognizer.overlapped_area(b, layouts[j], False)
if area_i > area_i_1:
layouts.pop(j)
else:
layouts.pop(i)
return layouts
def create_inputs(self, imgs, im_info):
"""generate input for different model type
Args:
imgs (list(numpy)): list of images (np.ndarray)
im_info (list(dict)): list of image info
Returns:
inputs (dict): input of model
"""
inputs = {}
im_shape = []
scale_factor = []
if len(imgs) == 1:
inputs['image'] = np.array((imgs[0],)).astype('float32')
inputs['im_shape'] = np.array(
(im_info[0]['im_shape'],)).astype('float32')
inputs['scale_factor'] = np.array(
(im_info[0]['scale_factor'],)).astype('float32')
return inputs
im_shape = np.array([info['im_shape'] for info in im_info], dtype='float32')
scale_factor = np.array([info['scale_factor'] for info in im_info], dtype='float32')
inputs['im_shape'] = np.concatenate(im_shape, axis=0)
inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
max_shape_h = max([e[0] for e in imgs_shape])
max_shape_w = max([e[1] for e in imgs_shape])
padding_imgs = []
for img in imgs:
im_c, im_h, im_w = img.shape[:]
padding_im = np.zeros(
(im_c, max_shape_h, max_shape_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = img
padding_imgs.append(padding_im)
inputs['image'] = np.stack(padding_imgs, axis=0)
return inputs
@staticmethod
def find_overlapped(box, boxes_sorted_by_y, naive=False):
if not boxes_sorted_by_y:
return
bxs = boxes_sorted_by_y
s, e, ii = 0, len(bxs), 0
while s < e and not naive:
ii = (e + s) // 2
pv = bxs[ii]
if box["bottom"] < pv["top"]:
e = ii
continue
if box["top"] > pv["bottom"]:
s = ii + 1
continue
break
while s < ii:
if box["top"] > bxs[s]["bottom"]:
s += 1
break
while e - 1 > ii:
if box["bottom"] < bxs[e - 1]["top"]:
e -= 1
break
max_overlapped_i, max_overlapped = None, 0
for i in range(s, e):
ov = Recognizer.overlapped_area(bxs[i], box)
if ov <= max_overlapped:
continue
max_overlapped_i = i
max_overlapped = ov
return max_overlapped_i
@staticmethod
def find_horizontally_tightest_fit(box, boxes):
if not boxes:
return
min_dis, min_i = 1000000, None
for i,b in enumerate(boxes):
if box.get("layoutno", "0") != b.get("layoutno", "0"):
continue
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
if dis < min_dis:
min_i = i
min_dis = dis
return min_i
@staticmethod
def find_overlapped_with_threshold(box, boxes, thr=0.3):
if not boxes:
return
max_overlapped_i, max_overlapped, _max_overlapped = None, thr, 0
s, e = 0, len(boxes)
for i in range(s, e):
ov = Recognizer.overlapped_area(box, boxes[i])
_ov = Recognizer.overlapped_area(boxes[i], box)
if (ov, _ov) < (max_overlapped, _max_overlapped):
continue
max_overlapped_i = i
max_overlapped = ov
_max_overlapped = _ov
return max_overlapped_i
def preprocess(self, image_list):
inputs = []
if "scale_factor" in self.input_names:
preprocess_ops = []
for op_info in [
{'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
{'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
{'type': 'Permute'},
{'stride': 32, 'type': 'PadStride'}
]:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(getattr(operators, op_type)(**new_op_info))
for im_path in image_list:
im, im_info = preprocess(im_path, preprocess_ops)
inputs.append({"image": np.array((im,)).astype('float32'),
"scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
else:
hh, ww = self.input_shape
for img in image_list:
h, w = img.shape[:2]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(np.array(img).astype('float32'), (ww, hh))
# Scale input pixel values to 0 to 1
img /= 255.0
img = img.transpose(2, 0, 1)
img = img[np.newaxis, :, :, :].astype(np.float32)
inputs.append({self.input_names[0]: img, "scale_factor": [w/ww, h/hh]})
return inputs
def postprocess(self, boxes, inputs, thr):
if "scale_factor" in self.input_names:
bb = []
for b in boxes:
clsid, bbox, score = int(b[0]), b[2:], b[1]
if score < thr:
continue
if clsid >= len(self.label_list):
continue
bb.append({
"type": self.label_list[clsid].lower(),
"bbox": [float(t) for t in bbox.tolist()],
"score": float(score)
})
return bb
def xywh2xyxy(x):
# [x, y, w, h] to [x1, y1, x2, y2]
y = np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2
y[:, 1] = x[:, 1] - x[:, 3] / 2
y[:, 2] = x[:, 0] + x[:, 2] / 2
y[:, 3] = x[:, 1] + x[:, 3] / 2
return y
def compute_iou(box, boxes):
# Compute xmin, ymin, xmax, ymax for both boxes
xmin = np.maximum(box[0], boxes[:, 0])
ymin = np.maximum(box[1], boxes[:, 1])
xmax = np.minimum(box[2], boxes[:, 2])
ymax = np.minimum(box[3], boxes[:, 3])
# Compute intersection area
intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
# Compute union area
box_area = (box[2] - box[0]) * (box[3] - box[1])
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
union_area = box_area + boxes_area - intersection_area
# Compute IoU
iou = intersection_area / union_area
return iou
def iou_filter(boxes, scores, iou_threshold):
sorted_indices = np.argsort(scores)[::-1]
keep_boxes = []
while sorted_indices.size > 0:
# Pick the last box
box_id = sorted_indices[0]
keep_boxes.append(box_id)
# Compute IoU of the picked box with the rest
ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
# Remove boxes with IoU over the threshold
keep_indices = np.where(ious < iou_threshold)[0]
# print(keep_indices.shape, sorted_indices.shape)
sorted_indices = sorted_indices[keep_indices + 1]
return keep_boxes
boxes = np.squeeze(boxes).T
# Filter out object confidence scores below threshold
scores = np.max(boxes[:, 4:], axis=1)
boxes = boxes[scores > thr, :]
scores = scores[scores > thr]
if len(boxes) == 0:
return []
# Get the class with the highest confidence
class_ids = np.argmax(boxes[:, 4:], axis=1)
boxes = boxes[:, :4]
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]])
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
boxes = xywh2xyxy(boxes)
unique_class_ids = np.unique(class_ids)
indices = []
for class_id in unique_class_ids:
class_indices = np.where(class_ids == class_id)[0]
class_boxes = boxes[class_indices, :]
class_scores = scores[class_indices]
class_keep_boxes = iou_filter(class_boxes, class_scores, 0.2)
indices.extend(class_indices[class_keep_boxes])
return [{
"type": self.label_list[class_ids[i]].lower(),
"bbox": [float(t) for t in boxes[i].tolist()],
"score": float(scores[i])
} for i in indices]
def close(self):
logging.info("Close recognizer.")
if hasattr(self, "ort_sess"):
del self.ort_sess
gc.collect()
def __call__(self, image_list, thr=0.7, batch_size=16):
res = []
images = []
for i in range(len(image_list)):
if not isinstance(image_list[i], np.ndarray):
images.append(np.array(image_list[i]))
else:
images.append(image_list[i])
batch_loop_cnt = math.ceil(float(len(images)) / batch_size)
for i in range(batch_loop_cnt):
start_index = i * batch_size
end_index = min((i + 1) * batch_size, len(images))
batch_image_list = images[start_index:end_index]
inputs = self.preprocess(batch_image_list)
logging.debug("preprocess")
for ins in inputs:
bb = self.postprocess(self.ort_sess.run(None, {k:v for k,v in ins.items() if k in self.input_names}, self.run_options)[0], ins, thr)
res.append(bb)
#seeit.save_results(image_list, res, self.label_list, threshold=thr)
return res
def __del__(self):
self.close()

View File

@@ -0,0 +1,71 @@
import logging
import os
import PIL
from PIL import ImageDraw
def save_results(image_list, results, labels, output_dir='output/', threshold=0.5):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for idx, im in enumerate(image_list):
im = draw_box(im, results[idx], labels, threshold=threshold)
out_path = os.path.join(output_dir, f"{idx}.jpg")
im.save(out_path, quality=95)
logging.debug("save result to: " + out_path)
def draw_box(im, result, labels, threshold=0.5):
draw_thickness = min(im.size) // 320
draw = ImageDraw.Draw(im)
color_list = get_color_map_list(len(labels))
clsid2color = {n.lower():color_list[i] for i,n in enumerate(labels)}
result = [r for r in result if r["score"] >= threshold]
for dt in result:
color = tuple(clsid2color[dt["type"]])
xmin, ymin, xmax, ymax = dt["bbox"]
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=draw_thickness,
fill=color)
# draw label
text = "{} {:.4f}".format(dt["type"], dt["score"])
tw, th = imagedraw_textsize_c(draw, text)
draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return im
def get_color_map_list(num_classes):
"""
Args:
num_classes (int): number of class
Returns:
color_map (list): RGB color list
"""
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
return color_map
def imagedraw_textsize_c(draw, text):
if int(PIL.__version__.split('.')[0]) < 10:
tw, th = draw.textsize(text)
else:
left, top, right, bottom = draw.textbbox((0, 0), text)
tw, th = right - left, bottom - top
return tw, th

View File

@@ -0,0 +1,77 @@
import os
import sys
sys.path.insert(
0,
os.path.abspath(
os.path.join(
os.path.dirname(
os.path.abspath(__file__)),
'../../')))
from .seeit import draw_box
from . import OCR, init_in_out
import argparse
import numpy as np
import trio
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu
# os.environ['CUDA_VISIBLE_DEVICES'] = '' #cpu
def main(args):
import torch.cuda
cuda_devices = torch.cuda.device_count()
limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
ocr = OCR()
images, outputs = init_in_out(args)
def __ocr(i, id, img):
print("Task {} start".format(i))
bxs = ocr(np.array(img), id)
bxs = [(line[0], line[1][0]) for line in bxs]
bxs = [{
"text": t,
"bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]],
"type": "ocr",
"score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]]
img = draw_box(images[i], bxs, ["ocr"], 1.)
img.save(outputs[i], quality=95)
with open(outputs[i] + ".txt", "w+", encoding='utf-8') as f:
f.write("\n".join([o["text"] for o in bxs]))
print("Task {} done".format(i))
async def __ocr_thread(i, id, img, limiter = None):
if limiter:
async with limiter:
print("Task {} use device {}".format(i, id))
await trio.to_thread.run_sync(lambda: __ocr(i, id, img))
else:
__ocr(i, id, img)
async def __ocr_launcher():
if cuda_devices > 1:
async with trio.open_nursery() as nursery:
for i, img in enumerate(images):
nursery.start_soon(__ocr_thread, i, i % cuda_devices, img, limiter[i % cuda_devices])
await trio.sleep(0.1)
else:
for i, img in enumerate(images):
await __ocr_thread(i, 0, img)
trio.run(__ocr_launcher)
print("OCR tasks are all done")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--inputs',
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
required=True)
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'",
default="./ocr_outputs")
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,170 @@
import logging
import os
import sys
sys.path.insert(
0,
os.path.abspath(
os.path.join(
os.path.dirname(
os.path.abspath(__file__)),
'../../')))
from .seeit import draw_box
from . import LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out
import argparse
import re
import numpy as np
def main(args):
images, outputs = init_in_out(args)
if args.mode.lower() == "layout":
detr = LayoutRecognizer("layout")
layouts = detr.forward(images, thr=float(args.threshold))
if args.mode.lower() == "tsr":
detr = TableStructureRecognizer()
ocr = OCR()
layouts = detr(images, thr=float(args.threshold))
for i, lyt in enumerate(layouts):
if args.mode.lower() == "tsr":
#lyt = [t for t in lyt if t["type"] == "table column"]
html = get_table_html(images[i], lyt, ocr)
with open(outputs[i] + ".html", "w+", encoding='utf-8') as f:
f.write(html)
lyt = [{
"type": t["label"],
"bbox": [t["x0"], t["top"], t["x1"], t["bottom"]],
"score": t["score"]
} for t in lyt]
img = draw_box(images[i], lyt, detr.labels, float(args.threshold))
img.save(outputs[i], quality=95)
logging.info("save result to: " + outputs[i])
def get_table_html(img, tb_cpns, ocr):
boxes = ocr(np.array(img))
boxes = LayoutRecognizer.sort_Y_firstly(
[{"x0": b[0][0], "x1": b[1][0],
"top": b[0][1], "text": t[0],
"bottom": b[-1][1],
"layout_type": "table",
"page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
np.mean([b[-1][1] - b[0][1] for b, _ in boxes]) / 3
)
def gather(kwd, fzy=10, ption=0.6):
nonlocal boxes
eles = LayoutRecognizer.sort_Y_firstly(
[r for r in tb_cpns if re.match(kwd, r["label"])], fzy)
eles = LayoutRecognizer.layouts_cleanup(boxes, eles, 5, ption)
return LayoutRecognizer.sort_Y_firstly(eles, 0)
headers = gather(r".*header$")
rows = gather(r".* (row|header)")
spans = gather(r".*spanning")
clmns = sorted([r for r in tb_cpns if re.match(
r"table column$", r["label"])], key=lambda x: x["x0"])
clmns = LayoutRecognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
for b in boxes:
ii = LayoutRecognizer.find_overlapped_with_threshold(b, rows, thr=0.3)
if ii is not None:
b["R"] = ii
b["R_top"] = rows[ii]["top"]
b["R_bott"] = rows[ii]["bottom"]
ii = LayoutRecognizer.find_overlapped_with_threshold(b, headers, thr=0.3)
if ii is not None:
b["H_top"] = headers[ii]["top"]
b["H_bott"] = headers[ii]["bottom"]
b["H_left"] = headers[ii]["x0"]
b["H_right"] = headers[ii]["x1"]
b["H"] = ii
ii = LayoutRecognizer.find_horizontally_tightest_fit(b, clmns)
if ii is not None:
b["C"] = ii
b["C_left"] = clmns[ii]["x0"]
b["C_right"] = clmns[ii]["x1"]
ii = LayoutRecognizer.find_overlapped_with_threshold(b, spans, thr=0.3)
if ii is not None:
b["H_top"] = spans[ii]["top"]
b["H_bott"] = spans[ii]["bottom"]
b["H_left"] = spans[ii]["x0"]
b["H_right"] = spans[ii]["x1"]
b["SP"] = ii
html = """
<html>
<head>
<style>
._table_1nkzy_11 {
margin: auto;
width: 70%%;
padding: 10px;
}
._table_1nkzy_11 p {
margin-bottom: 50px;
border: 1px solid #e1e1e1;
}
caption {
color: #6ac1ca;
font-size: 20px;
height: 50px;
line-height: 50px;
font-weight: 600;
margin-bottom: 10px;
}
._table_1nkzy_11 table {
width: 100%%;
border-collapse: collapse;
}
th {
color: #fff;
background-color: #6ac1ca;
}
td:hover {
background: #c1e8e8;
}
tr:nth-child(even) {
background-color: #f2f2f2;
}
._table_1nkzy_11 th,
._table_1nkzy_11 td {
text-align: center;
border: 1px solid #ddd;
padding: 8px;
}
</style>
</head>
<body>
%s
</body>
</html>
""" % TableStructureRecognizer.construct_table(boxes, html=True)
return html
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--inputs',
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
required=True)
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'",
default="./layouts_outputs")
parser.add_argument(
'--threshold',
help="A threshold to filter out detections. Default: 0.5",
default=0.5)
parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"],
default="layout")
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,597 @@
import logging
import os
import re
from collections import Counter
import numpy as np
from huggingface_hub import snapshot_download
from app.core.rag.common.file_utils import get_project_base_directory
from app.core.rag.nlp import rag_tokenizer
from .recognizer import Recognizer
class TableStructureRecognizer(Recognizer):
labels = [
"table",
"table column",
"table row",
"table column header",
"table projected row header",
"table spanning cell",
]
def __init__(self):
try:
super().__init__(self.labels, "tsr", os.path.join(get_project_base_directory(), "res/deepdoc"))
except Exception:
super().__init__(
self.labels,
"tsr",
snapshot_download(
repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "res/deepdoc"),
local_dir_use_symlinks=False,
),
)
def __call__(self, images, thr=0.2):
table_structure_recognizer_type = os.getenv("TABLE_STRUCTURE_RECOGNIZER_TYPE", "onnx").lower()
if table_structure_recognizer_type not in ["onnx", "ascend"]:
raise RuntimeError("Unsupported table structure recognizer type.")
if table_structure_recognizer_type == "onnx":
logging.debug("Using Onnx table structure recognizer")
tbls = super().__call__(images, thr)
else: # ascend
logging.debug("Using Ascend table structure recognizer")
tbls = self._run_ascend_tsr(images, thr)
res = []
# align left&right for rows, align top&bottom for columns
for tbl in tbls:
lts = [
{
"label": b["type"],
"score": b["score"],
"x0": b["bbox"][0],
"x1": b["bbox"][2],
"top": b["bbox"][1],
"bottom": b["bbox"][-1],
}
for b in tbl
]
if not lts:
continue
left = [b["x0"] for b in lts if b["label"].find("row") > 0 or b["label"].find("header") > 0]
right = [b["x1"] for b in lts if b["label"].find("row") > 0 or b["label"].find("header") > 0]
if not left:
continue
left = np.mean(left) if len(left) > 4 else np.min(left)
right = np.mean(right) if len(right) > 4 else np.max(right)
for b in lts:
if b["label"].find("row") > 0 or b["label"].find("header") > 0:
if b["x0"] > left:
b["x0"] = left
if b["x1"] < right:
b["x1"] = right
top = [b["top"] for b in lts if b["label"] == "table column"]
bottom = [b["bottom"] for b in lts if b["label"] == "table column"]
if not top:
res.append(lts)
continue
top = np.median(top) if len(top) > 4 else np.min(top)
bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom)
for b in lts:
if b["label"] == "table column":
if b["top"] > top:
b["top"] = top
if b["bottom"] < bottom:
b["bottom"] = bottom
res.append(lts)
return res
@staticmethod
def is_caption(bx):
patt = [r"[图表]+[ 0-9:]{2,}"]
if any([re.match(p, bx["text"].strip()) for p in patt]) or bx.get("layout_type", "").find("caption") >= 0:
return True
return False
@staticmethod
def blockType(b):
patt = [
("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
(r"^(20|19)[0-9]{2}年$", "Dt"),
(r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"),
("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"),
(r"^第*[一二三四1-4]季度$", "Dt"),
(r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
(r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"),
("^[0-9.,+%/ -]+$", "Nu"),
(r"^[0-9A-Z/\._~-]+$", "Ca"),
(r"^[A-Z]*[a-z' -]+$", "En"),
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()' -]+$", "NE"),
(r"^.{1}$", "Sg"),
]
for p, n in patt:
if re.search(p, b["text"].strip()):
return n
tks = [t for t in rag_tokenizer.tokenize(b["text"]).split() if len(t) > 1]
if len(tks) > 3:
if len(tks) < 12:
return "Tx"
else:
return "Lx"
if len(tks) == 1 and rag_tokenizer.tag(tks[0]) == "nr":
return "Nr"
return "Ot"
@staticmethod
def construct_table(boxes, is_english=False, html=True, **kwargs):
cap = ""
i = 0
while i < len(boxes):
if TableStructureRecognizer.is_caption(boxes[i]):
if is_english:
cap + " "
cap += boxes[i]["text"]
boxes.pop(i)
i -= 1
i += 1
if not boxes:
return []
for b in boxes:
b["btype"] = TableStructureRecognizer.blockType(b)
max_type = Counter([b["btype"] for b in boxes]).items()
max_type = max(max_type, key=lambda x: x[1])[0] if max_type else ""
logging.debug("MAXTYPE: " + max_type)
rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
rowh = np.min(rowh) if rowh else 0
boxes = Recognizer.sort_R_firstly(boxes, rowh / 2)
# for b in boxes:print(b)
boxes[0]["rn"] = 0
rows = [[boxes[0]]]
btm = boxes[0]["bottom"]
for b in boxes[1:]:
b["rn"] = len(rows) - 1
lst_r = rows[-1]
if lst_r[-1].get("R", "") != b.get("R", "") or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")): # new row
btm = b["bottom"]
b["rn"] += 1
rows.append([b])
continue
btm = (btm + b["bottom"]) / 2.0
rows[-1].append(b)
colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
colwm = np.min(colwm) if colwm else 0
crosspage = len(set([b["page_number"] for b in boxes])) > 1
if crosspage:
boxes = Recognizer.sort_X_firstly(boxes, colwm / 2)
else:
boxes = Recognizer.sort_C_firstly(boxes, colwm / 2)
boxes[0]["cn"] = 0
cols = [[boxes[0]]]
right = boxes[0]["x1"]
for b in boxes[1:]:
b["cn"] = len(cols) - 1
lst_c = cols[-1]
if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1]["page_number"]) or (
b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")
): # new col
right = b["x1"]
b["cn"] += 1
cols.append([b])
continue
right = (right + b["x1"]) / 2.0
cols[-1].append(b)
tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
for b in boxes:
tbl[b["rn"]][b["cn"]].append(b)
if len(rows) >= 4:
# remove single in column
j = 0
while j < len(tbl[0]):
e, ii = 0, 0
for i in range(len(tbl)):
if tbl[i][j]:
e += 1
ii = i
if e > 1:
break
if e > 1:
j += 1
continue
f = (j > 0 and tbl[ii][j - 1] and tbl[ii][j - 1][0].get("text")) or j == 0
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii][j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
if f and ff:
j += 1
continue
bx = tbl[ii][j][0]
logging.debug("Relocate column single: " + bx["text"])
# j column only has one value
left, right = 100000, 100000
if j > 0 and not f:
for i in range(len(tbl)):
if tbl[i][j - 1]:
left = min(left, np.min([bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
if j + 1 < len(tbl[0]) and not ff:
for i in range(len(tbl)):
if tbl[i][j + 1]:
right = min(right, np.min([a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
assert left < 100000 or right < 100000
if left < right:
for jj in range(j, len(tbl[0])):
for i in range(len(tbl)):
for a in tbl[i][jj]:
a["cn"] -= 1
if tbl[ii][j - 1]:
tbl[ii][j - 1].extend(tbl[ii][j])
else:
tbl[ii][j - 1] = tbl[ii][j]
for i in range(len(tbl)):
tbl[i].pop(j)
else:
for jj in range(j + 1, len(tbl[0])):
for i in range(len(tbl)):
for a in tbl[i][jj]:
a["cn"] -= 1
if tbl[ii][j + 1]:
tbl[ii][j + 1].extend(tbl[ii][j])
else:
tbl[ii][j + 1] = tbl[ii][j]
for i in range(len(tbl)):
tbl[i].pop(j)
cols.pop(j)
assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (len(cols), len(tbl[0]))
if len(cols) >= 4:
# remove single in row
i = 0
while i < len(tbl):
e, jj = 0, 0
for j in range(len(tbl[i])):
if tbl[i][j]:
e += 1
jj = j
if e > 1:
break
if e > 1:
i += 1
continue
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1][jj][0].get("text")) or i == 0
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1][jj][0].get("text")) or i + 1 >= len(tbl)
if f and ff:
i += 1
continue
bx = tbl[i][jj][0]
logging.debug("Relocate row single: " + bx["text"])
# i row only has one value
up, down = 100000, 100000
if i > 0 and not f:
for j in range(len(tbl[i - 1])):
if tbl[i - 1][j]:
up = min(up, np.min([bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
if i + 1 < len(tbl) and not ff:
for j in range(len(tbl[i + 1])):
if tbl[i + 1][j]:
down = min(down, np.min([a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
assert up < 100000 or down < 100000
if up < down:
for ii in range(i, len(tbl)):
for j in range(len(tbl[ii])):
for a in tbl[ii][j]:
a["rn"] -= 1
if tbl[i - 1][jj]:
tbl[i - 1][jj].extend(tbl[i][jj])
else:
tbl[i - 1][jj] = tbl[i][jj]
tbl.pop(i)
else:
for ii in range(i + 1, len(tbl)):
for j in range(len(tbl[ii])):
for a in tbl[ii][j]:
a["rn"] -= 1
if tbl[i + 1][jj]:
tbl[i + 1][jj].extend(tbl[i][jj])
else:
tbl[i + 1][jj] = tbl[i][jj]
tbl.pop(i)
rows.pop(i)
# which rows are headers
hdset = set([])
for i in range(len(tbl)):
cnt, h = 0, 0
for j, arr in enumerate(tbl[i]):
if not arr:
continue
cnt += 1
if max_type == "Nu" and arr[0]["btype"] == "Nu":
continue
if any([a.get("H") for a in arr]) or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
h += 1
if h / cnt > 0.5:
hdset.add(i)
if html:
return TableStructureRecognizer.__html_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, True))
return TableStructureRecognizer.__desc_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, False), is_english)
@staticmethod
def __html_table(cap, hdset, tbl):
# constrcut HTML
html = "<table>"
if cap:
html += f"<caption>{cap}</caption>"
for i in range(len(tbl)):
row = "<tr>"
txts = []
for j, arr in enumerate(tbl[i]):
if arr is None:
continue
if not arr:
row += "<td></td>" if i not in hdset else "<th></th>"
continue
txt = ""
if arr:
h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10)
txt = " ".join([c["text"] for c in Recognizer.sort_Y_firstly(arr, h)])
txts.append(txt)
sp = ""
if arr[0].get("colspan"):
sp = "colspan={}".format(arr[0]["colspan"])
if arr[0].get("rowspan"):
sp += " rowspan={}".format(arr[0]["rowspan"])
if i in hdset:
row += f"<th {sp} >" + txt + "</th>"
else:
row += f"<td {sp} >" + txt + "</td>"
if i in hdset:
if all([t in hdset for t in txts]):
continue
for t in txts:
hdset.add(t)
if row != "<tr>":
row += "</tr>"
else:
row = ""
html += "\n" + row
html += "\n</table>"
return html
@staticmethod
def __desc_table(cap, hdr_rowno, tbl, is_english):
# get text of every colomn in header row to become header text
clmno = len(tbl[0])
rowno = len(tbl)
headers = {}
hdrset = set()
lst_hdr = []
de = "" if not is_english else " for "
for r in sorted(list(hdr_rowno)):
headers[r] = ["" for _ in range(clmno)]
for i in range(clmno):
if not tbl[r][i]:
continue
txt = " ".join([a["text"].strip() for a in tbl[r][i]])
headers[r][i] = txt
hdrset.add(txt)
if all([not t for t in headers[r]]):
del headers[r]
hdr_rowno.remove(r)
continue
for j in range(clmno):
if headers[r][j]:
continue
if j >= len(lst_hdr):
break
headers[r][j] = lst_hdr[j]
lst_hdr = headers[r]
for i in range(rowno):
if i not in hdr_rowno:
continue
for j in range(i + 1, rowno):
if j not in hdr_rowno:
break
for k in range(clmno):
if not headers[j - 1][k]:
continue
if headers[j][k].find(headers[j - 1][k]) >= 0:
continue
if len(headers[j][k]) > len(headers[j - 1][k]):
headers[j][k] += (de if headers[j][k] else "") + headers[j - 1][k]
else:
headers[j][k] = headers[j - 1][k] + (de if headers[j - 1][k] else "") + headers[j][k]
logging.debug(f">>>>>>>>>>>>>>>>>{cap}SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
row_txt = []
for i in range(rowno):
if i in hdr_rowno:
continue
rtxt = []
def append(delimer):
nonlocal rtxt, row_txt
rtxt = delimer.join(rtxt)
if row_txt and len(row_txt[-1]) + len(rtxt) < 64:
row_txt[-1] += "\n" + rtxt
else:
row_txt.append(rtxt)
r = 0
if len(headers.items()):
_arr = [(i - r, r) for r, _ in headers.items() if r < i]
if _arr:
_, r = min(_arr, key=lambda x: x[0])
if r not in headers and clmno <= 2:
for j in range(clmno):
if not tbl[i][j]:
continue
txt = "".join([a["text"].strip() for a in tbl[i][j]])
if txt:
rtxt.append(txt)
if rtxt:
append("")
continue
for j in range(clmno):
if not tbl[i][j]:
continue
txt = "".join([a["text"].strip() for a in tbl[i][j]])
if not txt:
continue
ctt = headers[r][j] if r in headers else ""
if ctt:
ctt += ""
ctt += txt
if ctt:
rtxt.append(ctt)
if rtxt:
row_txt.append("; ".join(rtxt))
if cap:
if is_english:
from_ = " in "
else:
from_ = "来自"
row_txt = [t + f"\t——{from_}{cap}" for t in row_txt]
return row_txt
@staticmethod
def __cal_spans(boxes, rows, cols, tbl, html=True):
# caculate span
clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) for cln in cols]
crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) for cln in cols]
rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) for row in rows]
rbtm = [np.mean([c.get("R_btm", c["bottom"]) for c in row]) for row in rows]
for b in boxes:
if "SP" not in b:
continue
b["colspan"] = [b["cn"]]
b["rowspan"] = [b["rn"]]
# col span
for j in range(0, len(clft)):
if j == b["cn"]:
continue
if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]:
continue
if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]:
continue
b["colspan"].append(j)
# row span
for j in range(0, len(rtop)):
if j == b["rn"]:
continue
if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]:
continue
if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]:
continue
b["rowspan"].append(j)
def join(arr):
if not arr:
return ""
return "".join([t["text"] for t in arr])
# rm the spaning cells
for i in range(len(tbl)):
for j, arr in enumerate(tbl[i]):
if not arr:
continue
if all(["rowspan" not in a and "colspan" not in a for a in arr]):
continue
rowspan, colspan = [], []
for a in arr:
if isinstance(a.get("rowspan", 0), list):
rowspan.extend(a["rowspan"])
if isinstance(a.get("colspan", 0), list):
colspan.extend(a["colspan"])
rowspan, colspan = set(rowspan), set(colspan)
if len(rowspan) < 2 and len(colspan) < 2:
for a in arr:
if "rowspan" in a:
del a["rowspan"]
if "colspan" in a:
del a["colspan"]
continue
rowspan, colspan = sorted(rowspan), sorted(colspan)
rowspan = list(range(rowspan[0], rowspan[-1] + 1))
colspan = list(range(colspan[0], colspan[-1] + 1))
assert i in rowspan, rowspan
assert j in colspan, colspan
arr = []
for r in rowspan:
for c in colspan:
arr_txt = join(arr)
if tbl[r][c] and join(tbl[r][c]) != arr_txt:
arr.extend(tbl[r][c])
tbl[r][c] = None if html else arr
for a in arr:
if len(rowspan) > 1:
a["rowspan"] = len(rowspan)
elif "rowspan" in a:
del a["rowspan"]
if len(colspan) > 1:
a["colspan"] = len(colspan)
elif "colspan" in a:
del a["colspan"]
tbl[rowspan[0]][colspan[0]] = arr
return tbl
def _run_ascend_tsr(self, image_list, thr=0.2, batch_size=16):
import math
from ais_bench.infer.interface import InferSession
model_dir = os.path.join(get_project_base_directory(), "res/deepdoc")
model_file_path = os.path.join(model_dir, "tsr.om")
if not os.path.exists(model_file_path):
raise ValueError(f"Model file not found: {model_file_path}")
device_id = int(os.getenv("ASCEND_LAYOUT_RECOGNIZER_DEVICE_ID", 0))
session = InferSession(device_id=device_id, model_path=model_file_path)
images = [np.array(im) if not isinstance(im, np.ndarray) else im for im in image_list]
results = []
conf_thr = max(thr, 0.08)
batch_loop_cnt = math.ceil(float(len(images)) / batch_size)
for bi in range(batch_loop_cnt):
s = bi * batch_size
e = min((bi + 1) * batch_size, len(images))
batch_images = images[s:e]
inputs_list = self.preprocess(batch_images)
for ins in inputs_list:
feeds = []
if "image" in ins:
feeds.append(ins["image"])
else:
feeds.append(ins[self.input_names[0]])
output_list = session.infer(feeds=feeds, mode="static")
bb = self.postprocess(output_list, ins, conf_thr)
results.append(bb)
return results