feat: Add base project structure with API and web components

This commit is contained in:
Ke Sun
2025-12-02 20:28:01 +08:00
parent f3de6d6cc9
commit c1adc62ec6
817 changed files with 111226 additions and 106 deletions

View File

View File

View File

@@ -0,0 +1,42 @@
import os
import re
import tempfile
from app.core.rag.nlp import rag_tokenizer, tokenize
def chunk(filename, binary, lang, callback=None, seq2txt_mdl=None, **kwargs):
doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
# is it English
eng = lang.lower() == "english" # is_english(sections)
try:
_, ext = os.path.splitext(filename)
if not ext:
raise RuntimeError("No extension detected.")
if ext not in [".da", ".wave", ".wav", ".mp3", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", ".realaudio", ".vqf", ".oggvorbis", ".ape"]:
raise RuntimeError(f"Extension {ext} is not supported yet.")
tmp_path = ""
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmpf:
tmpf.write(binary)
tmpf.flush()
tmp_path = os.path.abspath(tmpf.name)
callback(0.1, "USE Sequence2Txt LLM to transcription the audio")
ans = seq2txt_mdl.transcription(tmp_path)
callback(0.8, "Sequence2Txt LLM respond: %s ..." % ans[:32])
tokenize(doc, ans, eng)
return [doc]
except Exception as e:
callback(prog=-1, msg=str(e))
finally:
if tmp_path and os.path.exists(tmp_path):
try:
os.unlink(tmp_path)
except Exception:
pass
return []

View File

@@ -0,0 +1,170 @@
import logging
import re
from io import BytesIO
from app.core.rag.deepdoc.parser.utils import get_text
from . import naive
from .naive import by_plaintext, PARSERS
from app.core.rag.nlp import bullets_category, is_english,remove_contents_table, \
hierarchical_merge, make_colon_as_title, naive_merge, random_choices, tokenize_table, \
tokenize_chunks
from app.core.rag.nlp import rag_tokenizer
from app.core.rag.deepdoc.parser import PdfParser, HtmlParser
from app.core.rag.deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
from PIL import Image
class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
from timeit import default_timer as timer
start = timer()
callback(msg="OCR started")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
start = timer()
self._layouts_rec(zoomin)
callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start))
logging.debug("layouts: {}".format(timer() - start))
start = timer()
self._table_transformer_job(zoomin)
callback(0.68, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge()
tbls = self._extract_table_figure(True, zoomin, True, True)
self._naive_vertical_merge()
self._filter_forpages()
self._merge_with_same_bullet()
callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start))
return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", ""))
for b in self.boxes], tbls
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
"""
Supported file formats are docx, pdf, txt.
Since a book is long and not all the parts are useful, if it's a PDF,
please setup the page ranges for every book in order eliminate negative effects and save elapsed computing time.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
pdf_parser = None
sections, tbls = [], []
if re.search(r"\.docx$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
doc_parser = naive.Docx()
# TODO: table of contents need to be removed
sections, tbls = doc_parser(
filename, binary=binary, from_page=from_page, to_page=to_page)
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
tbls=vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
# tbls = [((None, lns), None) for lns in tbls]
sections=[(item[0],item[1] if item[1] is not None else "") for item in sections if not isinstance(item[1], Image.Image)]
callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, by_plaintext)
callback(0.1, "Start to parse.")
sections, tables, pdf_parser = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
**kwargs
)
if not sections and not tables:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
sections = txt.split("\n")
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections = HtmlParser()(filename, binary)
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")
elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")
else:
raise NotImplementedError(
"file type not supported yet(doc, docx, pdf, txt supported)")
make_colon_as_title(sections)
bull = bullets_category(
[t for t in random_choices([t for t, _ in sections], k=100)])
if bull >= 0:
chunks = ["\n".join(ck)
for ck in hierarchical_merge(bull, sections, 5)]
else:
sections = [s.split("@") for s, _ in sections]
sections = [(pr[0], "@" + pr[1]) if len(pr) == 2 else (pr[0], '') for pr in sections ]
chunks = naive_merge(
sections, kwargs.get(
"chunk_token_num", 256), kwargs.get(
"delimer", "\n。;!?"))
# is it English
# is_english(random_choices([t for t, _ in sections], k=218))
eng = lang.lower() == "english"
res = tokenize_table(tbls, doc, eng)
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
return res
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=1, to_page=10, callback=dummy)

View File

@@ -0,0 +1,219 @@
import logging
import re
from io import BytesIO
from docx import Document
from app.core.rag.common.constants import ParserType
from app.core.rag.deepdoc.parser.utils import get_text
from app.core.rag.nlp import bullets_category, remove_contents_table, \
make_colon_as_title, tokenize_chunks, docx_question_level, tree_merge
from app.core.rag.nlp import rag_tokenizer, Node
from app.core.rag.deepdoc.parser import PdfParser, DocxParser, HtmlParser
from app.core.rag.app.naive import by_plaintext, PARSERS
class Docx(DocxParser):
def __init__(self):
pass
def __clean(self, line):
line = re.sub(r"\u3000", " ", line).strip()
return line
def old_call(self, filename, binary=None, from_page=0, to_page=100000):
self.doc = Document(
filename) if not binary else Document(BytesIO(binary))
pn = 0
lines = []
for p in self.doc.paragraphs:
if pn > to_page:
break
if from_page <= pn < to_page and p.text.strip():
lines.append(self.__clean(p.text))
for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
return [line for line in lines if line]
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
self.doc = Document(
filename) if not binary else Document(BytesIO(binary))
pn = 0
lines = []
level_set = set()
bull = bullets_category([p.text for p in self.doc.paragraphs])
for p in self.doc.paragraphs:
if pn > to_page:
break
question_level, p_text = docx_question_level(p, bull)
if not p_text.strip("\n"):
continue
lines.append((question_level, p_text))
level_set.add(question_level)
for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
sorted_levels = sorted(level_set)
h2_level = sorted_levels[1] if len(sorted_levels) > 1 else 1
h2_level = sorted_levels[-2] if h2_level == sorted_levels[-1] and len(sorted_levels) > 2 else h2_level
root = Node(level=0, depth=h2_level, texts=[])
root.build_tree(lines)
return [element for element in root.get_tree() if element]
def __str__(self) -> str:
return f'''
question:{self.question},
answer:{self.answer},
level:{self.level},
childs:{self.childs}
'''
class Pdf(PdfParser):
def __init__(self):
self.model_speciess = ParserType.LAWS.value
super().__init__()
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
from timeit import default_timer as timer
start = timer()
callback(msg="OCR started")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback
)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
start = timer()
self._layouts_rec(zoomin)
callback(0.67, "Layout analysis ({:.2f}s)".format(timer() - start))
logging.debug("layouts:".format(
))
self._naive_vertical_merge()
callback(0.8, "Text extraction ({:.2f}s)".format(timer() - start))
return [(b["text"], self._line_tag(b, zoomin))
for b in self.boxes], None
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
"""
Supported file formats are docx, pdf, txt.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
pdf_parser = None
sections = []
# is it English
eng = lang.lower() == "english" # is_english(sections)
if re.search(r"\.docx$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
chunks = Docx()(filename, binary)
callback(0.7, "Finish parsing.")
return tokenize_chunks(chunks, doc, eng, None)
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, by_plaintext)
callback(0.1, "Start to parse.")
raw_sections, tables, pdf_parser = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
**kwargs
)
if not raw_sections and not tables:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
for txt, poss in raw_sections:
sections.append(txt + poss)
callback(0.8, "Finish parsing.")
elif re.search(r"\.(txt|md|markdown|mdx)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
sections = txt.split("\n")
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections = HtmlParser()(filename, binary)
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
else:
raise NotImplementedError(
"file type not supported yet(doc, docx, pdf, txt supported)")
# Remove 'Contents' part
remove_contents_table(sections, eng)
make_colon_as_title(sections)
bull = bullets_category(sections)
res = tree_merge(bull, sections, 2)
if not res:
callback(0.99, "No chunk parsed out.")
return tokenize_chunks(res, doc, eng, pdf_parser)
# chunks = hierarchical_merge(bull, sections, 5)
# return tokenize_chunks(["\n".join(ck)for ck in chunks], doc, eng, pdf_parser)
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], callback=dummy)

View File

@@ -0,0 +1,114 @@
import logging
from email import policy
from email.parser import BytesParser
from .naive import chunk as naive_chunk
import re
from app.core.rag.nlp import rag_tokenizer, naive_merge, tokenize_chunks
from app.core.rag.deepdoc.parser import HtmlParser, TxtParser
from timeit import default_timer as timer
import io
def chunk(
filename,
binary=None,
from_page=0,
to_page=100000,
lang="Chinese",
callback=None,
**kwargs,
):
"""
Only eml is supported
"""
eng = lang.lower() == "english" # is_english(cks)
parser_config = kwargs.get(
"parser_config",
{"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"},
)
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
main_res = []
attachment_res = []
if binary:
msg = BytesParser(policy=policy.default).parse(io.BytesIO(binary))
else:
msg = BytesParser(policy=policy.default).parse(open(filename, "rb"))
text_txt, html_txt = [], []
# get the email header info
for header, value in msg.items():
text_txt.append(f"{header}: {value}")
# get the email main info
def _add_content(msg, content_type):
def _decode_payload(payload, charset, target_list):
try:
target_list.append(payload.decode(charset))
except (UnicodeDecodeError, LookupError):
for enc in ["utf-8", "gb2312", "gbk", "gb18030", "latin1"]:
try:
target_list.append(payload.decode(enc))
break
except UnicodeDecodeError:
continue
else:
target_list.append(payload.decode("utf-8", errors="ignore"))
if content_type == "text/plain":
payload = msg.get_payload(decode=True)
charset = msg.get_content_charset() or "utf-8"
_decode_payload(payload, charset, text_txt)
elif content_type == "text/html":
payload = msg.get_payload(decode=True)
charset = msg.get_content_charset() or "utf-8"
_decode_payload(payload, charset, html_txt)
elif "multipart" in content_type:
if msg.is_multipart():
for part in msg.iter_parts():
_add_content(part, part.get_content_type())
_add_content(msg, msg.get_content_type())
sections = TxtParser.parser_txt("\n".join(text_txt)) + [
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt), chunk_token_num=parser_config["chunk_token_num"]) if line
]
st = timer()
chunks = naive_merge(
sections,
int(parser_config.get("chunk_token_num", 128)),
parser_config.get("delimiter", "\n!?。;!?"),
)
main_res.extend(tokenize_chunks(chunks, doc, eng, None))
logging.debug("naive_merge({}): {}".format(filename, timer() - st))
# get the attachment info
for part in msg.iter_attachments():
content_disposition = part.get("Content-Disposition")
if content_disposition:
dispositions = content_disposition.strip().split(";")
if dispositions[0].lower() == "attachment":
filename = part.get_filename()
payload = part.get_payload(decode=True)
try:
attachment_res.extend(
naive_chunk(filename, payload, callback=callback, **kwargs)
)
except Exception:
pass
return main_res + attachment_res
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], callback=dummy)

View File

@@ -0,0 +1,299 @@
import logging
import copy
import re
from app.core.rag.common.constants import ParserType
from io import BytesIO
from app.core.rag.nlp import rag_tokenizer, tokenize, tokenize_table, bullets_category, title_frequency, tokenize_chunks, docx_question_level
from app.core.rag.common.token_utils import num_tokens_from_string
from app.core.rag.deepdoc.parser import PdfParser, DocxParser
from app.core.rag.deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper,vision_figure_parser_docx_wrapper
from docx import Document
from PIL import Image
from .naive import by_plaintext, PARSERS
class Pdf(PdfParser):
def __init__(self):
self.model_speciess = ParserType.MANUAL.value
super().__init__()
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
from timeit import default_timer as timer
start = timer()
callback(msg="OCR started")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback
)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
logging.debug("OCR: {}".format(timer() - start))
start = timer()
self._layouts_rec(zoomin)
callback(0.65, "Layout analysis ({:.2f}s)".format(timer() - start))
logging.debug("layouts: {}".format(timer() - start))
start = timer()
self._table_transformer_job(zoomin)
callback(0.67, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge()
tbls = self._extract_table_figure(True, zoomin, True, True)
self._concat_downward()
self._filter_forpages()
callback(0.68, "Text merged ({:.2f}s)".format(timer() - start))
# clean mess
for b in self.boxes:
b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip())
return [(b["text"], b.get("layoutno", ""), self.get_position(b, zoomin))
for i, b in enumerate(self.boxes)], tbls
class Docx(DocxParser):
def __init__(self):
pass
def get_picture(self, document, paragraph):
img = paragraph._element.xpath('.//pic:pic')
if not img:
return None
try:
img = img[0]
embed = img.xpath('.//a:blip/@r:embed')[0]
related_part = document.part.related_parts[embed]
image = related_part.image
if image is not None:
image = Image.open(BytesIO(image.blob))
return image
elif related_part.blob is not None:
image = Image.open(BytesIO(related_part.blob))
return image
else:
return None
except Exception:
return None
def concat_img(self, img1, img2):
if img1 and not img2:
return img1
if not img1 and img2:
return img2
if not img1 and not img2:
return None
width1, height1 = img1.size
width2, height2 = img2.size
new_width = max(width1, width2)
new_height = height1 + height2
new_image = Image.new('RGB', (new_width, new_height))
new_image.paste(img1, (0, 0))
new_image.paste(img2, (0, height1))
return new_image
def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None):
self.doc = Document(
filename) if not binary else Document(BytesIO(binary))
pn = 0
last_answer, last_image = "", None
question_stack, level_stack = [], []
ti_list = []
for p in self.doc.paragraphs:
if pn > to_page:
break
question_level, p_text = 0, ''
if from_page <= pn < to_page and p.text.strip():
question_level, p_text = docx_question_level(p)
if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{p_text}'
current_image = self.get_picture(self.doc, p)
last_image = self.concat_img(last_image, current_image)
else: # is a question
if last_answer or last_image:
sum_question = '\n'.join(question_stack)
if sum_question:
ti_list.append((f'{sum_question}\n{last_answer}', last_image))
last_answer, last_image = '', None
i = question_level
while question_stack and i <= level_stack[-1]:
question_stack.pop()
level_stack.pop()
question_stack.append(p_text)
level_stack.append(question_level)
for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
if last_answer:
sum_question = '\n'.join(question_stack)
if sum_question:
ti_list.append((f'{sum_question}\n{last_answer}', last_image))
tbls = []
for tb in self.doc.tables:
html= "<table>"
for r in tb.rows:
html += "<tr>"
i = 0
while i < len(r.cells):
span = 1
c = r.cells[i]
for j in range(i+1, len(r.cells)):
if c.text == r.cells[j].text:
span += 1
i = j
else:
break
i += 1
html += f"<td>{c.text}</td>" if span == 1 else f"<td colspan='{span}'>{c.text}</td>"
html += "</tr>"
html += "</table>"
tbls.append(((None, html), ""))
return ti_list, tbls
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
"""
Only pdf is supported.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
pdf_parser = None
doc = {
"docnm_kwd": filename
}
doc["title_tks"] = rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"]))
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
# is it English
eng = lang.lower() == "english" # pdf_parser.is_english
if re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
pdf_parser = PARSERS.get(name, by_plaintext)
callback(0.1, "Start to parse.")
sections, tbls, pdf_parser = pdf_parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
**kwargs
)
if not sections and not tbls:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.03:
max_lvl = max([lvl for _, lvl in pdf_parser.outlines])
most_level = max(0, max_lvl - 1)
levels = []
for txt, _, _ in sections:
for t, lvl in pdf_parser.outlines:
tks = set([t[i] + t[i + 1] for i in range(len(t) - 1)])
tks_ = set([txt[i] + txt[i + 1]
for i in range(min(len(t), len(txt) - 1))])
if len(set(tks & tks_)) / max([len(tks), len(tks_), 1]) > 0.8:
levels.append(lvl)
break
else:
levels.append(max_lvl + 1)
else:
bull = bullets_category([txt for txt, _, _ in sections])
most_level, levels = title_frequency(
bull, [(txt, lvl) for txt, lvl, _ in sections])
assert len(sections) == len(levels)
sec_ids = []
sid = 0
for i, lvl in enumerate(levels):
if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
sid += 1
sec_ids.append(sid)
sections = [(txt, sec_ids[i], poss)
for i, (txt, _, poss) in enumerate(sections)]
for (img, rows), poss in tbls:
if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0], -1,
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
def tag(pn, left, right, top, bottom):
if pn + left + right + top + bottom == 0:
return ""
return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
.format(pn, left, right, top, bottom)
chunks = []
last_sid = -2
tk_cnt = 0
for txt, sec_id, poss in sorted(sections, key=lambda x: (
x[-1][0][0], x[-1][0][3], x[-1][0][1])):
poss = "\t".join([tag(*pos) for pos in poss])
if tk_cnt < 32 or (tk_cnt < 1024 and (sec_id == last_sid or sec_id == -1)):
if chunks:
chunks[-1] += "\n" + txt + poss
tk_cnt += num_tokens_from_string(txt)
continue
chunks.append(txt + poss)
tk_cnt = num_tokens_from_string(txt)
if sec_id > -1:
last_sid = sec_id
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
res = tokenize_table(tbls, doc, eng)
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
return res
elif re.search(r"\.docx?$", filename, re.IGNORECASE):
docx_parser = Docx()
ti_list, tbls = docx_parser(filename, binary,
from_page=0, to_page=10000, callback=callback)
tbls=vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs)
res = tokenize_table(tbls, doc, eng)
for text, image in ti_list:
d = copy.deepcopy(doc)
if image:
d['image'] = image
d["doc_type_kwd"] = "image"
tokenize(d, text, eng)
res.append(d)
return res
else:
raise NotImplementedError("file type not supported yet(pdf and docx supported)")
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], callback=dummy)

View File

@@ -0,0 +1,849 @@
import logging
import re
import os
from functools import reduce
from io import BytesIO
from timeit import default_timer as timer
from docx import Document
from docx.image.exceptions import InvalidImageStreamError, UnexpectedEndOfFileError, UnrecognizedImageError
from docx.opc.pkgreader import _SerializedRelationships, _SerializedRelationship
from docx.opc.oxml import parse_xml
from markdown import markdown
from PIL import Image
import copy
from app.core.rag.llm.cv_model import AzureGptV4, QWenCV
from app.core.rag.common.file_utils import get_project_base_directory
from app.core.rag.utils.file_utils import extract_embed_file, extract_links_from_pdf, extract_links_from_docx, extract_html
from app.core.rag.deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownElementExtractor, MarkdownParser, PdfParser, TxtParser
from app.core.rag.deepdoc.parser.figure_parser import VisionFigureParser,vision_figure_parser_docx_wrapper,vision_figure_parser_pdf_wrapper
from app.core.rag.deepdoc.parser.pdf_parser import PlainParser, VisionParser
from app.core.rag.deepdoc.parser.mineru_parser import MinerUParser
from app.core.rag.nlp import concat_img, find_codec, naive_merge, naive_merge_with_images, naive_merge_docx, tokenize, rag_tokenizer, tokenize_chunks, tokenize_chunks_with_images, tokenize_table
def by_deepdoc(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, vision_model=None, pdf_cls = None ,**kwargs):
callback = callback
binary = binary
pdf_parser = pdf_cls() if pdf_cls else Pdf()
sections, tables = pdf_parser(
filename if not binary else binary,
from_page=from_page,
to_page=to_page,
callback=callback
)
tables = vision_figure_parser_pdf_wrapper(tbls=tables,
callback=callback,
vision_model=vision_model,
**kwargs)
return sections, tables, pdf_parser
def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, vision_model=None, pdf_cls = None ,**kwargs):
mineru_executable = os.environ.get("MINERU_EXECUTABLE", "mineru")
mineru_api = os.environ.get("MINERU_APISERVER", "http://host.docker.internal:9987")
pdf_parser = MinerUParser(mineru_path=mineru_executable, mineru_api=mineru_api)
if not pdf_parser.check_installation():
callback(-1, "MinerU not found.")
return None, None, pdf_parser
sections, tables = pdf_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
backend=os.environ.get("MINERU_BACKEND", "pipeline"),
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
)
return sections, tables, pdf_parser
def by_textln(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, vision_model=None, pdf_cls = None ,**kwargs):
textln_app_id = os.environ.get("TEXTLN_APP_ID", "")
textln_secret_code = os.environ.get("TEXTLN_SECRET_CODE", "")
textln_api = os.environ.get("TEXTLN_APISERVER", "https://api.textin.com/ai/service/v1/pdf_to_markdown")
pdf_parser = MinerUParser(mineru_path=textln_app_id, mineru_api=textln_api)
if not pdf_parser.check_installation():
callback(-1, "MinerU not found.")
return None, None, pdf_parser
sections, tables = pdf_parser.parse_pdf(
filepath=filename,
binary=binary,
callback=callback,
output_dir=os.environ.get("MINERU_OUTPUT_DIR", ""),
backend=os.environ.get("MINERU_BACKEND", "pipeline"),
delete_output=bool(int(os.environ.get("MINERU_DELETE_OUTPUT", 1))),
)
return sections, tables, pdf_parser
def by_plaintext(filename, binary=None, from_page=0, to_page=100000, callback=None, vision_model=None, **kwargs):
if kwargs.get("layout_recognizer", "") == "Plain Text":
pdf_parser = PlainParser()
else:
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
sections, tables = pdf_parser(
filename if not binary else binary,
from_page=from_page,
to_page=to_page,
callback=callback
)
return sections, tables, pdf_parser
PARSERS = {
"deepdoc": by_deepdoc,
"mineru": by_mineru,
"textln": by_textln,
"plaintext": by_plaintext, # default
}
class Docx(DocxParser):
def __init__(self):
pass
def get_picture(self, document, paragraph):
imgs = paragraph._element.xpath('.//pic:pic')
if not imgs:
return None
res_img = None
for img in imgs:
embed = img.xpath('.//a:blip/@r:embed')
if not embed:
continue
embed = embed[0]
try:
related_part = document.part.related_parts[embed]
image_blob = related_part.image.blob
except UnrecognizedImageError:
logging.info("Unrecognized image format. Skipping image.")
continue
except UnexpectedEndOfFileError:
logging.info("EOF was unexpectedly encountered while reading an image stream. Skipping image.")
continue
except InvalidImageStreamError:
logging.info("The recognized image stream appears to be corrupted. Skipping image.")
continue
except UnicodeDecodeError:
logging.info("The recognized image stream appears to be corrupted. Skipping image.")
continue
except Exception:
logging.info("The recognized image stream appears to be corrupted. Skipping image.")
continue
try:
image = Image.open(BytesIO(image_blob)).convert('RGB')
if res_img is None:
res_img = image
else:
res_img = concat_img(res_img, image)
except Exception:
continue
return res_img
def __clean(self, line):
line = re.sub(r"\u3000", " ", line).strip()
return line
def __get_nearest_title(self, table_index, filename):
"""Get the hierarchical title structure before the table"""
import re
from docx.text.paragraph import Paragraph
titles = []
blocks = []
# Get document name from filename parameter
doc_name = re.sub(r"\.[a-zA-Z]+$", "", filename)
if not doc_name:
doc_name = "Untitled Document"
# Collect all document blocks while maintaining document order
try:
# Iterate through all paragraphs and tables in document order
for i, block in enumerate(self.doc._element.body):
if block.tag.endswith('p'): # Paragraph
p = Paragraph(block, self.doc)
blocks.append(('p', i, p))
elif block.tag.endswith('tbl'): # Table
blocks.append(('t', i, None)) # Table object will be retrieved later
except Exception as e:
logging.error(f"Error collecting blocks: {e}")
return ""
# Find the target table position
target_table_pos = -1
table_count = 0
for i, (block_type, pos, _) in enumerate(blocks):
if block_type == 't':
if table_count == table_index:
target_table_pos = pos
break
table_count += 1
if target_table_pos == -1:
return "" # Target table not found
# Find the nearest heading paragraph in reverse order
nearest_title = None
for i in range(len(blocks)-1, -1, -1):
block_type, pos, block = blocks[i]
if pos >= target_table_pos: # Skip blocks after the table
continue
if block_type != 'p':
continue
if block.style and block.style.name and re.search(r"Heading\s*(\d+)", block.style.name, re.I):
try:
level_match = re.search(r"(\d+)", block.style.name)
if level_match:
level = int(level_match.group(1))
if level <= 7: # Support up to 7 heading levels
title_text = block.text.strip()
if title_text: # Avoid empty titles
nearest_title = (level, title_text)
break
except Exception as e:
logging.error(f"Error parsing heading level: {e}")
if nearest_title:
# Add current title
titles.append(nearest_title)
current_level = nearest_title[0]
# Find all parent headings, allowing cross-level search
while current_level > 1:
found = False
for i in range(len(blocks)-1, -1, -1):
block_type, pos, block = blocks[i]
if pos >= target_table_pos: # Skip blocks after the table
continue
if block_type != 'p':
continue
if block.style and re.search(r"Heading\s*(\d+)", block.style.name, re.I):
try:
level_match = re.search(r"(\d+)", block.style.name)
if level_match:
level = int(level_match.group(1))
# Find any heading with a higher level
if level < current_level:
title_text = block.text.strip()
if title_text: # Avoid empty titles
titles.append((level, title_text))
current_level = level
found = True
break
except Exception as e:
logging.error(f"Error parsing parent heading: {e}")
if not found: # Break if no parent heading is found
break
# Sort by level (ascending, from highest to lowest)
titles.sort(key=lambda x: x[0])
# Organize titles (from highest to lowest)
hierarchy = [doc_name] + [t[1] for t in titles]
return " > ".join(hierarchy)
return ""
def __call__(self, filename, binary=None, from_page=0, to_page=100000):
self.doc = Document(
filename) if not binary else Document(BytesIO(binary))
pn = 0
lines = []
last_image = None
for p in self.doc.paragraphs:
if pn > to_page:
break
if from_page <= pn < to_page:
if p.text.strip():
if p.style and p.style.name == 'Caption':
former_image = None
if lines and lines[-1][1] and lines[-1][2] != 'Caption':
former_image = lines[-1][1].pop()
elif last_image:
former_image = last_image
last_image = None
lines.append((self.__clean(p.text), [former_image], p.style.name))
else:
current_image = self.get_picture(self.doc, p)
image_list = [current_image]
if last_image:
image_list.insert(0, last_image)
last_image = None
lines.append((self.__clean(p.text), image_list, p.style.name if p.style else ""))
else:
if current_image := self.get_picture(self.doc, p):
if lines:
lines[-1][1].append(current_image)
else:
last_image = current_image
for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
new_line = [(line[0], reduce(concat_img, line[1]) if line[1] else None) for line in lines]
tbls = []
for i, tb in enumerate(self.doc.tables):
title = self.__get_nearest_title(i, filename)
html = "<table>"
if title:
html += f"<caption>Table Location: {title}</caption>"
for r in tb.rows:
html += "<tr>"
i = 0
try:
while i < len(r.cells):
span = 1
c = r.cells[i]
for j in range(i + 1, len(r.cells)):
if c.text == r.cells[j].text:
span += 1
i = j
else:
break
i += 1
html += f"<td>{c.text}</td>" if span == 1 else f"<td colspan='{span}'>{c.text}</td>"
except Exception as e:
logging.warning(f"Error parsing table, ignore: {e}")
html += "</tr>"
html += "</table>"
tbls.append(((None, html), ""))
return new_line, tbls
def to_markdown(self, filename=None, binary=None, inline_images: bool = True):
"""
This function uses mammoth, licensed under the BSD 2-Clause License.
"""
import base64
import uuid
import mammoth
from markdownify import markdownify
docx_file = BytesIO(binary) if binary else open(filename, "rb")
def _convert_image_to_base64(image):
try:
with image.open() as image_file:
image_bytes = image_file.read()
encoded = base64.b64encode(image_bytes).decode("utf-8")
base64_url = f"data:{image.content_type};base64,{encoded}"
alt_name = "image"
alt_name = f"img_{uuid.uuid4().hex[:8]}"
return {"src": base64_url, "alt": alt_name}
except Exception as e:
logging.warning(f"Failed to convert image to base64: {e}")
return {"src": "", "alt": "image"}
try:
if inline_images:
result = mammoth.convert_to_html(docx_file, convert_image=mammoth.images.img_element(_convert_image_to_base64))
else:
result = mammoth.convert_to_html(docx_file)
html = result.value
markdown_text = markdownify(html)
return markdown_text
finally:
if not binary:
docx_file.close()
class Pdf(PdfParser):
def __init__(self):
super().__init__()
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None, separate_tables_figures=False):
start = timer()
first_start = start
callback(msg="OCR started")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback
)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
logging.info("OCR({}~{}): {:.2f}s".format(from_page, to_page, timer() - start))
start = timer()
self._layouts_rec(zoomin)
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._table_transformer_job(zoomin)
callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge(zoomin=zoomin)
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
if separate_tables_figures:
tbls, figures = self._extract_table_figure(True, zoomin, True, True, True)
self._concat_downward()
logging.info("layouts cost: {}s".format(timer() - first_start))
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls, figures
else:
tbls = self._extract_table_figure(True, zoomin, True, True)
self._naive_vertical_merge()
self._concat_downward()
self._final_reading_order_merge()
# self._filter_forpages()
logging.info("layouts cost: {}s".format(timer() - first_start))
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls
class Markdown(MarkdownParser):
def md_to_html(self, sections):
if not sections:
return []
if isinstance(sections, type("")):
text = sections
elif isinstance(sections[0], type("")):
text = sections[0]
else:
return []
from bs4 import BeautifulSoup
html_content = markdown(text)
soup = BeautifulSoup(html_content, 'html.parser')
return soup
def get_picture_urls(self, soup):
if soup:
return [img.get('src') for img in soup.find_all('img') if img.get('src')]
return []
def get_hyperlink_urls(self, soup):
if soup:
return set([a.get('href') for a in soup.find_all('a') if a.get('href')])
return []
def get_pictures(self, text):
"""Download and open all images from markdown text."""
import requests
soup = self.md_to_html(text)
image_urls = self.get_picture_urls(soup)
images = []
# Find all image URLs in text
for url in image_urls:
if not url:
continue
try:
# check if the url is a local file or a remote URL
if url.startswith(('http://', 'https://')):
# For remote URLs, download the image
response = requests.get(url, stream=True, timeout=30)
if response.status_code == 200 and response.headers['Content-Type'] and response.headers['Content-Type'].startswith('image/'):
img = Image.open(BytesIO(response.content)).convert('RGB')
images.append(img)
else:
# For local file paths, open the image directly
from pathlib import Path
local_path = Path(url)
if not local_path.exists():
logging.warning(f"Local image file not found: {url}")
continue
img = Image.open(url).convert('RGB')
images.append(img)
except Exception as e:
logging.error(f"Failed to download/open image from {url}: {e}")
continue
return images if images else None
def __call__(self, filename, binary=None, separate_tables=True,delimiter=None):
if binary:
encoding = find_codec(binary)
txt = binary.decode(encoding, errors="ignore")
else:
with open(filename, "r") as f:
txt = f.read()
remainder, tables = self.extract_tables_and_remainder(f'{txt}\n', separate_tables=separate_tables)
# To eliminate duplicate tables in chunking result, uncomment code below and set separate_tables to True in line 410.
# extractor = MarkdownElementExtractor(remainder)
extractor = MarkdownElementExtractor(txt)
element_sections = extractor.extract_elements(delimiter)
sections = [(element, "") for element in element_sections]
tbls = []
for table in tables:
tbls.append(((None, markdown(table, extensions=['markdown.extensions.tables'])), ""))
return sections, tbls
def load_from_xml_v2(baseURI, rels_item_xml):
"""
Return |_SerializedRelationships| instance loaded with the
relationships contained in *rels_item_xml*. Returns an empty
collection if *rels_item_xml* is |None|.
"""
srels = _SerializedRelationships()
if rels_item_xml is not None:
rels_elm = parse_xml(rels_item_xml)
for rel_elm in rels_elm.Relationship_lst:
if rel_elm.target_ref in ('../NULL', 'NULL'):
continue
srels._srels.append(_SerializedRelationship(baseURI, rel_elm))
return srels
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, vision_model=None, **kwargs):
"""
Supported file formats are docx, doc, pdf, excel, txt, markdown, html, json.
This method apply the naive ways to chunk files.
Successive text will be sliced into pieces using 'delimiter'.
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
"""
urls = set()
url_res = []
is_english = lang.lower() == "english" # is_english(cks)
parser_config = kwargs.get(
"parser_config", {
"layout_recognize": "DeepDOC", "chunk_token_num": 512, "delimiter": "\n!?。;!?", "analyze_hyperlink": True})
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
res = []
pdf_parser = None
section_images = None
is_root = kwargs.get("is_root", True)
embed_res = []
if is_root:
# Only extract embedded files at the root call
embeds = []
if binary is not None:
embeds = extract_embed_file(binary)
else:
raise Exception("Embedding extraction from file path is not supported.")
# Recursively chunk each embedded file and collect results
for embed_filename, embed_bytes in embeds:
try:
sub_res = chunk(embed_filename, binary=embed_bytes, lang=lang, callback=callback, vision_model=vision_model, is_root=False, **kwargs) or []
embed_res.extend(sub_res)
except Exception as e:
if callback:
callback(0.05, f"Failed to chunk embed {embed_filename}: {e}")
continue
if re.search(r"\.docx$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
if parser_config.get("analyze_hyperlink", False) and is_root:
urls = extract_links_from_docx(binary)
for index, url in enumerate(urls):
html_bytes, metadata = extract_html(url)
if not html_bytes:
continue
try:
sub_url_res = chunk(url, html_bytes, lang=lang, callback=callback, vision_model=vision_model, is_root=False, **kwargs)
except Exception as e:
logging.info(f"Failed to chunk url in registered file type {url}: {e}")
sub_url_res = chunk(f"{index}.html", html_bytes, lang=lang, callback=callback, vision_model=vision_model, is_root=False, **kwargs)
url_res.extend(sub_url_res)
# fix "There is no item named 'word/NULL' in the archive", referring to https://github.com/python-openxml/python-docx/issues/1105#issuecomment-1298075246
_SerializedRelationships.load_from_xml = load_from_xml_v2
sections, tables = Docx()(filename, binary)
tables=vision_figure_parser_docx_wrapper(sections=sections,tbls=tables,callback=callback, vision_model=vision_model, **kwargs)
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
st = timer()
chunks, images = naive_merge_docx(
sections, int(parser_config.get(
"chunk_token_num", 128)), parser_config.get(
"delimiter", "\n!?。;!?"))
if kwargs.get("section_only", False):
chunks.extend(embed_res)
chunks.extend(url_res)
return chunks
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images))
logging.info("naive_merge({}): {}".format(filename, timer() - st))
res.extend(embed_res)
res.extend(url_res)
return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if parser_config.get("analyze_hyperlink", False) and is_root:
urls = extract_links_from_pdf(binary)
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, by_plaintext)
callback(0.1, "Start to parse.")
sections, tables, pdf_parser = parser(
filename=filename,
binary=binary,
from_page=from_page,
to_page=to_page,
lang=lang,
callback=callback,
vision_model=vision_model,
layout_recognizer=layout_recognizer,
**kwargs
)
if not sections and not tables:
return []
if name in ["mineru", "textln"]:
parser_config["chunk_token_num"] = 0
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
elif re.search(r"\.pptx?$", filename, re.IGNORECASE):
if not binary:
with open(filename, "rb") as f:
binary = f.read()
from app.core.rag.app.presentation import Ppt
ppt_parser = Ppt()
for pn, (txt, img) in enumerate(ppt_parser(
filename if not binary else binary, from_page, to_page, callback)):
d = copy.deepcopy(doc)
pn += from_page
d["image"] = img
d["doc_type_kwd"] = "image"
d["page_num_int"] = [pn + 1]
d["top_int"] = [0]
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
tokenize(d, txt, is_english)
res.append(d)
return res
elif re.search(r"\.(da|wave|wav|mp3|aac|flac|ogg|aiff|au|midi|wma|realaudio|vqf|oggvorbis|ape?)$", filename, re.IGNORECASE):
if not binary:
with open(filename, "rb") as f:
binary = f.read()
from app.core.rag.app.audio import chunk as parser
return parser(filename, binary, lang=lang, callback=callback, seq2txt_mdl=vision_model, **kwargs)
elif re.search(r"\.(png|jpeg|jpg|gif|bmp|svg|mp4|mov|avi|flv|mpeg|mpg|webm|wmv|3gp|3gpp|mkv?)$", filename, re.IGNORECASE):
if not binary:
with open(filename, "rb") as f:
binary = f.read()
from app.core.rag.app.picture import chunk as parser
return parser(filename, binary, lang=lang, callback=callback, vision_model=vision_model, **kwargs)
elif re.search(r"\.(csv|xlsx?)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
if not binary:
with open(filename, "rb") as f:
binary = f.read()
excel_parser = ExcelParser()
if parser_config.get("html4excel"):
sections = [(_, "") for _ in excel_parser.html(binary, 12) if _]
else:
sections = [(_, "") for _ in excel_parser(binary) if _]
parser_config["chunk_token_num"] = 12800
elif re.search(r"\.(txt|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|sql)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections = TxtParser()(filename, binary,
parser_config.get("chunk_token_num", 128),
parser_config.get("delimiter", "\n!?;。;!?"))
callback(0.8, "Finish parsing.")
elif re.search(r"\.(md|markdown)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
markdown_parser = Markdown(int(parser_config.get("chunk_token_num", 128)))
sections, tables = markdown_parser(filename, binary, separate_tables=False,delimiter=parser_config.get("delimiter", "\n!?;。;!?"))
if vision_model:
# Process images for each section
section_images = []
for idx, (section_text, _) in enumerate(sections):
images = markdown_parser.get_pictures(section_text) if section_text else None
if images:
# If multiple images found, combine them using concat_img
combined_image = reduce(concat_img, images) if len(images) > 1 else images[0]
section_images.append(combined_image)
markdown_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data= [((combined_image, ["markdown image"]), [(0, 0, 0, 0, 0)])], **kwargs)
boosted_figures = markdown_vision_parser(callback=callback)
sections[idx] = (section_text + "\n\n" + "\n\n".join([fig[0][1][0] for fig in boosted_figures]), sections[idx][1])
else:
section_images.append(None)
else:
logging.warning("No visual model detected. Skipping figure parsing enhancement.")
if parser_config.get("hyperlink_urls", False) and is_root:
for idx, (section_text, _) in enumerate(sections):
soup = markdown_parser.md_to_html(section_text)
hyperlink_urls = markdown_parser.get_hyperlink_urls(soup)
urls.update(hyperlink_urls)
res = tokenize_table(tables, doc, is_english)
callback(0.8, "Finish parsing.")
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
chunk_token_num = int(parser_config.get("chunk_token_num", 128))
sections = HtmlParser()(filename, binary, chunk_token_num)
sections = [(_, "") for _ in sections if _]
callback(0.8, "Finish parsing.")
elif re.search(r"\.(json|jsonl|ldjson)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
chunk_token_num = int(parser_config.get("chunk_token_num", 128))
sections = JsonParser(chunk_token_num)(filename)
sections = [(_, "") for _ in sections if _]
callback(0.8, "Finish parsing.")
elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
try:
import tika
os.environ['TIKA_SERVER_JAR'] = "/tmp/tika-server.jar"
os.environ['TIKA_SERVER_PORT'] = '9998'
# java11 Initialize Tika 3.1.0.jar service urlhttp://localhost:9998 view processlsof -i :9998
tika.initVM()
from tika import parser as tika_parser
except Exception as e:
callback(0.8, f"tika not available: {e}. Unsupported .doc parsing.")
logging.warning(f"tika not available: {e}. Unsupported .doc parsing for {filename}.")
return []
doc_parsed = tika_parser.from_file(filename)
if doc_parsed.get('content', None) is not None:
sections = doc_parsed['content'].split('\n')
sections = [(_, "") for _ in sections if _]
callback(0.8, "Finish parsing.")
else:
callback(0.8, f"tika.parser got empty content from {filename}.")
logging.warning(f"tika.parser got empty content from {filename}.")
return []
else:
raise NotImplementedError(
"file type not supported yet(pdf, xlsx, doc, docx, txt supported)")
st = timer()
if section_images:
# if all images are None, set section_images to None
if all(image is None for image in section_images):
section_images = None
if section_images:
chunks, images = naive_merge_with_images(sections, section_images,
int(parser_config.get(
"chunk_token_num", 128)), parser_config.get(
"delimiter", "\n!?。;!?"))
if kwargs.get("section_only", False):
chunks.extend(embed_res)
return chunks
res.extend(tokenize_chunks_with_images(chunks, doc, is_english, images))
else:
chunks = naive_merge(
sections, int(parser_config.get(
"chunk_token_num", 128)), parser_config.get(
"delimiter", "\n!?。;!?"))
if kwargs.get("section_only", False):
chunks.extend(embed_res)
return chunks
res.extend(tokenize_chunks(chunks, doc, is_english, pdf_parser))
if urls and parser_config.get("analyze_hyperlink", False) and is_root:
for index, url in enumerate(urls):
html_bytes, metadata = extract_html(url)
if not html_bytes:
continue
try:
sub_url_res = chunk(url, html_bytes, callback=callback, lang=lang, is_root=False, **kwargs)
except Exception as e:
logging.info(f"Failed to chunk url in registered file type {url}: {e}")
sub_url_res = chunk(f"{index}.html", html_bytes, lang=lang, callback=callback, vision_model=vision_model, is_root=False, **kwargs)
url_res.extend(sub_url_res)
logging.info("naive_merge({}): {}".format(filename, timer() - st))
if embed_res:
res.extend(embed_res)
if url_res:
res.extend(url_res)
return res
if __name__ == "__main__":
# import sys
# chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
# Prepare to configure vision_model information
vision_model = QWenCV(
key="sk-8e9e40cd171749858ce2d3722ea75669",
model_name="qwen-vl-max",
lang="chinese", # 默认使用中文
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
)
def progress_callback(prog=None, msg=None):
print(f"prog: {prog} msg: {msg}\n")
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/1.txt"
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/2.md"
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/3.md" # 带图url
file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/义务教育教科书·中国历史七年级上册 (2)_Compressed.md"
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/4.doc"
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/5.json"
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/6.html"
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/7.xlsx"
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/8.pdf"
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/9.pptx"
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/10.png"
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/11.mp4"
# file_path = "/Users/sbtjfdn/Downloads/记忆科学/files/12.mp3"
res = chunk(filename=file_path,
from_page=0,
to_page=10,
callback=progress_callback,
vision_model=vision_model,
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"analyze_hyperlink": True,
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
},
is_root=False)
for index, item in enumerate(res):
print(f"Index: {index}\n----")
print(item)
print("----")

149
api/app/core/rag/app/one.py Normal file
View File

@@ -0,0 +1,149 @@
import logging
from io import BytesIO
import re
from app.core.rag.deepdoc.parser.utils import get_text
from . import naive
from app.core.rag.nlp import rag_tokenizer, tokenize
from app.core.rag.deepdoc.parser import PdfParser, ExcelParser, HtmlParser
from app.core.rag.deepdoc.parser.figure_parser import vision_figure_parser_docx_wrapper
from app.core.rag.app.naive import by_plaintext, PARSERS
class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
from timeit import default_timer as timer
start = timer()
callback(msg="OCR started")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback
)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
start = timer()
self._layouts_rec(zoomin, drop=False)
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
logging.debug("layouts cost: {}s".format(timer() - start))
start = timer()
self._table_transformer_job(zoomin)
callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge()
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
tbls = self._extract_table_figure(True, zoomin, True, True)
self._concat_downward()
sections = [(b["text"], self.get_position(b, zoomin))
for i, b in enumerate(self.boxes)]
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
x[-1][0][0], x[-1][0][3], x[-1][0][1]))], tbls
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
"""
Supported file formats are docx, pdf, excel, txt.
One file forms a chunk which maintains original text order.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
eng = lang.lower() == "english" # is_english(cks)
if re.search(r"\.docx$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections, tbls = naive.Docx()(filename, binary)
tbls=vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
sections = [s for s, _ in sections if s]
for (_, html), _ in tbls:
sections.append(html)
callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, by_plaintext)
callback(0.1, "Start to parse.")
sections, tbls, pdf_parser = parser(
filename = filename,
binary = binary,
from_page = from_page,
to_page = to_page,
lang = lang,
callback = callback,
pdf_cls = Pdf,
**kwargs
)
if not sections and not tbls:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
for (img, rows), poss in tbls:
if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0],
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
sections = [s for s, _ in sections if s]
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = ExcelParser()
sections = excel_parser.html(binary, 1000000000)
elif re.search(r"\.(txt|md|markdown)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
sections = txt.split("\n")
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections = HtmlParser()(filename, binary)
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
elif re.search(r"\.doc$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")
else:
raise NotImplementedError(
"file type not supported yet(doc, docx, pdf, txt supported)")
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
tokenize(doc, "\n".join(sections), eng)
return [doc]
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

@@ -0,0 +1,284 @@
import logging
import copy
import re
from app.core.rag.deepdoc.parser.figure_parser import vision_figure_parser_pdf_wrapper
from app.core.rag.common.constants import ParserType
from app.core.rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks
from app.core.rag.deepdoc.parser import PdfParser, PlainParser
import numpy as np
class Pdf(PdfParser):
def __init__(self):
self.model_speciess = ParserType.PAPER.value
super().__init__()
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
from timeit import default_timer as timer
start = timer()
callback(msg="OCR started")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback
)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
start = timer()
self._layouts_rec(zoomin)
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
logging.debug(f"layouts cost: {timer() - start}s")
start = timer()
self._table_transformer_job(zoomin)
callback(0.68, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge()
tbls = self._extract_table_figure(True, zoomin, True, True)
column_width = np.median([b["x1"] - b["x0"] for b in self.boxes])
self._concat_downward()
self._filter_forpages()
callback(0.75, "Text merged ({:.2f}s)".format(timer() - start))
# clean mess
if column_width < self.page_images[0].size[0] / zoomin / 2:
logging.debug("two_column................... {} {}".format(column_width,
self.page_images[0].size[0] / zoomin / 2))
self.boxes = self.sort_X_by_page(self.boxes, column_width / 2)
for b in self.boxes:
b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip())
def _begin(txt):
return re.match(
"[0-9. 一、i]*(introduction|abstract|摘要|引言|keywords|key words|关键词|background|背景|目录|前言|contents)",
txt.lower().strip())
if from_page > 0:
return {
"title": "",
"authors": "",
"abstract": "",
"sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes if
re.match(r"(text|title)", b.get("layoutno", "text"))],
"tables": tbls
}
# get title and authors
title = ""
authors = []
i = 0
while i < min(32, len(self.boxes)-1):
b = self.boxes[i]
i += 1
if b.get("layoutno", "").find("title") >= 0:
title = b["text"]
if _begin(title):
title = ""
break
for j in range(3):
if _begin(self.boxes[i + j]["text"]):
break
authors.append(self.boxes[i + j]["text"])
break
break
# get abstract
abstr = ""
i = 0
while i + 1 < min(32, len(self.boxes)):
b = self.boxes[i]
i += 1
txt = b["text"].lower().strip()
if re.match("(abstract|摘要)", txt):
if len(txt.split()) > 32 or len(txt) > 64:
abstr = txt + self._line_tag(b, zoomin)
break
txt = self.boxes[i]["text"].lower().strip()
if len(txt.split()) > 32 or len(txt) > 64:
abstr = txt + self._line_tag(self.boxes[i], zoomin)
i += 1
break
if not abstr:
i = 0
callback(
0.8, "Page {}~{}: Text merging finished".format(
from_page, min(
to_page, self.total_page)))
for b in self.boxes:
logging.debug("{} {}".format(b["text"], b.get("layoutno")))
logging.debug("{}".format(tbls))
return {
"title": title,
"authors": " ".join(authors),
"abstract": abstr,
"sections": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if
re.match(r"(text|title)", b.get("layoutno", "text"))],
"tables": tbls
}
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, **kwargs):
"""
Only pdf is supported.
The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
if re.search(r"\.pdf$", filename, re.IGNORECASE):
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
paper = {
"title": filename,
"authors": " ",
"abstract": "",
"sections": pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page)[0],
"tables": []
}
else:
pdf_parser = Pdf()
paper = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
tbls=paper["tables"]
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
paper["tables"] = tbls
else:
raise NotImplementedError("file type not supported yet(pdf supported)")
doc = {"docnm_kwd": filename, "authors_tks": rag_tokenizer.tokenize(paper["authors"]),
"title_tks": rag_tokenizer.tokenize(paper["title"] if paper["title"] else filename)}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
doc["authors_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["authors_tks"])
# is it English
eng = lang.lower() == "english" # pdf_parser.is_english
logging.debug("It's English.....{}".format(eng))
res = tokenize_table(paper["tables"], doc, eng)
if paper["abstract"]:
d = copy.deepcopy(doc)
txt = pdf_parser.remove_tag(paper["abstract"])
d["important_kwd"] = ["abstract", "总结", "概括", "summary", "summarize"]
d["important_tks"] = " ".join(d["important_kwd"])
d["image"], poss = pdf_parser.crop(
paper["abstract"], need_position=True)
add_positions(d, poss)
tokenize(d, txt, eng)
res.append(d)
sorted_sections = paper["sections"]
# set pivot using the most frequent type of title,
# then merge between 2 pivot
bull = bullets_category([txt for txt, _ in sorted_sections])
most_level, levels = title_frequency(bull, sorted_sections)
assert len(sorted_sections) == len(levels)
sec_ids = []
sid = 0
for i, lvl in enumerate(levels):
if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
sid += 1
sec_ids.append(sid)
logging.debug("{} {} {} {}".format(lvl, sorted_sections[i][0], most_level, sid))
chunks = []
last_sid = -2
for (txt, _), sec_id in zip(sorted_sections, sec_ids):
if sec_id == last_sid:
if chunks:
chunks[-1] += "\n" + txt
continue
chunks.append(txt)
last_sid = sec_id
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
return res
"""
readed = [0] * len(paper["lines"])
# find colon firstly
i = 0
while i + 1 < len(paper["lines"]):
txt = pdf_parser.remove_tag(paper["lines"][i][0])
j = i
if txt.strip("\n").strip()[-1] not in ":":
i += 1
continue
i += 1
while i < len(paper["lines"]) and not paper["lines"][i][0]:
i += 1
if i >= len(paper["lines"]): break
proj = [paper["lines"][i][0].strip()]
i += 1
while i < len(paper["lines"]) and paper["lines"][i][0].strip()[0] == proj[-1][0]:
proj.append(paper["lines"][i])
i += 1
for k in range(j, i): readed[k] = True
txt = txt[::-1]
if eng:
r = re.search(r"(.*?) ([\\.;?!]|$)", txt)
txt = r.group(1)[::-1] if r else txt[::-1]
else:
r = re.search(r"(.*?) ([。?;!]|$)", txt)
txt = r.group(1)[::-1] if r else txt[::-1]
for p in proj:
d = copy.deepcopy(doc)
txt += "\n" + pdf_parser.remove_tag(p)
d["image"], poss = pdf_parser.crop(p, need_position=True)
add_positions(d, poss)
tokenize(d, txt, eng)
res.append(d)
i = 0
chunk = []
tk_cnt = 0
def add_chunk():
nonlocal chunk, res, doc, pdf_parser, tk_cnt
d = copy.deepcopy(doc)
ck = "\n".join(chunk)
tokenize(d, pdf_parser.remove_tag(ck), pdf_parser.is_english)
d["image"], poss = pdf_parser.crop(ck, need_position=True)
add_positions(d, poss)
res.append(d)
chunk = []
tk_cnt = 0
while i < len(paper["lines"]):
if tk_cnt > 128:
add_chunk()
if readed[i]:
i += 1
continue
readed[i] = True
txt, layouts = paper["lines"][i]
txt_ = pdf_parser.remove_tag(txt)
i += 1
cnt = num_tokens_from_string(txt_)
if any([
layouts.find("title") >= 0 and chunk,
cnt + tk_cnt > 128 and tk_cnt > 32,
]):
add_chunk()
chunk = [txt]
tk_cnt = cnt
else:
chunk.append(txt)
tk_cnt += cnt
if chunk: add_chunk()
for i, d in enumerate(res):
print(d)
# d["image"].save(f"./logs/{i}.jpg")
return res
"""
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], callback=dummy)

View File

@@ -0,0 +1,96 @@
import io
import re
import numpy as np
from PIL import Image
from app.core.rag.deepdoc.vision import OCR
from app.core.rag.nlp import rag_tokenizer, tokenize
from app.core.rag.common.string_utils import clean_markdown_block
ocr = OCR()
# Gemini supported MIME types
VIDEO_EXTS = [".mp4", ".mov", ".avi", ".flv", ".mpeg", ".mpg", ".webm", ".wmv", ".3gp", ".3gpp", ".mkv"]
def chunk(filename, binary, lang, callback=None, vision_model=None, **kwargs):
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename)),
}
eng = lang.lower() == "english"
if any(filename.lower().endswith(ext) for ext in VIDEO_EXTS):
try:
doc.update({"doc_type_kwd": "video"})
ans = vision_model.chat(system="", history=[], gen_conf={}, video_bytes=binary, filename=filename)
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
ans += "\n" + ans
tokenize(doc, ans, eng)
return [doc]
except Exception as e:
callback(prog=-1, msg=str(e))
else:
img = Image.open(io.BytesIO(binary)).convert("RGB")
doc.update(
{
"image": img,
"doc_type_kwd": "image",
}
)
bxs = ocr(np.array(img))
txt = "\n".join([t[0] for _, t in bxs if t[0]])
callback(0.4, "Finish OCR: (%s ...)" % txt[:12])
if (eng and len(txt.split()) > 32) or len(txt) > 32:
tokenize(doc, txt, eng)
callback(0.8, "OCR results is too long to use CV LLM.")
return [doc]
try:
callback(0.4, "Use CV LLM to describe the picture.")
img_binary = io.BytesIO()
img.save(img_binary, format="JPEG")
img_binary.seek(0)
ans = vision_model.describe(img_binary.read())
callback(0.8, "CV LLM respond: %s ..." % ans[:32])
txt += "\n" + ans
tokenize(doc, txt, eng)
return [doc]
except Exception as e:
callback(prog=-1, msg=str(e))
return []
def vision_llm_chunk(binary, vision_model, prompt=None, callback=None):
"""
A simple wrapper to process image to markdown texts via VLM.
Returns:
Simple markdown texts generated by VLM.
"""
callback = callback or (lambda prog, msg: None)
img = binary
txt = ""
try:
with io.BytesIO() as img_binary:
try:
img.save(img_binary, format="JPEG")
except Exception:
img_binary.seek(0)
img_binary.truncate()
img.save(img_binary, format="PNG")
img_binary.seek(0)
description, token_count = vision_model.describe_with_prompt(img_binary.read(), prompt)
ans = clean_markdown_block(description)
txt += "\n" + ans
return txt
except Exception as e:
callback(-1, str(e))
return ""

View File

@@ -0,0 +1,164 @@
import copy
import re
from io import BytesIO
from PIL import Image
from app.core.rag.nlp import tokenize, is_english
from app.core.rag.nlp import rag_tokenizer
from app.core.rag.deepdoc.parser import PdfParser, PptParser, PlainParser
from PyPDF2 import PdfReader as pdf2_read
from app.core.rag.app.naive import by_plaintext, PARSERS
class Ppt(PptParser):
def __call__(self, fnm, from_page, to_page, callback=None):
txts = super().__call__(fnm, from_page, to_page)
callback(0.5, "Text extraction finished.")
import aspose.slides as slides
import aspose.pydrawing as drawing
imgs = []
with slides.Presentation(BytesIO(fnm)) as presentation:
for i, slide in enumerate(presentation.slides[from_page: to_page]):
try:
with BytesIO() as buffered:
slide.get_thumbnail(
0.1, 0.1).save(
buffered, drawing.imaging.ImageFormat.jpeg)
buffered.seek(0)
imgs.append(Image.open(buffered).copy())
except RuntimeError as e:
raise RuntimeError(f'ppt parse error at page {i+1}, original error: {str(e)}') from e
assert len(imgs) == len(
txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
callback(0.9, "Image extraction finished")
self.is_english = is_english(txts)
return [(txts[i], imgs[i]) for i in range(len(txts))]
class Pdf(PdfParser):
def __init__(self):
super().__init__()
def __garbage(self, txt):
txt = txt.lower().strip()
if re.match(r"[0-9\.,%/-]+$", txt):
return True
if len(txt) < 3:
return True
return False
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
from timeit import default_timer as timer
start = timer()
callback(msg="OCR started")
self.__images__(filename if not binary else binary,
zoomin, from_page, to_page, callback)
callback(msg="Page {}~{}: OCR finished ({:.2f}s)".format(from_page, min(to_page, self.total_page), timer() - start))
assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(
len(self.boxes), len(self.page_images))
res = []
for i in range(len(self.boxes)):
lines = "\n".join([b["text"] for b in self.boxes[i]
if not self.__garbage(b["text"])])
res.append((lines, self.page_images[i]))
callback(0.9, "Page {}~{}: Parsing finished".format(
from_page, min(to_page, self.total_page)))
return res, []
class PlainPdf(PlainParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, callback=None, **kwargs):
self.pdf = pdf2_read(filename if not binary else BytesIO(binary))
page_txt = []
for page in self.pdf.pages[from_page: to_page]:
page_txt.append(page.extract_text())
callback(0.9, "Parsing finished")
return [(txt, None) for txt in page_txt], []
def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, vision_model=None, parser_config=None, **kwargs):
"""
The supported file formats are pdf, pptx.
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
"""
if parser_config is None:
parser_config = {}
eng = lang.lower() == "english"
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
res = []
if re.search(r"\.pptx?$", filename, re.IGNORECASE):
if not binary:
with open(filename, "rb") as f:
binary = f.read()
ppt_parser = Ppt()
for pn, (txt, img) in enumerate(ppt_parser(
filename if not binary else binary, from_page, 1000000, callback)):
d = copy.deepcopy(doc)
pn += from_page
d["image"] = img
d["doc_type_kwd"] = "image"
d["page_num_int"] = [pn + 1]
d["top_int"] = [0]
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
tokenize(d, txt, eng)
res.append(d)
return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if isinstance(layout_recognizer, bool):
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
name = layout_recognizer.strip().lower()
parser = PARSERS.get(name, by_plaintext)
callback(0.1, "Start to parse.")
sections, _, _ = parser(
filename=filename,
binary=binary,
from_page=from_page,
to_page=to_page,
lang=lang,
callback=callback,
vision_model=vision_model,
pdf_cls=Pdf,
**kwargs
)
if not sections:
return []
if name in ["tcadp", "docling", "mineru"]:
parser_config["chunk_token_num"] = 0
callback(0.8, "Finish parsing.")
for pn, (txt, img) in enumerate(sections):
d = copy.deepcopy(doc)
pn += from_page
if img:
d["image"] = img
d["page_num_int"] = [pn + 1]
d["top_int"] = [0]
d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
tokenize(d, txt, eng)
res.append(d)
return res
raise NotImplementedError(
"file type not supported yet(pptx, pdf supported)")
if __name__ == "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

455
api/app/core/rag/app/qa.py Normal file
View File

@@ -0,0 +1,455 @@
import logging
import re
import csv
from copy import deepcopy
from io import BytesIO
from timeit import default_timer as timer
from openpyxl import load_workbook
from app.core.rag.deepdoc.parser.utils import get_text
from app.core.rag.nlp import is_english, random_choices, qbullets_category, add_positions, has_qbullet, docx_question_level
from app.core.rag.nlp import rag_tokenizer, tokenize_table, concat_img
from app.core.rag.deepdoc.parser import PdfParser, ExcelParser, DocxParser
from docx import Document
from PIL import Image
from markdown import markdown
from app.core.rag.common.float_utils import get_float
class Excel(ExcelParser):
def __call__(self, fnm, binary=None, callback=None):
if not binary:
wb = load_workbook(fnm)
else:
wb = load_workbook(BytesIO(binary))
total = 0
for sheetname in wb.sheetnames:
total += len(list(wb[sheetname].rows))
res, fails = [], []
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
for i, r in enumerate(rows):
q, a = "", ""
for cell in r:
if not cell.value:
continue
if not q:
q = str(cell.value)
elif not a:
a = str(cell.value)
else:
break
if q and a:
res.append((q, a))
else:
fails.append(str(i + 1))
if len(res) % 999 == 0:
callback(len(res) *
0.6 /
total, ("Extract pairs: {}".format(len(res)) +
(f"{len(fails)} failure, line: %s..." %
(",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract pairs: {}. ".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
self.is_english = is_english(
[rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
return res
class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
start = timer()
callback(msg="OCR started")
self.__images__(
filename if not binary else binary,
zoomin,
from_page,
to_page,
callback
)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
logging.debug("OCR({}~{}): {:.2f}s".format(from_page, to_page, timer() - start))
start = timer()
self._layouts_rec(zoomin, drop=False)
callback(0.63, "Layout analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._table_transformer_job(zoomin)
callback(0.65, "Table analysis ({:.2f}s)".format(timer() - start))
start = timer()
self._text_merge()
callback(0.67, "Text merged ({:.2f}s)".format(timer() - start))
tbls = self._extract_table_figure(True, zoomin, True, True)
#self._naive_vertical_merge()
# self._concat_downward()
#self._filter_forpages()
logging.debug("layouts: {}".format(timer() - start))
sections = [b["text"] for b in self.boxes]
bull_x0_list = []
q_bull, reg = qbullets_category(sections)
if q_bull == -1:
raise ValueError("Unable to recognize Q&A structure.")
qai_list = []
last_q, last_a, last_tag = '', '', ''
last_index = -1
last_box = {'text':''}
last_bull = None
def sort_key(element):
tbls_pn = element[1][0][0]
tbls_top = element[1][0][3]
return tbls_pn, tbls_top
tbls.sort(key=sort_key)
tbl_index = 0
last_pn, last_bottom = 0, 0
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', ''
for box in self.boxes:
section, line_tag = box['text'], self._line_tag(box, zoomin)
has_bull, index = has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list)
last_box, last_index, last_bull = box, index, has_bull
line_pn = get_float(line_tag.lstrip('@@').split('\t')[0])
line_top = get_float(line_tag.rstrip('##').split('\t')[3])
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
if not has_bull: # No question bullet
if not last_q:
if tbl_pn < line_pn or (tbl_pn == line_pn and tbl_top <= line_top): # image passed
tbl_index += 1
continue
else:
sum_tag = line_tag
sum_section = section
while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the middle of current answer
sum_tag = f'{tbl_tag}{sum_tag}'
sum_section = f'{tbl_text}{sum_section}'
tbl_index += 1
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
last_a = f'{last_a}{sum_section}'
last_tag = f'{last_tag}{sum_tag}'
else:
if last_q:
while ((tbl_pn == last_pn and tbl_top>= last_bottom) or (tbl_pn > last_pn)) \
and ((tbl_pn == line_pn and tbl_top <= line_top) or (tbl_pn < line_pn)): # add image at the end of last answer
last_tag = f'{last_tag}{tbl_tag}'
last_a = f'{last_a}{tbl_text}'
tbl_index += 1
tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, tbl_text = self.get_tbls_info(tbls, tbl_index)
image, poss = self.crop(last_tag, need_position=True)
qai_list.append((last_q, last_a, image, poss))
last_q, last_a, last_tag = '', '', ''
last_q = has_bull.group()
_, end = has_bull.span()
last_a = section[end:]
last_tag = line_tag
last_bottom = float(line_tag.rstrip('##').split('\t')[4])
last_pn = line_pn
if last_q:
qai_list.append((last_q, last_a, *self.crop(last_tag, need_position=True)))
return qai_list, tbls
def get_tbls_info(self, tbls, tbl_index):
if tbl_index >= len(tbls):
return 1, 0, 0, 0, 0, '@@0\t0\t0\t0\t0##', ''
tbl_pn = tbls[tbl_index][1][0][0]+1
tbl_left = tbls[tbl_index][1][0][1]
tbl_right = tbls[tbl_index][1][0][2]
tbl_top = tbls[tbl_index][1][0][3]
tbl_bottom = tbls[tbl_index][1][0][4]
tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
.format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom)
_tbl_text = ''.join(tbls[tbl_index][0][1])
return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag, _tbl_text
class Docx(DocxParser):
def __init__(self):
pass
def get_picture(self, document, paragraph):
img = paragraph._element.xpath('.//pic:pic')
if not img:
return None
img = img[0]
embed = img.xpath('.//a:blip/@r:embed')[0]
related_part = document.part.related_parts[embed]
image = related_part.image
image = Image.open(BytesIO(image.blob)).convert('RGB')
return image
def __call__(self, filename, binary=None, from_page=0, to_page=100000, callback=None):
self.doc = Document(
filename) if not binary else Document(BytesIO(binary))
pn = 0
last_answer, last_image = "", None
question_stack, level_stack = [], []
qai_list = []
for p in self.doc.paragraphs:
if pn > to_page:
break
question_level, p_text = 0, ''
if from_page <= pn < to_page and p.text.strip():
question_level, p_text = docx_question_level(p)
if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{p_text}'
current_image = self.get_picture(self.doc, p)
last_image = concat_img(last_image, current_image)
else: # is a question
if last_answer or last_image:
sum_question = '\n'.join(question_stack)
if sum_question:
qai_list.append((sum_question, last_answer, last_image))
last_answer, last_image = '', None
i = question_level
while question_stack and i <= level_stack[-1]:
question_stack.pop()
level_stack.pop()
question_stack.append(p_text)
level_stack.append(question_level)
for run in p.runs:
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
if last_answer:
sum_question = '\n'.join(question_stack)
if sum_question:
qai_list.append((sum_question, last_answer, last_image))
tbls = []
for tb in self.doc.tables:
html= "<table>"
for r in tb.rows:
html += "<tr>"
i = 0
while i < len(r.cells):
span = 1
c = r.cells[i]
for j in range(i+1, len(r.cells)):
if c.text == r.cells[j].text:
span += 1
i = j
i += 1
html += f"<td>{c.text}</td>" if span == 1 else f"<td colspan='{span}'>{c.text}</td>"
html += "</tr>"
html += "</table>"
tbls.append(((None, html), ""))
return qai_list, tbls
def rmPrefix(txt):
return re.sub(
r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t: ]+", "", txt.strip(), flags=re.IGNORECASE)
def beAdocPdf(d, q, a, eng, image, poss):
qprefix = "Question: " if eng else "问题:"
aprefix = "Answer: " if eng else "回答:"
d["content_with_weight"] = "\t".join(
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
d["content_ltks"] = rag_tokenizer.tokenize(q)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if image:
d["image"] = image
d["doc_type_kwd"] = "image"
add_positions(d, poss)
return d
def beAdocDocx(d, q, a, eng, image, row_num=-1):
qprefix = "Question: " if eng else "问题:"
aprefix = "Answer: " if eng else "回答:"
d["content_with_weight"] = "\t".join(
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
d["content_ltks"] = rag_tokenizer.tokenize(q)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if image:
d["image"] = image
d["doc_type_kwd"] = "image"
if row_num >= 0:
d["top_int"] = [row_num]
return d
def beAdoc(d, q, a, eng, row_num=-1):
qprefix = "Question: " if eng else "问题:"
aprefix = "Answer: " if eng else "回答:"
d["content_with_weight"] = "\t".join(
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
d["content_ltks"] = rag_tokenizer.tokenize(q)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
if row_num >= 0:
d["top_int"] = [row_num]
return d
def mdQuestionLevel(s):
match = re.match(r'#*', s)
return (len(match.group(0)), s.lstrip('#').lstrip()) if match else (0, s)
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
"""
Excel and csv(txt) format files are supported.
If the file is in excel format, there should be 2 column question and answer without header.
And question column is ahead of answer column.
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate question and answer.
All the deformed lines will be ignored.
Every pair of Q&A will be treated as a chunk.
"""
eng = lang.lower() == "english"
res = []
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
}
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = Excel()
for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)):
res.append(beAdoc(deepcopy(doc), q, a, eng, ii))
return res
elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
comma, tab = 0, 0
for line in lines:
if len(line.split(",")) == 2:
comma += 1
if len(line.split("\t")) == 2:
tab += 1
delimiter = "\t" if tab >= comma else ","
fails = []
question, answer = "", ""
i = 0
while i < len(lines):
arr = lines[i].split(delimiter)
if len(arr) != 2:
if question:
answer += "\n" + lines[i]
else:
fails.append(str(i+1))
elif len(arr) == 2:
if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
question, answer = arr
i += 1
if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
if question:
res.append(beAdoc(deepcopy(doc), question, answer, eng, len(lines)))
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res
elif re.search(r"\.(csv)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
delimiter = "\t" if any("\t" in line for line in lines) else ","
fails = []
question, answer = "", ""
res = []
reader = csv.reader(lines, delimiter=delimiter)
for i, row in enumerate(reader):
if len(row) != 2:
if question:
answer += "\n" + lines[i]
else:
fails.append(str(i + 1))
elif len(row) == 2:
if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
question, answer = row
if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
if question:
res.append(beAdoc(deepcopy(doc), question, answer, eng, len(list(reader))))
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
pdf_parser = Pdf()
qai_list, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
for q, a, image, poss in qai_list:
res.append(beAdocPdf(deepcopy(doc), q, a, eng, image, poss))
return res
elif re.search(r"\.(md|markdown)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
_last_question, last_answer = "", ""
question_stack, level_stack = [], []
code_block = False
for index, line in enumerate(lines):
if line.strip().startswith('```'):
code_block = not code_block
question_level, question = 0, ''
if not code_block:
question_level, question = mdQuestionLevel(line)
if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{line}'
else: # is a question
if last_answer.strip():
sum_question = '\n'.join(question_stack)
if sum_question:
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
last_answer = ''
i = question_level
while question_stack and i <= level_stack[-1]:
question_stack.pop()
level_stack.pop()
question_stack.append(question)
level_stack.append(question_level)
if last_answer.strip():
sum_question = '\n'.join(question_stack)
if sum_question:
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
return res
elif re.search(r"\.docx$", filename, re.IGNORECASE):
docx_parser = Docx()
qai_list, tbls = docx_parser(filename, binary,
from_page=0, to_page=10000, callback=callback)
res = tokenize_table(tbls, doc, eng)
for i, (q, a, image) in enumerate(qai_list):
res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i))
return res
raise NotImplementedError(
"Excel, csv(txt), pdf, markdown and docx format files are supported.")
if __name__ == "__main__":
import sys
def dummy(prog=None, msg=""):
pass
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)

View File

View File

@@ -0,0 +1,106 @@
import os
import queue
import threading
from typing import Any, Callable, Coroutine, Optional, Type, Union
import asyncio
import trio
from functools import wraps
from flask import make_response, jsonify
from .constants import RetCode
TimeoutException = Union[Type[BaseException], BaseException]
OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None,
on_timeout: Optional[OnTimeoutCallback] = None):
if isinstance(seconds, str):
seconds = float(seconds)
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
result_queue = queue.Queue(maxsize=1)
def target():
try:
result = func(*args, **kwargs)
result_queue.put(result)
except Exception as e:
result_queue.put(e)
thread = threading.Thread(target=target)
thread.daemon = True
thread.start()
for a in range(attempts):
try:
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
result = result_queue.get(timeout=seconds)
else:
result = result_queue.get()
if isinstance(result, Exception):
raise result
return result
except queue.Empty:
pass
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.")
@wraps(func)
async def async_wrapper(*args, **kwargs) -> Any:
if seconds is None:
return await func(*args, **kwargs)
for a in range(attempts):
try:
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
with trio.fail_after(seconds):
return await func(*args, **kwargs)
else:
return await func(*args, **kwargs)
except trio.TooSlowError:
if a < attempts - 1:
continue
if on_timeout is not None:
if callable(on_timeout):
result = on_timeout()
if isinstance(result, Coroutine):
return await result
return result
return on_timeout
if exception is None:
raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
if isinstance(exception, BaseException):
raise exception
if isinstance(exception, type) and issubclass(exception, BaseException):
raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
raise RuntimeError("Invalid exception type provided")
if asyncio.iscoroutinefunction(func):
return async_wrapper
return wrapper
return decorator
def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None):
result_dict = {"code": code, "message": message, "data": data}
response_dict = {}
for key, value in result_dict.items():
if value is None and key != "code":
continue
else:
response_dict[key] = value
response = make_response(jsonify(response_dict))
if auth:
response.headers["Authorization"] = auth
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Method"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
response.headers["Access-Control-Expose-Headers"] = "Authorization"
return response

View File

@@ -0,0 +1,180 @@
from enum import Enum, IntEnum
from strenum import StrEnum
SERVICE_CONF = "service_conf.yaml"
RAG_SERVICE_NAME = "rag"
class CustomEnum(Enum):
@classmethod
def valid(cls, value):
try:
cls(value)
return True
except BaseException:
return False
@classmethod
def values(cls):
return [member.value for member in cls.__members__.values()]
@classmethod
def names(cls):
return [member.name for member in cls.__members__.values()]
class RetCode(IntEnum, CustomEnum):
SUCCESS = 0
NOT_EFFECTIVE = 10
EXCEPTION_ERROR = 100
ARGUMENT_ERROR = 101
DATA_ERROR = 102
OPERATING_ERROR = 103
CONNECTION_ERROR = 105
RUNNING = 106
PERMISSION_ERROR = 108
AUTHENTICATION_ERROR = 109
UNAUTHORIZED = 401
SERVER_ERROR = 500
FORBIDDEN = 403
NOT_FOUND = 404
class StatusEnum(Enum):
VALID = "1"
INVALID = "0"
class ActiveEnum(Enum):
ACTIVE = "1"
INACTIVE = "0"
class LLMType(StrEnum):
CHAT = 'chat'
EMBEDDING = 'embedding'
SPEECH2TEXT = 'speech2text'
IMAGE2TEXT = 'image2text'
RERANK = 'rerank'
TTS = 'tts'
class TaskStatus(StrEnum):
UNSTART = "0"
RUNNING = "1"
CANCEL = "2"
DONE = "3"
FAIL = "4"
SCHEDULE = "5"
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL,
TaskStatus.SCHEDULE}
class ParserType(StrEnum):
PRESENTATION = "presentation"
LAWS = "laws"
MANUAL = "manual"
PAPER = "paper"
RESUME = "resume"
BOOK = "book"
QA = "qa"
TABLE = "table"
NAIVE = "naive"
PICTURE = "picture"
ONE = "one"
AUDIO = "audio"
EMAIL = "email"
KG = "knowledge_graph"
TAG = "tag"
class FileSource(StrEnum):
LOCAL = ""
KNOWLEDGEBASE = "knowledgebase"
S3 = "s3"
NOTION = "notion"
DISCORD = "discord"
CONFLUENCE = "confluence"
GMAIL = "gmail"
GOOGLE_DRIVE = "google_drive"
JIRA = "jira"
SHAREPOINT = "sharepoint"
SLACK = "slack"
TEAMS = "teams"
class PipelineTaskType(StrEnum):
PARSE = "Parse"
DOWNLOAD = "Download"
RAPTOR = "RAPTOR"
GRAPH_RAG = "GraphRAG"
MINDMAP = "Mindmap"
VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR,
PipelineTaskType.GRAPH_RAG, PipelineTaskType.MINDMAP}
class MCPServerType(StrEnum):
SSE = "sse"
STREAMABLE_HTTP = "streamable-http"
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
class Storage(Enum):
MINIO = 1
AZURE_SPN = 2
AZURE_SAS = 3
AWS_S3 = 4
OSS = 5
OPENDAL = 6
# environment
# ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT"
# ENV_RAG_SECRET_KEY = "RAG_SECRET_KEY"
# ENV_REGISTER_ENABLED = "REGISTER_ENABLED"
# ENV_DOC_ENGINE = "DOC_ENGINE"
# ENV_SANDBOX_ENABLED = "SANDBOX_ENABLED"
# ENV_SANDBOX_HOST = "SANDBOX_HOST"
# ENV_MAX_CONTENT_LENGTH = "MAX_CONTENT_LENGTH"
# ENV_COMPONENT_EXEC_TIMEOUT = "COMPONENT_EXEC_TIMEOUT"
# ENV_TRINO_USE_TLS = "TRINO_USE_TLS"
# ENV_MAX_FILE_NUM_PER_USER = "MAX_FILE_NUM_PER_USER"
# ENV_MACOS = "MACOS"
# ENV_RAG_DEBUGPY_LISTEN = "RAG_DEBUGPY_LISTEN"
# ENV_WERKZEUG_RUN_MAIN = "WERKZEUG_RUN_MAIN"
# ENV_DISABLE_SDK = "DISABLE_SDK"
# ENV_ENABLE_TIMEOUT_ASSERTION = "ENABLE_TIMEOUT_ASSERTION"
# ENV_LOG_LEVELS = "LOG_LEVELS"
# ENV_TENSORRT_DLA_SVR = "TENSORRT_DLA_SVR"
# ENV_OCR_GPU_MEM_LIMIT_MB = "OCR_GPU_MEM_LIMIT_MB"
# ENV_OCR_ARENA_EXTEND_STRATEGY = "OCR_ARENA_EXTEND_STRATEGY"
# ENV_MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK = "MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK"
# ENV_MAX_MAX_CONCURRENT_CHATS = "MAX_CONCURRENT_CHATS"
# ENV_RAG_MCP_BASE_URL = "RAG_MCP_BASE_URL"
# ENV_RAG_MCP_HOST = "RAG_MCP_HOST"
# ENV_RAG_MCP_PORT = "RAG_MCP_PORT"
# ENV_RAG_MCP_LAUNCH_MODE = "RAG_MCP_LAUNCH_MODE"
# ENV_RAG_MCP_HOST_API_KEY = "RAG_MCP_HOST_API_KEY"
# ENV_MINERU_EXECUTABLE = "MINERU_EXECUTABLE"
# ENV_MINERU_APISERVER = "MINERU_APISERVER"
# ENV_MINERU_OUTPUT_DIR = "MINERU_OUTPUT_DIR"
# ENV_MINERU_BACKEND = "MINERU_BACKEND"
# ENV_MINERU_DELETE_OUTPUT = "MINERU_DELETE_OUTPUT"
# ENV_TCADP_OUTPUT_DIR = "TCADP_OUTPUT_DIR"
# ENV_LM_TIMEOUT_SECONDS = "LM_TIMEOUT_SECONDS"
# ENV_LLM_MAX_RETRIES = "LLM_MAX_RETRIES"
# ENV_LLM_BASE_DELAY = "LLM_BASE_DELAY"
# ENV_OLLAMA_KEEP_ALIVE = "OLLAMA_KEEP_ALIVE"
# ENV_DOC_BULK_SIZE = "DOC_BULK_SIZE"
# ENV_EMBEDDING_BATCH_SIZE = "EMBEDDING_BATCH_SIZE"
# ENV_MAX_CONCURRENT_TASKS = "MAX_CONCURRENT_TASKS"
# ENV_MAX_CONCURRENT_CHUNK_BUILDERS = "MAX_CONCURRENT_CHUNK_BUILDERS"
# ENV_MAX_CONCURRENT_MINIO = "MAX_CONCURRENT_MINIO"
# ENV_WORKER_HEARTBEAT_TIMEOUT = "WORKER_HEARTBEAT_TIMEOUT"
# ENV_TRACE_MALLOC_ENABLED = "TRACE_MALLOC_ENABLED"
PAGERANK_FLD = "pagerank_fea"
SVR_QUEUE_NAME = "rag_svr_queue"
SVR_CONSUMER_GROUP_NAME = "rag_svr_task_broker"
TAG_FLD = "tag_feas"

View File

@@ -0,0 +1,28 @@
import os
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
def get_project_base_directory(*args):
global PROJECT_BASE
if PROJECT_BASE is None:
PROJECT_BASE = os.path.abspath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
os.pardir,
os.pardir,
os.pardir,
os.pardir,
)
)
if args:
return os.path.join(PROJECT_BASE, *args)
return PROJECT_BASE
def traversal_files(base):
for root, ds, fs in os.walk(base):
for f in fs:
fullname = os.path.join(root, f)
yield fullname

View File

@@ -0,0 +1,30 @@
def get_float(v):
"""
Convert a value to float, handling None and exceptions gracefully.
Attempts to convert the input value to a float. If the value is None or
cannot be converted to float, returns negative infinity as a default value.
Args:
v: The value to convert to float. Can be any type that float() accepts,
or None.
Returns:
float: The converted float value if successful, otherwise float('-inf').
Examples:
>>> get_float("3.14")
3.14
>>> get_float(None)
-inf
>>> get_float("invalid")
-inf
>>> get_float(42)
42.0
"""
if v is None:
return float('-inf')
try:
return float(v)
except Exception:
return float('-inf')

View File

@@ -0,0 +1,92 @@
import base64
import hashlib
import uuid
import requests
import threading
import subprocess
import sys
import os
import logging
def get_uuid():
return uuid.uuid1().hex
def download_img(url):
if not url:
return ""
response = requests.get(url)
return "data:" + \
response.headers.get('Content-Type', 'image/jpg') + ";" + \
"base64," + base64.b64encode(response.content).decode("utf-8")
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod
def convert_bytes(size_in_bytes: int) -> str:
"""
Format size in bytes.
"""
if size_in_bytes == 0:
return "0 B"
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
i = 0
size = float(size_in_bytes)
while size >= 1024 and i < len(units) - 1:
size /= 1024
i += 1
if i == 0 or size >= 100:
return f"{size:.0f} {units[i]}"
elif size >= 10:
return f"{size:.1f} {units[i]}"
else:
return f"{size:.2f} {units[i]}"
def once(func):
"""
A thread-safe decorator that ensures the decorated function runs exactly once,
caching and returning its result for all subsequent calls. This prevents
race conditions in multi-thread environments by using a lock to protect
the execution state.
Args:
func (callable): The function to be executed only once.
Returns:
callable: A wrapper function that executes `func` on the first call
and returns the cached result thereafter.
Example:
@once
def compute_expensive_value():
print("Computing...")
return 42
# First call: executes and prints
# Subsequent calls: return 42 without executing
"""
executed = False
result = None
lock = threading.Lock()
def wrapper(*args, **kwargs):
nonlocal executed, result
with lock:
if not executed:
executed = True
result = func(*args, **kwargs)
return result
return wrapper
@once
def pip_install_torch():
device = os.getenv("DEVICE", "cpu")
if device=="cpu":
return
logging.info("Installing pytorch")
pkg_names = ["torch>=2.5.0,<3.0.0"]
subprocess.check_call([sys.executable, "-m", "pip", "install", *pkg_names])

View File

@@ -0,0 +1,2 @@
PARALLEL_DEVICES: int = 0

View File

@@ -0,0 +1,57 @@
import re
def remove_redundant_spaces(txt: str):
"""
Remove redundant spaces around punctuation marks while preserving meaningful spaces.
This function performs two main operations:
1. Remove spaces after left-boundary characters (opening brackets, etc.)
2. Remove spaces before right-boundary characters (closing brackets, punctuation, etc.)
Args:
txt (str): Input text to process
Returns:
str: Text with redundant spaces removed
"""
# First pass: Remove spaces after left-boundary characters
# Matches: [non-alphanumeric-and-specific-right-punctuation] + [non-space]
# Removes spaces after characters like '(', '<', and other non-alphanumeric chars
# Examples:
# "( test" → "(test"
txt = re.sub(r"([^a-z0-9.,\)>]) +([^ ])", r"\1\2", txt, flags=re.IGNORECASE)
# Second pass: Remove spaces before right-boundary characters
# Matches: [non-space] + [non-alphanumeric-and-specific-left-punctuation]
# Removes spaces before characters like non-')', non-',', non-'.', and non-alphanumeric chars
# Examples:
# "world !" → "world!"
return re.sub(r"([^ ]) +([^a-z0-9.,\(<])", r"\1\2", txt, flags=re.IGNORECASE)
def clean_markdown_block(text):
"""
Remove Markdown code block syntax from the beginning and end of text.
This function cleans Markdown code blocks by removing:
- Opening ```Markdown tags (with optional whitespace and newlines)
- Closing ``` tags (with optional whitespace and newlines)
Args:
text (str): Input text that may be wrapped in Markdown code blocks
Returns:
str: Cleaned text with Markdown code block syntax removed, and stripped of surrounding whitespace
"""
# Remove opening ```markdown tag with optional whitespace and newlines
# Matches: optional whitespace + ```markdown + optional whitespace + optional newline
text = re.sub(r'^\s*```markdown\s*\n?', '', text)
# Remove closing ``` tag with optional whitespace and newlines
# Matches: optional newline + optional whitespace + ``` + optional whitespace at end
text = re.sub(r'\n?\s*```\s*$', '', text)
# Return text with surrounding whitespace removed
return text.strip()

View File

@@ -0,0 +1,59 @@
import os
import tiktoken
from .file_utils import get_project_base_directory
tiktoken_cache_dir = os.path.join(get_project_base_directory(), "res")
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
encoder = tiktoken.get_encoding("cl100k_base")
def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
try:
code_list = encoder.encode(string)
return len(code_list)
except Exception:
return 0
def total_token_count_from_response(resp):
if resp is None:
return 0
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
try:
return resp.usage.total_tokens
except Exception:
pass
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
try:
return resp.usage_metadata.total_tokens
except Exception:
pass
if 'usage' in resp and 'total_tokens' in resp['usage']:
try:
return resp["usage"]["total_tokens"]
except Exception:
pass
if 'usage' in resp and 'input_tokens' in resp['usage'] and 'output_tokens' in resp['usage']:
try:
return resp["usage"]["input_tokens"] + resp["usage"]["output_tokens"]
except Exception:
pass
if 'meta' in resp and 'tokens' in resp['meta'] and 'input_tokens' in resp['meta']['tokens'] and 'output_tokens' in resp['meta']['tokens']:
try:
return resp["meta"]["tokens"]["input_tokens"] + resp["meta"]["tokens"]["output_tokens"]
except Exception:
pass
return 0
def truncate(string: str, max_len: int) -> str:
"""Returns truncated text if the length of text exceed max_len."""
return encoder.decode(encoder.encode(string)[:max_len])

View File

@@ -0,0 +1,122 @@
English | [简体中文](./README_zh.md)
# *Deep*Doc
- [1. Introduction](#1)
- [2. Vision](#2)
- [3. Parser](#3)
<a name="1"></a>
## 1. Introduction
With a bunch of documents from various domains with various formats and along with diverse retrieval requirements,
an accurate analysis becomes a very challenge task. *Deep*Doc is born for that purpose.
There are 2 parts in *Deep*Doc so far: vision and parser.
You can run the flowing test programs if you're interested in our results of OCR, layout recognition and TSR.
```bash
python deepdoc/vision/t_ocr.py -h
usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR]
options:
-h, --help show this help message and exit
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
--output_dir OUTPUT_DIR
Directory where to store the output images. Default: './ocr_outputs'
```
```bash
python deepdoc/vision/t_recognizer.py -h
usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}]
options:
-h, --help show this help message and exit
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
--output_dir OUTPUT_DIR
Directory where to store the output images. Default: './layouts_outputs'
--threshold THRESHOLD
A threshold to filter out detections. Default: 0.5
--mode {layout,tsr} Task mode: layout recognition or table structure recognition
```
Our models are served on HuggingFace. If you have trouble downloading HuggingFace models, this might help!!
```bash
export HF_ENDPOINT=https://hf-mirror.com
```
<a name="2"></a>
## 2. Vision
We use vision information to resolve problems as human being.
- OCR. Since a lot of documents presented as images or at least be able to transform to image,
OCR is a very essential and fundamental or even universal solution for text extraction.
```bash
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
```
The inputs could be directory to images or PDF, or a image or PDF.
You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results,
txt files which contain the OCR text.
<div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/f25bee3d-aaf7-4102-baf5-d5208361d110" width="900"/>
</div>
- Layout recognition. Documents from different domain may have various layouts,
like, newspaper, magazine, book and résumé are distinct in terms of layout.
Only when machine have an accurate layout analysis, it can decide if these text parts are successive or not,
or this part needs Table Structure Recognition(TSR) to process, or this part is a figure and described with this caption.
We have 10 basic layout components which covers most cases:
- Text
- Title
- Figure
- Figure caption
- Table
- Table caption
- Header
- Footer
- Reference
- Equation
Have a try on the following command to see the layout detection results.
```bash
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
```
The inputs could be directory to images or PDF, or a image or PDF.
You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following:
<div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
</div>
- Table Structure Recognition(TSR). Data table is a frequently used structure to present data including numbers or text.
And the structure of a table might be very complex, like hierarchy headers, spanning cells and projected row headers.
Along with TSR, we also reassemble the content into sentences which could be well comprehended by LLM.
We have five labels for TSR task:
- Column
- Row
- Column header
- Projected row header
- Spanning cell
Have a try on the following command to see the layout detection results.
```bash
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result
```
The inputs could be directory to images or PDF, or a image or PDF.
You can look into the folder 'path_to_store_result' where has both images and html pages which demonstrate the detection results as following:
<div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/>
</div>
<a name="3"></a>
## 3. Parser
Four kinds of document formats as PDF, DOCX, EXCEL and PPT have their corresponding parser.
The most complex one is PDF parser since PDF's flexibility. The output of PDF parser includes:
- Text chunks with their own positions in PDF(page number and rectangular positions).
- Tables with cropped image from the PDF, and contents which has already translated into natural language sentences.
- Figures with caption and text in the figures.
### Résumé
The résumé is a very complicated kind of document. A résumé which is composed of unstructured text
with various layouts could be resolved into structured data composed of nearly a hundred of fields.
We haven't opened the parser yet, as we open the processing method after parsing procedure.

View File

@@ -0,0 +1,116 @@
[English](./README.md) | 简体中文
# *Deep*Doc
- [*Deep*Doc](#deepdoc)
- [1. 介绍](#1-介绍)
- [2. 视觉处理](#2-视觉处理)
- [3. 解析器](#3-解析器)
- [简历](#简历)
<a name="1"></a>
## 1. 介绍
对于来自不同领域、具有不同格式和不同检索要求的大量文档,准确的分析成为一项极具挑战性的任务。*Deep*Doc 就是为了这个目的而诞生的。到目前为止,*Deep*Doc 中有两个组成部分视觉处理和解析器。如果您对我们的OCR、布局识别和TSR结果感兴趣您可以运行下面的测试程序。
```bash
python deepdoc/vision/t_ocr.py -h
usage: t_ocr.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR]
options:
-h, --help show this help message and exit
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
--output_dir OUTPUT_DIR
Directory where to store the output images. Default: './ocr_outputs'
```
```bash
python deepdoc/vision/t_recognizer.py -h
usage: t_recognizer.py [-h] --inputs INPUTS [--output_dir OUTPUT_DIR] [--threshold THRESHOLD] [--mode {layout,tsr}]
options:
-h, --help show this help message and exit
--inputs INPUTS Directory where to store images or PDFs, or a file path to a single image or PDF
--output_dir OUTPUT_DIR
Directory where to store the output images. Default: './layouts_outputs'
--threshold THRESHOLD
A threshold to filter out detections. Default: 0.5
--mode {layout,tsr} Task mode: layout recognition or table structure recognition
```
HuggingFace为我们的模型提供服务。如果你在下载HuggingFace模型时遇到问题这可能会有所帮助
```bash
export HF_ENDPOINT=https://hf-mirror.com
```
<a name="2"></a>
## 2. 视觉处理
作为人类,我们使用视觉信息来解决问题。
- **OCROptical Character Recognition光学字符识别**。由于许多文档都是以图像形式呈现的或者至少能够转换为图像因此OCR是文本提取的一个非常重要、基本甚至通用的解决方案。
```bash
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
```
输入可以是图像或PDF的目录或者单个图像、PDF文件。您可以查看文件夹 `path_to_store_result` 其中有演示结果位置的图像以及包含OCR文本的txt文件。
<div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/f25bee3d-aaf7-4102-baf5-d5208361d110" width="900"/>
</div>
- 布局识别Layout recognition。来自不同领域的文件可能有不同的布局如报纸、杂志、书籍和简历在布局方面是不同的。只有当机器有准确的布局分析时它才能决定这些文本部分是连续的还是不连续的或者这个部分需要表结构识别Table Structure RecognitionTSR来处理或者这个部件是一个图形并用这个标题来描述。我们有10个基本布局组件涵盖了大多数情况
- 文本
- 标题
- 配图
- 配图标题
- 表格
- 表格标题
- 页头
- 页尾
- 参考引用
- 公式
请尝试以下命令以查看布局检测结果。
```bash
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
```
输入可以是图像或PDF的目录或者单个图像、PDF文件。您可以查看文件夹 `path_to_store_result` ,其中有显示检测结果的图像,如下所示:
<div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
</div>
- **TSRTable Structure Recognition表结构识别**。数据表是一种常用的结构用于表示包括数字或文本在内的数据。表的结构可能非常复杂比如层次结构标题、跨单元格和投影行标题。除了TSR我们还将内容重新组合成LLM可以很好理解的句子。TSR任务有五个标签
- 列
- 行
- 列标题
- 行标题
- 合并单元格
请尝试以下命令以查看布局检测结果。
```bash
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=tsr --output_dir=path_to_store_result
```
输入可以是图像或PDF的目录或者单个图像、PDF文件。您可以查看文件夹 `path_to_store_result` 其中包含图像和html页面这些页面展示了以下检测结果
<div align="center" style="margin-top:20px;margin-bottom:20px;">
<img src="https://github.com/infiniflow/ragflow/assets/12318111/cb24e81b-f2ba-49f3-ac09-883d75606f4c" width="1000"/>
</div>
<a name="3"></a>
## 3. 解析器
PDF、DOCX、EXCEL和PPT四种文档格式都有相应的解析器。最复杂的是PDF解析器因为PDF具有灵活性。PDF解析器的输出包括
- 在PDF中有自己位置的文本块页码和矩形位置
- 带有PDF裁剪图像的表格以及已经翻译成自然语言句子的内容。
- 图中带标题和文字的图。
### 简历
简历是一种非常复杂的文档。由各种格式的非结构化文本构成的简历可以被解析为包含近百个字段的结构化数据。我们还没有启用解析器,因为在解析过程之后才会启动处理方法。

View File

@@ -0,0 +1,2 @@
from beartype.claw import beartype_this_package
beartype_this_package()

View File

@@ -0,0 +1,24 @@
from .docx_parser import RAGDocxParser as DocxParser
from .excel_parser import RAGExcelParser as ExcelParser
from .html_parser import RAGHtmlParser as HtmlParser
from .json_parser import RAGJsonParser as JsonParser
from .markdown_parser import MarkdownElementExtractor
from .markdown_parser import RAGMarkdownParser as MarkdownParser
from .pdf_parser import PlainParser
from .pdf_parser import RAGPdfParser as PdfParser
from .ppt_parser import RAGPptParser as PptParser
from .txt_parser import RAGTxtParser as TxtParser
__all__ = [
"PdfParser",
"PlainParser",
"DocxParser",
"ExcelParser",
"PptParser",
"HtmlParser",
"JsonParser",
"MarkdownParser",
"TxtParser",
"MarkdownElementExtractor",
]

View File

@@ -0,0 +1,123 @@
from docx import Document
import re
import pandas as pd
from collections import Counter
from app.core.rag.nlp import rag_tokenizer
from io import BytesIO
class RAGDocxParser:
def __extract_table_content(self, tb):
df = []
for row in tb.rows:
df.append([c.text for c in row.cells])
return self.__compose_table_content(pd.DataFrame(df))
def __compose_table_content(self, df):
def blockType(b):
pattern = [
("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
(r"^(20|19)[0-9]{2}年$", "Dt"),
(r"^(20|19)[0-9]{2}[年/-][0-9]{1,2}月*$", "Dt"),
("^[0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
(r"^第*[一二三四1-4]季度$", "Dt"),
(r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
(r"^(20|19)[0-9]{2}[ABCDE]$", "DT"),
("^[0-9.,+%/ -]+$", "Nu"),
(r"^[0-9A-Z/\._~-]+$", "Ca"),
(r"^[A-Z]*[a-z' -]+$", "En"),
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()' -]+$", "NE"),
(r"^.{1}$", "Sg")
]
for p, n in pattern:
if re.search(p, b):
return n
tks = [t for t in rag_tokenizer.tokenize(b).split() if len(t) > 1]
if len(tks) > 3:
if len(tks) < 12:
return "Tx"
else:
return "Lx"
if len(tks) == 1 and rag_tokenizer.tag(tks[0]) == "nr":
return "Nr"
return "Ot"
if len(df) < 2:
return []
max_type = Counter([blockType(str(df.iloc[i, j])) for i in range(
1, len(df)) for j in range(len(df.iloc[i, :]))])
max_type = max(max_type.items(), key=lambda x: x[1])[0]
colnm = len(df.iloc[0, :])
hdrows = [0] # header is not necessarily appear in the first line
if max_type == "Nu":
for r in range(1, len(df)):
tys = Counter([blockType(str(df.iloc[r, j]))
for j in range(len(df.iloc[r, :]))])
tys = max(tys.items(), key=lambda x: x[1])[0]
if tys != max_type:
hdrows.append(r)
lines = []
for i in range(1, len(df)):
if i in hdrows:
continue
hr = [r - i for r in hdrows]
hr = [r for r in hr if r < 0]
t = len(hr) - 1
while t > 0:
if hr[t] - hr[t - 1] > 1:
hr = hr[t:]
break
t -= 1
headers = []
for j in range(len(df.iloc[i, :])):
t = []
for h in hr:
x = str(df.iloc[i + h, j]).strip()
if x in t:
continue
t.append(x)
t = ",".join(t)
if t:
t += ": "
headers.append(t)
cells = []
for j in range(len(df.iloc[i, :])):
if not str(df.iloc[i, j]):
continue
cells.append(headers[j] + str(df.iloc[i, j]))
lines.append(";".join(cells))
if colnm > 3:
return lines
return ["\n".join(lines)]
def __call__(self, fnm, from_page=0, to_page=100000000):
self.doc = Document(fnm) if isinstance(
fnm, str) else Document(BytesIO(fnm))
pn = 0 # parsed page
secs = [] # parsed contents
for p in self.doc.paragraphs:
if pn > to_page:
break
runs_within_single_paragraph = [] # save runs within the range of pages
for run in p.runs:
if pn > to_page:
break
if from_page <= pn < to_page and p.text.strip():
runs_within_single_paragraph.append(run.text) # append run.text first
# wrap page break checker into a static method
if 'lastRenderedPageBreak' in run._element.xml:
pn += 1
secs.append(("".join(runs_within_single_paragraph), p.style.name if hasattr(p.style, 'name') else '')) # then concat run.text as part of the paragraph
tbls = [self.__extract_table_content(tb) for tb in self.doc.tables]
return secs, tbls

View File

@@ -0,0 +1,210 @@
import logging
import re
import sys
from io import BytesIO
import pandas as pd
from openpyxl import Workbook, load_workbook
from app.core.rag.nlp import find_codec
# copied from `/openpyxl/cell/cell.py`
ILLEGAL_CHARACTERS_RE = re.compile(r"[\000-\010]|[\013-\014]|[\016-\037]")
class RAGExcelParser:
@staticmethod
def _load_excel_to_workbook(file_like_object):
if isinstance(file_like_object, bytes):
file_like_object = BytesIO(file_like_object)
# Read first 4 bytes to determine file type
file_like_object.seek(0)
file_head = file_like_object.read(4)
file_like_object.seek(0)
if not (file_head.startswith(b"PK\x03\x04") or file_head.startswith(b"\xd0\xcf\x11\xe0")):
logging.info("Not an Excel file, converting CSV to Excel Workbook")
try:
file_like_object.seek(0)
df = pd.read_csv(file_like_object)
return RAGExcelParser._dataframe_to_workbook(df)
except Exception as e_csv:
raise Exception(f"Failed to parse CSV and convert to Excel Workbook: {e_csv}")
try:
return load_workbook(file_like_object, data_only=True)
except Exception as e:
logging.info(f"openpyxl load error: {e}, try pandas instead")
try:
file_like_object.seek(0)
try:
dfs = pd.read_excel(file_like_object, sheet_name=None)
return RAGExcelParser._dataframe_to_workbook(dfs)
except Exception as ex:
logging.info(f"pandas with default engine load error: {ex}, try calamine instead")
file_like_object.seek(0)
df = pd.read_excel(file_like_object, engine="calamine")
return RAGExcelParser._dataframe_to_workbook(df)
except Exception as e_pandas:
raise Exception(f"pandas.read_excel error: {e_pandas}, original openpyxl error: {e}")
@staticmethod
def _clean_dataframe(df: pd.DataFrame):
def clean_string(s):
if isinstance(s, str):
return ILLEGAL_CHARACTERS_RE.sub(" ", s)
return s
return df.apply(lambda col: col.map(clean_string))
@staticmethod
def _dataframe_to_workbook(df):
# if contains multiple sheets use _dataframes_to_workbook
if isinstance(df, dict) and len(df) > 1:
return RAGExcelParser._dataframes_to_workbook(df)
df = RAGExcelParser._clean_dataframe(df)
wb = Workbook()
ws = wb.active
ws.title = "Data"
for col_num, column_name in enumerate(df.columns, 1):
ws.cell(row=1, column=col_num, value=column_name)
for row_num, row in enumerate(df.values, 2):
for col_num, value in enumerate(row, 1):
ws.cell(row=row_num, column=col_num, value=value)
return wb
@staticmethod
def _dataframes_to_workbook(dfs: dict):
wb = Workbook()
default_sheet = wb.active
wb.remove(default_sheet)
for sheet_name, df in dfs.items():
df = RAGExcelParser._clean_dataframe(df)
ws = wb.create_sheet(title=sheet_name)
for col_num, column_name in enumerate(df.columns, 1):
ws.cell(row=1, column=col_num, value=column_name)
for row_num, row in enumerate(df.values, 2):
for col_num, value in enumerate(row, 1):
ws.cell(row=row_num, column=col_num, value=value)
return wb
def html(self, fnm, chunk_rows=256):
from html import escape
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
wb = RAGExcelParser._load_excel_to_workbook(file_like_object)
tb_chunks = []
def _fmt(v):
if v is None:
return ""
return str(v).strip()
for sheetname in wb.sheetnames:
ws = wb[sheetname]
try:
rows = list(ws.rows)
except Exception as e:
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
continue
if not rows:
continue
tb_rows_0 = "<tr>"
for t in list(rows[0]):
tb_rows_0 += f"<th>{escape(_fmt(t.value))}</th>"
tb_rows_0 += "</tr>"
for chunk_i in range((len(rows) - 1) // chunk_rows + 1):
tb = ""
tb += f"<table><caption>{sheetname}</caption>"
tb += tb_rows_0
for r in list(rows[1 + chunk_i * chunk_rows : min(1 + (chunk_i + 1) * chunk_rows, len(rows))]):
tb += "<tr>"
for i, c in enumerate(r):
if c.value is None:
tb += "<td></td>"
else:
tb += f"<td>{escape(_fmt(c.value))}</td>"
tb += "</tr>"
tb += "</table>\n"
tb_chunks.append(tb)
return tb_chunks
def markdown(self, fnm):
import pandas as pd
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
try:
file_like_object.seek(0)
df = pd.read_excel(file_like_object)
except Exception as e:
logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file")
file_like_object.seek(0)
df = pd.read_csv(file_like_object)
df = df.replace(r"^\s*$", "", regex=True)
return df.to_markdown(index=False)
def __call__(self, fnm):
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
wb = RAGExcelParser._load_excel_to_workbook(file_like_object)
res = []
for sheetname in wb.sheetnames:
ws = wb[sheetname]
try:
rows = list(ws.rows)
except Exception as e:
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
continue
if not rows:
continue
ti = list(rows[0])
for r in list(rows[1:]):
fields = []
for i, c in enumerate(r):
if not c.value:
continue
t = str(ti[i].value) if i < len(ti) else ""
t += ("" if t else "") + str(c.value)
fields.append(t)
line = "; ".join(fields)
if sheetname.lower().find("sheet") < 0:
line += " ——" + sheetname
res.append(line)
return res
@staticmethod
def row_number(fnm, binary):
if fnm.split(".")[-1].lower().find("xls") >= 0:
wb = RAGExcelParser._load_excel_to_workbook(BytesIO(binary))
total = 0
for sheetname in wb.sheetnames:
try:
ws = wb[sheetname]
total += len(list(ws.rows))
except Exception as e:
logging.warning(f"Skip sheet '{sheetname}' due to rows access error: {e}")
continue
return total
if fnm.split(".")[-1].lower() in ["csv", "txt"]:
encoding = find_codec(binary)
txt = binary.decode(encoding, errors="ignore")
return len(txt.split("\n"))
if __name__ == "__main__":
psr = RAGExcelParser()
psr(sys.argv[1])

View File

@@ -0,0 +1,118 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from PIL import Image
from app.core.rag.common.constants import LLMType
from app.core.rag.common.connection_utils import timeout
from app.core.rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
from app.core.rag.prompts.generator import vision_llm_figure_describe_prompt
def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
return [
(
(figure_data[1], [figure_data[0]]),
[(0, 0, 0, 0, 0)],
)
for figure_data in figures_data_without_positions
if isinstance(figure_data[1], Image.Image)
]
def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,vision_model=None,**kwargs):
if vision_model:
figures_data = vision_figure_parser_figure_data_wrapper(sections)
try:
docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs)
boosted_figures = docx_vision_parser(callback=callback)
tbls.extend(boosted_figures)
except Exception as e:
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
return tbls
def vision_figure_parser_pdf_wrapper(tbls,callback=None,vision_model=None,**kwargs):
if vision_model:
def is_figure_item(item):
return (
isinstance(item[0][0], Image.Image) and
isinstance(item[0][1], list)
)
figures_data = [item for item in tbls if is_figure_item(item)]
try:
docx_vision_parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs)
boosted_figures = docx_vision_parser(callback=callback)
tbls = [item for item in tbls if not is_figure_item(item)]
tbls.extend(boosted_figures)
except Exception as e:
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
return tbls
shared_executor = ThreadPoolExecutor(max_workers=10)
class VisionFigureParser:
def __init__(self, vision_model, figures_data, *args, **kwargs):
self.vision_model = vision_model
self._extract_figures_info(figures_data)
assert len(self.figures) == len(self.descriptions)
assert not self.positions or (len(self.figures) == len(self.positions))
def _extract_figures_info(self, figures_data):
self.figures = []
self.descriptions = []
self.positions = []
for item in figures_data:
# position
if len(item) == 2 and isinstance(item[0], tuple) and len(item[0]) == 2 and isinstance(item[1], list) and isinstance(item[1][0], tuple) and len(item[1][0]) == 5:
img_desc = item[0]
assert len(img_desc) == 2 and isinstance(img_desc[0], Image.Image) and isinstance(img_desc[1], list), "Should be (figure, [description])"
self.figures.append(img_desc[0])
self.descriptions.append(img_desc[1])
self.positions.append(item[1])
else:
assert len(item) == 2 and isinstance(item[0], Image.Image) and isinstance(item[1], list), f"Unexpected form of figure data: get {len(item)=}, {item=}"
self.figures.append(item[0])
self.descriptions.append(item[1])
def _assemble(self):
self.assembled = []
self.has_positions = len(self.positions) != 0
for i in range(len(self.figures)):
figure = self.figures[i]
desc = self.descriptions[i]
pos = self.positions[i] if self.has_positions else None
figure_desc = (figure, desc)
if pos is not None:
self.assembled.append((figure_desc, pos))
else:
self.assembled.append((figure_desc,))
return self.assembled
def __call__(self, **kwargs):
callback = kwargs.get("callback", lambda prog, msg: None)
@timeout(30, 3)
def process(figure_idx, figure_binary):
description_text = picture_vision_llm_chunk(
binary=figure_binary,
vision_model=self.vision_model,
prompt=vision_llm_figure_describe_prompt(),
callback=callback,
)
return figure_idx, description_text
futures = []
for idx, img_binary in enumerate(self.figures or []):
futures.append(shared_executor.submit(process, idx, img_binary))
for future in as_completed(futures):
figure_num, txt = future.result()
if txt:
self.descriptions[figure_num] = txt + "\n".join(self.descriptions[figure_num])
self._assemble()
return self.assembled

View File

@@ -0,0 +1,197 @@
from app.core.rag.nlp import find_codec, rag_tokenizer
import uuid
import chardet
from bs4 import BeautifulSoup, NavigableString, Tag, Comment
import html
def get_encoding(file):
with open(file,'rb') as f:
tmp = chardet.detect(f.read())
return tmp['encoding']
BLOCK_TAGS = [
"h1", "h2", "h3", "h4", "h5", "h6",
"p", "div", "article", "section", "aside",
"ul", "ol", "li",
"table", "pre", "code", "blockquote",
"figure", "figcaption"
]
TITLE_TAGS = {"h1": "#", "h2": "##", "h3": "###", "h4": "#####", "h5": "#####", "h6": "######"}
class RAGHtmlParser:
def __call__(self, fnm, binary=None, chunk_token_num=512):
if binary:
encoding = find_codec(binary)
txt = binary.decode(encoding, errors="ignore")
else:
with open(fnm, "r",encoding=get_encoding(fnm)) as f:
txt = f.read()
return self.parser_txt(txt, chunk_token_num)
@classmethod
def parser_txt(cls, txt, chunk_token_num):
if not isinstance(txt, str):
raise TypeError("txt type should be string!")
temp_sections = []
soup = BeautifulSoup(txt, "html5lib")
# delete <style> tag
for style_tag in soup.find_all(["style", "script"]):
style_tag.decompose()
# delete <script> tag in <div>
for div_tag in soup.find_all("div"):
for script_tag in div_tag.find_all("script"):
script_tag.decompose()
# delete inline style
for tag in soup.find_all(True):
if 'style' in tag.attrs:
del tag.attrs['style']
# delete HTML comment
for comment in soup.find_all(string=lambda text: isinstance(text, Comment)):
comment.extract()
cls.read_text_recursively(soup.body, temp_sections, chunk_token_num=chunk_token_num)
block_txt_list, table_list = cls.merge_block_text(temp_sections)
sections = cls.chunk_block(block_txt_list, chunk_token_num=chunk_token_num)
for table in table_list:
sections.append(table.get("content", ""))
return sections
@classmethod
def split_table(cls, html_table, chunk_token_num=512):
soup = BeautifulSoup(html_table, "html.parser")
rows = soup.find_all("tr")
tables = []
current_table = []
current_count = 0
table_str_list = []
for row in rows:
tks_str = rag_tokenizer.tokenize(str(row))
token_count = len(tks_str.split(" ")) if tks_str else 0
if current_count + token_count > chunk_token_num:
tables.append(current_table)
current_table = []
current_count = 0
current_table.append(row)
current_count += token_count
if current_table:
tables.append(current_table)
for table_rows in tables:
new_table = soup.new_tag("table")
for row in table_rows:
new_table.append(row)
table_str_list.append(str(new_table))
return table_str_list
@classmethod
def read_text_recursively(cls, element, parser_result, chunk_token_num=512, parent_name=None, block_id=None):
if isinstance(element, NavigableString):
content = element.strip()
def is_valid_html(content):
try:
soup = BeautifulSoup(content, "html.parser")
return bool(soup.find())
except Exception:
return False
return_info = []
if content:
if is_valid_html(content):
soup = BeautifulSoup(content, "html.parser")
child_info = cls.read_text_recursively(soup, parser_result, chunk_token_num, element.name, block_id)
parser_result.extend(child_info)
else:
info = {"content": element.strip(), "tag_name": "inner_text", "metadata": {"block_id": block_id}}
if parent_name:
info["tag_name"] = parent_name
return_info.append(info)
return return_info
elif isinstance(element, Tag):
if str.lower(element.name) == "table":
table_info_list = []
table_id = str(uuid.uuid1())
table_list = [html.unescape(str(element))]
for t in table_list:
table_info_list.append({"content": t, "tag_name": "table",
"metadata": {"table_id": table_id, "index": table_list.index(t)}})
return table_info_list
else:
block_id = None
if str.lower(element.name) in BLOCK_TAGS:
block_id = str(uuid.uuid1())
for child in element.children:
child_info = cls.read_text_recursively(child, parser_result, chunk_token_num, element.name,
block_id)
parser_result.extend(child_info)
return []
@classmethod
def merge_block_text(cls, parser_result):
block_content = []
current_content = ""
table_info_list = []
lask_block_id = None
for item in parser_result:
content = item.get("content")
tag_name = item.get("tag_name")
title_flag = tag_name in TITLE_TAGS
block_id = item.get("metadata", {}).get("block_id")
if block_id:
if title_flag:
content = f"{TITLE_TAGS[tag_name]} {content}"
if lask_block_id != block_id:
if lask_block_id is not None:
block_content.append(current_content)
current_content = content
lask_block_id = block_id
else:
current_content += (" " if current_content else "") + content
else:
if tag_name == "table":
table_info_list.append(item)
else:
current_content += (" " if current_content else "" + content)
if current_content:
block_content.append(current_content)
return block_content, table_info_list
@classmethod
def chunk_block(cls, block_txt_list, chunk_token_num=512):
chunks = []
current_block = ""
current_token_count = 0
for block in block_txt_list:
tks_str = rag_tokenizer.tokenize(block)
block_token_count = len(tks_str.split(" ")) if tks_str else 0
if block_token_count > chunk_token_num:
if current_block:
chunks.append(current_block)
start = 0
tokens = tks_str.split(" ")
while start < len(tokens):
end = start + chunk_token_num
split_tokens = tokens[start:end]
chunks.append(" ".join(split_tokens))
start = end
current_block = ""
current_token_count = 0
else:
if current_token_count + block_token_count <= chunk_token_num:
current_block += ("\n" if current_block else "") + block
current_token_count += block_token_count
else:
chunks.append(current_block)
current_block = block
current_token_count = block_token_count
if current_block:
chunks.append(current_block)
return chunks

View File

@@ -0,0 +1,159 @@
import json
from typing import Any
from app.core.rag.nlp import find_codec
class RAGJsonParser:
def __init__(self, max_chunk_size: int = 2000, min_chunk_size: int | None = None):
super().__init__()
self.max_chunk_size = max_chunk_size * 2
self.min_chunk_size = min_chunk_size if min_chunk_size is not None else max(max_chunk_size - 200, 50)
def __call__(self, filename):
with open(filename, "r") as f:
txt = f.read()
if self.is_jsonl_format(txt):
sections = self._parse_jsonl(txt)
else:
sections = self._parse_json(txt)
return sections
@staticmethod
def _json_size(data: dict) -> int:
"""Calculate the size of the serialized JSON object."""
return len(json.dumps(data, ensure_ascii=False))
@staticmethod
def _set_nested_dict(d: dict, path: list[str], value: Any) -> None:
"""Set a value in a nested dictionary based on the given path."""
for key in path[:-1]:
d = d.setdefault(key, {})
d[path[-1]] = value
def _list_to_dict_preprocessing(self, data: Any) -> Any:
if isinstance(data, dict):
# Process each key-value pair in the dictionary
return {k: self._list_to_dict_preprocessing(v) for k, v in data.items()}
elif isinstance(data, list):
# Convert the list to a dictionary with index-based keys
return {str(i): self._list_to_dict_preprocessing(item) for i, item in enumerate(data)}
else:
# Base case: the item is neither a dict nor a list, so return it unchanged
return data
def _json_split(
self,
data,
current_path: list[str] | None,
chunks: list[dict] | None,
) -> list[dict]:
"""
Split json into maximum size dictionaries while preserving structure.
"""
current_path = current_path or []
chunks = chunks or [{}]
if isinstance(data, dict):
for key, value in data.items():
new_path = current_path + [key]
chunk_size = self._json_size(chunks[-1])
size = self._json_size({key: value})
remaining = self.max_chunk_size - chunk_size
if size < remaining:
# Add item to current chunk
self._set_nested_dict(chunks[-1], new_path, value)
else:
if chunk_size >= self.min_chunk_size:
# Chunk is big enough, start a new chunk
chunks.append({})
# Iterate
self._json_split(value, new_path, chunks)
else:
# handle single item
self._set_nested_dict(chunks[-1], current_path, data)
return chunks
def split_json(
self,
json_data,
convert_lists: bool = False,
) -> list[dict]:
"""Splits JSON into a list of JSON chunks"""
if convert_lists:
preprocessed_data = self._list_to_dict_preprocessing(json_data)
chunks = self._json_split(preprocessed_data, None, None)
else:
chunks = self._json_split(json_data, None, None)
# Remove the last chunk if it's empty
if not chunks[-1]:
chunks.pop()
return chunks
def split_text(
self,
json_data: dict[str, Any],
convert_lists: bool = False,
ensure_ascii: bool = True,
) -> list[str]:
"""Splits JSON into a list of JSON formatted strings"""
chunks = self.split_json(json_data=json_data, convert_lists=convert_lists)
# Convert to string
return [json.dumps(chunk, ensure_ascii=ensure_ascii) for chunk in chunks]
def _parse_json(self, content: str) -> list[str]:
sections = []
try:
json_data = json.loads(content)
chunks = self.split_json(json_data, True)
sections = [json.dumps(line, ensure_ascii=False) for line in chunks if line]
except json.JSONDecodeError:
pass
return sections
def _parse_jsonl(self, content: str) -> list[str]:
lines = content.strip().splitlines()
all_chunks = []
for line in lines:
if not line.strip():
continue
try:
data = json.loads(line)
chunks = self.split_json(data, convert_lists=True)
all_chunks.extend(json.dumps(chunk, ensure_ascii=False) for chunk in chunks if chunk)
except json.JSONDecodeError:
continue
return all_chunks
def is_jsonl_format(self, txt: str, sample_limit: int = 10, threshold: float = 0.8) -> bool:
lines = [line.strip() for line in txt.strip().splitlines() if line.strip()]
if not lines:
return False
try:
json.loads(txt)
return False
except json.JSONDecodeError:
pass
sample_limit = min(len(lines), sample_limit)
sample_lines = lines[:sample_limit]
valid_lines = sum(1 for line in sample_lines if self._is_valid_json(line))
if not valid_lines:
return False
return (valid_lines / len(sample_lines)) >= threshold
def _is_valid_json(self, line: str) -> bool:
try:
json.loads(line)
return True
except json.JSONDecodeError:
return False

View File

@@ -0,0 +1,277 @@
import re
from markdown import markdown
class RAGMarkdownParser:
def __init__(self, chunk_token_num=128):
self.chunk_token_num = int(chunk_token_num)
def extract_tables_and_remainder(self, markdown_text, separate_tables=True):
tables = []
working_text = markdown_text
def replace_tables_with_rendered_html(pattern, table_list, render=True):
new_text = ""
last_end = 0
for match in pattern.finditer(working_text):
raw_table = match.group()
table_list.append(raw_table)
if separate_tables:
# Skip this match (i.e., remove it)
new_text += working_text[last_end : match.start()] + "\n\n"
else:
# Replace with rendered HTML
html_table = markdown(raw_table, extensions=["markdown.extensions.tables"]) if render else raw_table
new_text += working_text[last_end : match.start()] + html_table + "\n\n"
last_end = match.end()
new_text += working_text[last_end:]
return new_text
if "|" in markdown_text: # for optimize performance
# Standard Markdown table
border_table_pattern = re.compile(
r"""
(?:\n|^)
(?:\|.*?\|.*?\|.*?\n)
(?:\|(?:\s*[:-]+[-| :]*\s*)\|.*?\n)
(?:\|.*?\|.*?\|.*?\n)+
""",
re.VERBOSE,
)
working_text = replace_tables_with_rendered_html(border_table_pattern, tables)
# Borderless Markdown table
no_border_table_pattern = re.compile(
r"""
(?:\n|^)
(?:\S.*?\|.*?\n)
(?:(?:\s*[:-]+[-| :]*\s*).*?\n)
(?:\S.*?\|.*?\n)+
""",
re.VERBOSE,
)
working_text = replace_tables_with_rendered_html(no_border_table_pattern, tables)
# Replace any TAGS e.g. <table ...> to <table>
TAGS = ["table", "td", "tr", "th", "tbody", "thead", "div"]
table_with_attributes_pattern = re.compile(
rf"<(?:{'|'.join(TAGS)})[^>]*>", re.IGNORECASE
)
def replace_tag(m):
tag_name = re.match(r"<(\w+)", m.group()).group(1)
return "<{}>".format(tag_name)
working_text = re.sub(table_with_attributes_pattern, replace_tag, working_text)
if "<table>" in working_text.lower(): # for optimize performance
# HTML table extraction - handle possible html/body wrapper tags
html_table_pattern = re.compile(
r"""
(?:\n|^)
\s*
(?:
# case1: <html><body><table>...</table></body></html>
(?:<html[^>]*>\s*<body[^>]*>\s*<table[^>]*>.*?</table>\s*</body>\s*</html>)
|
# case2: <body><table>...</table></body>
(?:<body[^>]*>\s*<table[^>]*>.*?</table>\s*</body>)
|
# case3: only<table>...</table>
(?:<table[^>]*>.*?</table>)
)
\s*
(?=\n|$)
""",
re.VERBOSE | re.DOTALL | re.IGNORECASE,
)
def replace_html_tables():
nonlocal working_text
new_text = ""
last_end = 0
for match in html_table_pattern.finditer(working_text):
raw_table = match.group()
tables.append(raw_table)
if separate_tables:
new_text += working_text[last_end : match.start()] + "\n\n"
else:
new_text += working_text[last_end : match.start()] + raw_table + "\n\n"
last_end = match.end()
new_text += working_text[last_end:]
working_text = new_text
replace_html_tables()
return working_text, tables
class MarkdownElementExtractor:
def __init__(self, markdown_content):
self.markdown_content = markdown_content
self.lines = markdown_content.split("\n")
def get_delimiters(self,delimiters):
toks = re.findall(r"`([^`]+)`", delimiters)
toks = sorted(set(toks), key=lambda x: -len(x))
return "|".join(re.escape(t) for t in toks if t)
def extract_elements(self,delimiter=None):
"""Extract individual elements (headers, code blocks, lists, etc.)"""
sections = []
i = 0
dels=""
if delimiter:
dels = self.get_delimiters(delimiter)
if len(dels) > 0:
text = "\n".join(self.lines)
parts = re.split(dels, text)
sections = [p.strip() for p in parts if p and p.strip()]
return sections
while i < len(self.lines):
line = self.lines[i]
if re.match(r"^#{1,6}\s+.*$", line):
# header
element = self._extract_header(i)
sections.append(element["content"])
i = element["end_line"] + 1
elif line.strip().startswith("```"):
# code block
element = self._extract_code_block(i)
sections.append(element["content"])
i = element["end_line"] + 1
elif re.match(r"^\s*[-*+]\s+.*$", line) or re.match(r"^\s*\d+\.\s+.*$", line):
# list block
element = self._extract_list_block(i)
sections.append(element["content"])
i = element["end_line"] + 1
elif line.strip().startswith(">"):
# blockquote
element = self._extract_blockquote(i)
sections.append(element["content"])
i = element["end_line"] + 1
elif line.strip():
# text block (paragraphs and inline elements until next block element)
element = self._extract_text_block(i)
sections.append(element["content"])
i = element["end_line"] + 1
else:
i += 1
sections = [section for section in sections if section.strip()]
return sections
def _extract_header(self, start_pos):
return {
"type": "header",
"content": self.lines[start_pos],
"start_line": start_pos,
"end_line": start_pos,
}
def _extract_code_block(self, start_pos):
end_pos = start_pos
content_lines = [self.lines[start_pos]]
# Find the end of the code block
for i in range(start_pos + 1, len(self.lines)):
content_lines.append(self.lines[i])
end_pos = i
if self.lines[i].strip().startswith("```"):
break
return {
"type": "code_block",
"content": "\n".join(content_lines),
"start_line": start_pos,
"end_line": end_pos,
}
def _extract_list_block(self, start_pos):
end_pos = start_pos
content_lines = []
i = start_pos
while i < len(self.lines):
line = self.lines[i]
# check if this line is a list item or continuation of a list
if (
re.match(r"^\s*[-*+]\s+.*$", line)
or re.match(r"^\s*\d+\.\s+.*$", line)
or (i > start_pos and not line.strip())
or (i > start_pos and re.match(r"^\s{2,}[-*+]\s+.*$", line))
or (i > start_pos and re.match(r"^\s{2,}\d+\.\s+.*$", line))
or (i > start_pos and re.match(r"^\s+\w+.*$", line))
):
content_lines.append(line)
end_pos = i
i += 1
else:
break
return {
"type": "list_block",
"content": "\n".join(content_lines),
"start_line": start_pos,
"end_line": end_pos,
}
def _extract_blockquote(self, start_pos):
end_pos = start_pos
content_lines = []
i = start_pos
while i < len(self.lines):
line = self.lines[i]
if line.strip().startswith(">") or (i > start_pos and not line.strip()):
content_lines.append(line)
end_pos = i
i += 1
else:
break
return {
"type": "blockquote",
"content": "\n".join(content_lines),
"start_line": start_pos,
"end_line": end_pos,
}
def _extract_text_block(self, start_pos):
"""Extract a text block (paragraphs, inline elements) until next block element"""
end_pos = start_pos
content_lines = [self.lines[start_pos]]
i = start_pos + 1
while i < len(self.lines):
line = self.lines[i]
# stop if we encounter a block element
if re.match(r"^#{1,6}\s+.*$", line) or line.strip().startswith("```") or re.match(r"^\s*[-*+]\s+.*$", line) or re.match(r"^\s*\d+\.\s+.*$", line) or line.strip().startswith(">"):
break
elif not line.strip():
# check if the next line is a block element
if i + 1 < len(self.lines) and (
re.match(r"^#{1,6}\s+.*$", self.lines[i + 1])
or self.lines[i + 1].strip().startswith("```")
or re.match(r"^\s*[-*+]\s+.*$", self.lines[i + 1])
or re.match(r"^\s*\d+\.\s+.*$", self.lines[i + 1])
or self.lines[i + 1].strip().startswith(">")
):
break
else:
content_lines.append(line)
end_pos = i
i += 1
else:
content_lines.append(line)
end_pos = i
i += 1
return {
"type": "text_block",
"content": "\n".join(content_lines),
"start_line": start_pos,
"end_line": end_pos,
}

View File

@@ -0,0 +1,524 @@
import json
import logging
import os
import platform
import re
import subprocess
import sys
import tempfile
import threading
import time
import zipfile
from io import BytesIO
from os import PathLike
from pathlib import Path
from queue import Empty, Queue
from typing import Any, Callable, Optional
import numpy as np
import pdfplumber
import requests
from PIL import Image
from strenum import StrEnum
from .pdf_parser import RAGPdfParser
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
class MinerUContentType(StrEnum):
IMAGE = "image"
TABLE = "table"
TEXT = "text"
EQUATION = "equation"
CODE = "code"
LIST = "list"
DISCARDED = "discarded"
class MinerUParser(RAGPdfParser):
def __init__(self, mineru_path: str = "mineru", mineru_api: str = "http://host.docker.internal:9987", mineru_server_url: str = ""):
self.mineru_path = Path(mineru_path)
self.mineru_api = mineru_api.rstrip("/")
self.mineru_server_url = mineru_server_url.rstrip("/")
self.using_api = False
self.logger = logging.getLogger(self.__class__.__name__)
def _extract_zip_no_root(self, zip_path, extract_to, root_dir):
with zipfile.ZipFile(zip_path, "r") as zip_ref:
if not root_dir:
files = zip_ref.namelist()
if files and files[0].endswith("/"):
root_dir = files[0]
else:
root_dir = None
if not root_dir or not root_dir.endswith("/"):
self.logger.info(f"[MinerU] No root directory found, extracting all...fff{root_dir}")
zip_ref.extractall(extract_to)
return
root_len = len(root_dir)
for member in zip_ref.infolist():
filename = member.filename
if filename == root_dir:
self.logger.info("[MinerU] Ignore root folder...")
continue
path = filename
if path.startswith(root_dir):
path = path[root_len:]
full_path = os.path.join(extract_to, path)
if member.is_dir():
os.makedirs(full_path, exist_ok=True)
else:
os.makedirs(os.path.dirname(full_path), exist_ok=True)
with open(full_path, "wb") as f:
f.write(zip_ref.read(filename))
def _is_http_endpoint_valid(self, url, timeout=5):
try:
response = requests.head(url, timeout=timeout, allow_redirects=True)
return response.status_code in [200, 301, 302, 307, 308]
except Exception:
return False
def check_installation(self, backend: str = "pipeline", server_url: Optional[str] = None) -> tuple[bool, str]:
reason = ""
valid_backends = ["pipeline", "vlm-http-client", "vlm-transformers", "vlm-vllm-engine"]
if backend not in valid_backends:
reason = "[MinerU] Invalid backend '{backend}'. Valid backends are: {valid_backends}"
logging.warning(reason)
return False, reason
subprocess_kwargs = {
"capture_output": True,
"text": True,
"check": True,
"encoding": "utf-8",
"errors": "ignore",
}
if platform.system() == "Windows":
subprocess_kwargs["creationflags"] = getattr(subprocess, "CREATE_NO_WINDOW", 0)
if server_url is None:
server_url = self.mineru_server_url
if backend == "vlm-http-client" and server_url:
try:
server_accessible = self._is_http_endpoint_valid(server_url + "/openapi.json")
logging.info(f"[MinerU] vlm-http-client server check: {server_accessible}")
if server_accessible:
self.using_api = False # We are using http client, not API
return True, reason
else:
reason = f"[MinerU] vlm-http-client server not accessible: {server_url}"
logging.warning(f"[MinerU] vlm-http-client server not accessible: {server_url}")
return False, reason
except Exception as e:
logging.warning(f"[MinerU] vlm-http-client server check failed: {e}")
try:
response = requests.get(server_url, timeout=5)
logging.info(f"[MinerU] vlm-http-client server connection check: success with status {response.status_code}")
self.using_api = False
return True, reason
except Exception as e:
reason = f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}"
logging.warning(f"[MinerU] vlm-http-client server connection check failed: {server_url}: {e}")
return False, reason
try:
result = subprocess.run([str(self.mineru_path), "--version"], **subprocess_kwargs)
version_info = result.stdout.strip()
if version_info:
logging.info(f"[MinerU] Detected version: {version_info}")
else:
logging.info("[MinerU] Detected MinerU, but version info is empty.")
return True, reason
except subprocess.CalledProcessError as e:
logging.warning(f"[MinerU] Execution failed (exit code {e.returncode}).")
except FileNotFoundError:
logging.warning("[MinerU] MinerU not found. Please install it via: pip install -U 'mineru[core]'")
except Exception as e:
logging.error(f"[MinerU] Unexpected error during installation check: {e}")
# If executable check fails, try API check
try:
if self.mineru_api:
# check openapi.json
openapi_exists = self._is_http_endpoint_valid(self.mineru_api + "/openapi.json")
if not openapi_exists:
reason = "[MinerU] Failed to detect vaild MinerU API server"
return openapi_exists, reason
logging.info(f"[MinerU] Detected {self.mineru_api}/openapi.json: {openapi_exists}")
self.using_api = openapi_exists
return openapi_exists, reason
else:
logging.info("[MinerU] api not exists.")
except Exception as e:
reason = f"[MinerU] Unexpected error during api check: {e}"
logging.error(f"[MinerU] Unexpected error during api check: {e}")
return False, reason
def _run_mineru(
self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, server_url: Optional[str] = None, callback: Optional[Callable] = None
):
if self.using_api:
self._run_mineru_api(input_path, output_dir, method, backend, lang, callback)
else:
self._run_mineru_executable(input_path, output_dir, method, backend, lang, server_url, callback)
def _run_mineru_api(self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, callback: Optional[Callable] = None):
OUTPUT_ZIP_PATH = os.path.join(str(output_dir), "output.zip")
pdf_file_path = str(input_path)
if not os.path.exists(pdf_file_path):
raise RuntimeError(f"[MinerU] PDF file not exists: {pdf_file_path}")
pdf_file_name = Path(pdf_file_path).stem.strip()
output_path = os.path.join(str(output_dir), pdf_file_name, method)
os.makedirs(output_path, exist_ok=True)
files = {"files": (pdf_file_name + ".pdf", open(pdf_file_path, "rb"), "application/pdf")}
data = {
"output_dir": "./output",
"lang_list": lang,
"backend": backend,
"parse_method": method,
"formula_enable": True,
"table_enable": True,
"server_url": None,
"return_md": True,
"return_middle_json": True,
"return_model_output": True,
"return_content_list": True,
"return_images": True,
"response_format_zip": True,
"start_page_id": 0,
"end_page_id": 99999,
}
headers = {"Accept": "application/json"}
try:
self.logger.info(f"[MinerU] invoke api: {self.mineru_api}/file_parse")
if callback:
callback(0.20, f"[MinerU] invoke api: {self.mineru_api}/file_parse")
response = requests.post(url=f"{self.mineru_api}/file_parse", files=files, data=data, headers=headers, timeout=1800)
response.raise_for_status()
if response.headers.get("Content-Type") == "application/zip":
self.logger.info(f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...")
if callback:
callback(0.30, f"[MinerU] zip file returned, saving to {OUTPUT_ZIP_PATH}...")
with open(OUTPUT_ZIP_PATH, "wb") as f:
f.write(response.content)
self.logger.info(f"[MinerU] Unzip to {output_path}...")
self._extract_zip_no_root(OUTPUT_ZIP_PATH, output_path, pdf_file_name + "/")
if callback:
callback(0.40, f"[MinerU] Unzip to {output_path}...")
else:
self.logger.warning("[MinerU] not zip returned from api%s " % response.headers.get("Content-Type"))
except Exception as e:
raise RuntimeError(f"[MinerU] api failed with exception {e}")
self.logger.info("[MinerU] Api completed successfully.")
def _run_mineru_executable(
self, input_path: Path, output_dir: Path, method: str = "auto", backend: str = "pipeline", lang: Optional[str] = None, server_url: Optional[str] = None, callback: Optional[Callable] = None
):
cmd = [str(self.mineru_path), "-p", str(input_path), "-o", str(output_dir), "-m", method]
if backend:
cmd.extend(["-b", backend])
if lang:
cmd.extend(["-l", lang])
if server_url and backend == "vlm-http-client":
cmd.extend(["-u", server_url])
self.logger.info(f"[MinerU] Running command: {' '.join(cmd)}")
subprocess_kwargs = {
"stdout": subprocess.PIPE,
"stderr": subprocess.PIPE,
"text": True,
"encoding": "utf-8",
"errors": "ignore",
"bufsize": 1,
}
if platform.system() == "Windows":
subprocess_kwargs["creationflags"] = getattr(subprocess, "CREATE_NO_WINDOW", 0)
process = subprocess.Popen(cmd, **subprocess_kwargs)
stdout_queue, stderr_queue = Queue(), Queue()
def enqueue_output(pipe, queue, prefix):
for line in iter(pipe.readline, ""):
if line.strip():
queue.put((prefix, line.strip()))
pipe.close()
threading.Thread(target=enqueue_output, args=(process.stdout, stdout_queue, "STDOUT"), daemon=True).start()
threading.Thread(target=enqueue_output, args=(process.stderr, stderr_queue, "STDERR"), daemon=True).start()
while process.poll() is None:
for q in (stdout_queue, stderr_queue):
try:
while True:
prefix, line = q.get_nowait()
if prefix == "STDOUT":
self.logger.info(f"[MinerU] {line}")
else:
self.logger.warning(f"[MinerU] {line}")
except Empty:
pass
time.sleep(0.1)
return_code = process.wait()
if return_code != 0:
raise RuntimeError(f"[MinerU] Process failed with exit code {return_code}")
self.logger.info("[MinerU] Command completed successfully.")
def __images__(self, fnm, zoomin: int = 1, page_from=0, page_to=600, callback=None):
self.page_from = page_from
self.page_to = page_to
try:
with pdfplumber.open(fnm) if isinstance(fnm, (str, PathLike)) else pdfplumber.open(BytesIO(fnm)) as pdf:
self.pdf = pdf
self.page_images = [p.to_image(resolution=72 * zoomin, antialias=True).original for _, p in enumerate(self.pdf.pages[page_from:page_to])]
except Exception as e:
self.page_images = None
self.total_page = 0
logging.exception(e)
def _line_tag(self, bx):
pn = [bx["page_idx"] + 1]
positions = bx["bbox"]
x0, top, x1, bott = positions
if hasattr(self, "page_images") and self.page_images and len(self.page_images) > bx["page_idx"]:
page_width, page_height = self.page_images[bx["page_idx"]].size
x0 = (x0 / 1000.0) * page_width
x1 = (x1 / 1000.0) * page_width
top = (top / 1000.0) * page_height
bott = (bott / 1000.0) * page_height
return "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##".format("-".join([str(p) for p in pn]), x0, x1, top, bott)
def crop(self, text, ZM=1, need_position=False):
imgs = []
poss = self.extract_positions(text)
if not poss:
if need_position:
return None, None
return
max_width = max(np.max([right - left for (_, left, right, _, _) in poss]), 6)
GAP = 6
pos = poss[0]
poss.insert(0, ([pos[0][0]], pos[1], pos[2], max(0, pos[3] - 120), max(pos[3] - GAP, 0)))
pos = poss[-1]
poss.append(([pos[0][-1]], pos[1], pos[2], min(self.page_images[pos[0][-1]].size[1], pos[4] + GAP), min(self.page_images[pos[0][-1]].size[1], pos[4] + 120)))
positions = []
for ii, (pns, left, right, top, bottom) in enumerate(poss):
right = left + max_width
if bottom <= top:
bottom = top + 2
for pn in pns[1:]:
bottom += self.page_images[pn - 1].size[1]
img0 = self.page_images[pns[0]]
x0, y0, x1, y1 = int(left), int(top), int(right), int(min(bottom, img0.size[1]))
crop0 = img0.crop((x0, y0, x1, y1))
imgs.append(crop0)
if 0 < ii < len(poss) - 1:
positions.append((pns[0] + self.page_from, x0, x1, y0, y1))
bottom -= img0.size[1]
for pn in pns[1:]:
page = self.page_images[pn]
x0, y0, x1, y1 = int(left), 0, int(right), int(min(bottom, page.size[1]))
cimgp = page.crop((x0, y0, x1, y1))
imgs.append(cimgp)
if 0 < ii < len(poss) - 1:
positions.append((pn + self.page_from, x0, x1, y0, y1))
bottom -= page.size[1]
if not imgs:
if need_position:
return None, None
return
height = 0
for img in imgs:
height += img.size[1] + GAP
height = int(height)
width = int(np.max([i.size[0] for i in imgs]))
pic = Image.new("RGB", (width, height), (245, 245, 245))
height = 0
for ii, img in enumerate(imgs):
if ii == 0 or ii + 1 == len(imgs):
img = img.convert("RGBA")
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0))
overlay.putalpha(128)
img = Image.alpha_composite(img, overlay).convert("RGB")
pic.paste(img, (0, int(height)))
height += img.size[1] + GAP
if need_position:
return pic, positions
return pic
@staticmethod
def extract_positions(txt: str):
poss = []
for tag in re.findall(r"@@[0-9-]+\t[0-9.\t]+##", txt):
pn, left, right, top, bottom = tag.strip("#").strip("@").split("\t")
left, right, top, bottom = float(left), float(right), float(top), float(bottom)
poss.append(([int(p) - 1 for p in pn.split("-")], left, right, top, bottom))
return poss
def _read_output(self, output_dir: Path, file_stem: str, method: str = "auto", backend: str = "pipeline") -> list[dict[str, Any]]:
subdir = output_dir / file_stem / method
if backend.startswith("vlm-"):
subdir = output_dir / file_stem / "vlm"
json_file = subdir / f"{file_stem}_content_list.json"
if not json_file.exists():
raise FileNotFoundError(f"[MinerU] Missing output file: {json_file}")
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
for item in data:
for key in ("img_path", "table_img_path", "equation_img_path"):
if key in item and item[key]:
item[key] = str((subdir / item[key]).resolve())
return data
def _transfer_to_sections(self, outputs: list[dict[str, Any]]):
sections = []
for output in outputs:
match output["type"]:
case MinerUContentType.TEXT:
section = output["text"]
case MinerUContentType.TABLE:
section = output.get("table_body", "") + "\n".join(output.get("table_caption", [])) + "\n".join(output.get("table_footnote", []))
if not section.strip():
section = "FAILED TO PARSE TABLE"
case MinerUContentType.IMAGE:
section = "".join(output["image_caption"]) + "\n" + "".join(output["image_footnote"])
case MinerUContentType.EQUATION:
section = output["text"]
case MinerUContentType.CODE:
section = output["code_body"] + "\n".join(output.get("code_caption", []))
case MinerUContentType.LIST:
section = "\n".join(output.get("list_items", []))
case MinerUContentType.DISCARDED:
pass
if section:
sections.append((section, self._line_tag(output)))
return sections
def _transfer_to_tables(self, outputs: list[dict[str, Any]]):
return []
def parse_pdf(
self,
filepath: str | PathLike[str],
binary: BytesIO | bytes,
callback: Optional[Callable] = None,
*,
output_dir: Optional[str] = None,
backend: str = "pipeline",
lang: Optional[str] = None,
method: str = "auto",
server_url: Optional[str] = None,
delete_output: bool = True,
) -> tuple:
import shutil
temp_pdf = None
created_tmp_dir = False
# remove spaces, or mineru crash, and _read_output fail too
file_path = Path(filepath)
pdf_file_name = file_path.stem.replace(" ", "") + ".pdf"
pdf_file_path_valid = os.path.join(file_path.parent, pdf_file_name)
if binary:
temp_dir = Path(tempfile.mkdtemp(prefix="mineru_bin_pdf_"))
temp_pdf = temp_dir / pdf_file_name
with open(temp_pdf, "wb") as f:
f.write(binary)
pdf = temp_pdf
self.logger.info(f"[MinerU] Received binary PDF -> {temp_pdf}")
if callback:
callback(0.15, f"[MinerU] Received binary PDF -> {temp_pdf}")
else:
if pdf_file_path_valid != filepath:
self.logger.info(f"[MinerU] Remove all space in file name: {pdf_file_path_valid}")
shutil.move(filepath, pdf_file_path_valid)
pdf = Path(pdf_file_path_valid)
if not pdf.exists():
if callback:
callback(-1, f"[MinerU] PDF not found: {pdf}")
raise FileNotFoundError(f"[MinerU] PDF not found: {pdf}")
if output_dir:
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
else:
out_dir = Path(tempfile.mkdtemp(prefix="mineru_pdf_"))
created_tmp_dir = True
self.logger.info(f"[MinerU] Output directory: {out_dir}")
if callback:
callback(0.15, f"[MinerU] Output directory: {out_dir}")
self.__images__(pdf, zoomin=1)
try:
self._run_mineru(pdf, out_dir, method=method, backend=backend, lang=lang, server_url=server_url, callback=callback)
outputs = self._read_output(out_dir, pdf.stem, method=method, backend=backend)
self.logger.info(f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
if callback:
callback(0.75, f"[MinerU] Parsed {len(outputs)} blocks from PDF.")
return self._transfer_to_sections(outputs), self._transfer_to_tables(outputs)
finally:
if temp_pdf and temp_pdf.exists():
try:
temp_pdf.unlink()
temp_pdf.parent.rmdir()
except Exception:
pass
if delete_output and created_tmp_dir and out_dir.exists():
try:
shutil.rmtree(out_dir)
except Exception:
pass
if __name__ == "__main__":
parser = MinerUParser("mineru")
ok, reason = parser.check_installation()
print("MinerU available:", ok)
filepath = ""
with open(filepath, "rb") as file:
outputs = parser.parse_pdf(filepath=filepath, binary=file.read())
for output in outputs:
print(output)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,83 @@
import logging
from io import BytesIO
from pptx import Presentation
class RAGPptParser:
def __init__(self):
super().__init__()
def __get_bulleted_text(self, paragraph):
is_bulleted = bool(paragraph._p.xpath("./a:pPr/a:buChar")) or bool(paragraph._p.xpath("./a:pPr/a:buAutoNum")) or bool(paragraph._p.xpath("./a:pPr/a:buBlip"))
if is_bulleted:
return f"{' '* paragraph.level}.{paragraph.text}"
else:
return paragraph.text
def __extract(self, shape):
try:
# First try to get text content
if hasattr(shape, 'has_text_frame') and shape.has_text_frame:
text_frame = shape.text_frame
texts = []
for paragraph in text_frame.paragraphs:
if paragraph.text.strip():
texts.append(self.__get_bulleted_text(paragraph))
return "\n".join(texts)
# Safely get shape_type
try:
shape_type = shape.shape_type
except NotImplementedError:
# If shape_type is not available, try to get text content
if hasattr(shape, 'text'):
return shape.text.strip()
return ""
# Handle table
if shape_type == 19:
tb = shape.table
rows = []
for i in range(1, len(tb.rows)):
rows.append("; ".join([tb.cell(
0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
return "\n".join(rows)
# Handle group shape
if shape_type == 6:
texts = []
for p in sorted(shape.shapes, key=lambda x: (x.top // 10, x.left)):
t = self.__extract(p)
if t:
texts.append(t)
return "\n".join(texts)
return ""
except Exception as e:
logging.error(f"Error processing shape: {str(e)}")
return ""
def __call__(self, fnm, from_page, to_page, callback=None):
ppt = Presentation(fnm) if isinstance(
fnm, str) else Presentation(
BytesIO(fnm))
txts = []
self.total_page = len(ppt.slides)
for i, slide in enumerate(ppt.slides):
if i < from_page:
continue
if i >= to_page:
break
texts = []
for shape in sorted(
slide.shapes, key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left if x.left is not None else 0)):
try:
txt = self.__extract(shape)
if txt:
texts.append(txt)
except Exception as e:
logging.exception(e)
txts.append("\n".join(texts))
return txts

View File

@@ -0,0 +1,48 @@
import re
from .utils import get_text
from app.core.rag.common.token_utils import num_tokens_from_string
class RAGTxtParser:
def __call__(self, fnm, binary=None, chunk_token_num=128, delimiter="\n!?;。;!?"):
txt = get_text(fnm, binary)
return self.parser_txt(txt, chunk_token_num, delimiter)
@classmethod
def parser_txt(cls, txt, chunk_token_num=128, delimiter="\n!?;。;!?"):
if not isinstance(txt, str):
raise TypeError("txt type should be str!")
cks = [""]
tk_nums = [0]
delimiter = delimiter.encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8')
def add_chunk(t):
nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if tk_nums[-1] > chunk_token_num:
cks.append(t)
tk_nums.append(tnum)
else:
cks[-1] += t
tk_nums[-1] += tnum
dels = []
s = 0
for m in re.finditer(r"`([^`]+)`", delimiter, re.I):
f, t = m.span()
dels.append(m.group(1))
dels.extend(list(delimiter[s: f]))
s = t
if s < len(delimiter):
dels.extend(list(delimiter[s:]))
dels = [re.escape(d) for d in dels if d]
dels = [d for d in dels if d]
dels = "|".join(dels)
secs = re.split(r"(%s)" % dels, txt)
for sec in secs:
if re.match(f"^{dels}$", sec):
continue
add_chunk(sec)
return [[c, ""] for c in cks]

View File

@@ -0,0 +1,16 @@
from app.core.rag.nlp import find_codec
def get_text(fnm: str, binary=None) -> str:
txt = ""
if binary:
encoding = find_codec(binary)
txt = binary.decode(encoding, errors="ignore")
else:
with open(fnm, "r") as f:
while True:
line = f.readline()
if not line:
break
txt += line
return txt

View File

@@ -0,0 +1,75 @@
import io
import sys
import threading
import pdfplumber
from .ocr import OCR
from .recognizer import Recognizer
from .layout_recognizer import AscendLayoutRecognizer
from .layout_recognizer import LayoutRecognizer4YOLOv10 as LayoutRecognizer
from .table_structure_recognizer import TableStructureRecognizer
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
def init_in_out(args):
import os
import traceback
from PIL import Image
from app.core.rag.common.file_utils import traversal_files
images = []
outputs = []
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
def pdf_pages(fnm, zoomin=3):
nonlocal outputs, images
with sys.modules[LOCK_KEY_pdfplumber]:
pdf = pdfplumber.open(fnm)
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in enumerate(pdf.pages)]
for i, page in enumerate(images):
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")
pdf.close()
def images_and_outputs(fnm):
nonlocal outputs, images
if fnm.split(".")[-1].lower() == "pdf":
pdf_pages(fnm)
return
try:
fp = open(fnm, "rb")
binary = fp.read()
fp.close()
images.append(Image.open(io.BytesIO(binary)).convert("RGB"))
outputs.append(os.path.split(fnm)[-1])
except Exception:
traceback.print_exc()
if os.path.isdir(args.inputs):
for fnm in traversal_files(args.inputs):
images_and_outputs(fnm)
else:
images_and_outputs(args.inputs)
for i in range(len(outputs)):
outputs[i] = os.path.join(args.output_dir, outputs[i])
return images, outputs
__all__ = [
"OCR",
"Recognizer",
"LayoutRecognizer",
"AscendLayoutRecognizer",
"TableStructureRecognizer",
"init_in_out",
]

View File

@@ -0,0 +1,440 @@
import logging
import math
import os
import re
from collections import Counter
from copy import deepcopy
import cv2
import numpy as np
from huggingface_hub import snapshot_download
from app.core.rag.common.file_utils import get_project_base_directory
from . import Recognizer
from .operators import nms
class LayoutRecognizer(Recognizer):
labels = [
"_background_",
"Text",
"Title",
"Figure",
"Figure caption",
"Table",
"Table caption",
"Header",
"Footer",
"Reference",
"Equation",
]
def __init__(self, domain):
try:
model_dir = os.path.join(get_project_base_directory(), "res/deepdoc")
super().__init__(self.labels, domain, model_dir)
except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "res/deepdoc"), local_dir_use_symlinks=False)
super().__init__(self.labels, domain, model_dir)
self.garbage_layouts = ["footer", "header", "reference"]
self.client = None
if os.environ.get("TENSORRT_DLA_SVR"):
from deepdoc.vision.dla_cli import DLAClient
self.client = DLAClient(os.environ["TENSORRT_DLA_SVR"])
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
def __is_garbage(b):
patt = [r"^•+$", "^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}", "\\(cid *: *[0-9]+ *\\)"]
return any([re.search(p, b["text"]) for p in patt])
if self.client:
layouts = self.client.predict(image_list)
else:
layouts = super().__call__(image_list, thr, batch_size)
# save_results(image_list, layouts, self.labels, output_dir='output/', threshold=0.7)
assert len(image_list) == len(ocr_res)
# Tag layout type
boxes = []
assert len(image_list) == len(layouts)
garbages = {}
page_layout = []
for pn, lts in enumerate(layouts):
bxs = ocr_res[pn]
lts = [
{
"type": b["type"],
"score": float(b["score"]),
"x0": b["bbox"][0] / scale_factor,
"x1": b["bbox"][2] / scale_factor,
"top": b["bbox"][1] / scale_factor,
"bottom": b["bbox"][-1] / scale_factor,
"page_number": pn,
}
for b in lts
if float(b["score"]) >= 0.4 or b["type"] not in self.garbage_layouts
]
lts = self.sort_Y_firstly(lts, np.mean([lt["bottom"] - lt["top"] for lt in lts]) / 2)
lts = self.layouts_cleanup(bxs, lts)
page_layout.append(lts)
def findLayout(ty):
nonlocal bxs, lts, self
lts_ = [lt for lt in lts if lt["type"] == ty]
i = 0
while i < len(bxs):
if bxs[i].get("layout_type"):
i += 1
continue
if __is_garbage(bxs[i]):
bxs.pop(i)
continue
ii = self.find_overlapped_with_threshold(bxs[i], lts_, thr=0.4)
if ii is None:
bxs[i]["layout_type"] = ""
i += 1
continue
lts_[ii]["visited"] = True
keep_feats = [
lts_[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor,
lts_[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor,
]
if drop and lts_[ii]["type"] in self.garbage_layouts and not any(keep_feats):
if lts_[ii]["type"] not in garbages:
garbages[lts_[ii]["type"]] = []
garbages[lts_[ii]["type"]].append(bxs[i]["text"])
bxs.pop(i)
continue
bxs[i]["layoutno"] = f"{ty}-{ii}"
bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[ii]["type"] != "equation" else "figure"
i += 1
for lt in ["footer", "header", "reference", "figure caption", "table caption", "title", "table", "text", "figure", "equation"]:
findLayout(lt)
# add box to figure layouts which has not text box
for i, lt in enumerate([lt for lt in lts if lt["type"] in ["figure", "equation"]]):
if lt.get("visited"):
continue
lt = deepcopy(lt)
del lt["type"]
lt["text"] = ""
lt["layout_type"] = "figure"
lt["layoutno"] = f"figure-{i}"
bxs.append(lt)
boxes.extend(bxs)
ocr_res = boxes
garbag_set = set()
for k in garbages.keys():
garbages[k] = Counter(garbages[k])
for g, c in garbages[k].items():
if c > 1:
garbag_set.add(g)
ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
return ocr_res, page_layout
def forward(self, image_list, thr=0.7, batch_size=16):
return super().__call__(image_list, thr, batch_size)
class LayoutRecognizer4YOLOv10(LayoutRecognizer):
labels = [
"title",
"Text",
"Reference",
"Figure",
"Figure caption",
"Table",
"Table caption",
"Table caption",
"Equation",
"Figure caption",
]
def __init__(self, domain):
domain = "layout"
super().__init__(domain)
self.auto = False
self.scaleFill = False
self.scaleup = True
self.stride = 32
self.center = True
def preprocess(self, image_list):
inputs = []
new_shape = self.input_shape # height, width
for img in image_list:
shape = img.shape[:2] # current shape [height, width]
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
ww, hh = new_unpad
img = np.array(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).astype(np.float32)
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) # add border
img /= 255.0
img = img.transpose(2, 0, 1)
img = img[np.newaxis, :, :, :].astype(np.float32)
inputs.append({self.input_names[0]: img, "scale_factor": [shape[1] / ww, shape[0] / hh, dw, dh]})
return inputs
def postprocess(self, boxes, inputs, thr):
thr = 0.08
boxes = np.squeeze(boxes)
scores = boxes[:, 4]
boxes = boxes[scores > thr, :]
scores = scores[scores > thr]
if len(boxes) == 0:
return []
class_ids = boxes[:, -1].astype(int)
boxes = boxes[:, :4]
boxes[:, 0] -= inputs["scale_factor"][2]
boxes[:, 2] -= inputs["scale_factor"][2]
boxes[:, 1] -= inputs["scale_factor"][3]
boxes[:, 3] -= inputs["scale_factor"][3]
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]])
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
unique_class_ids = np.unique(class_ids)
indices = []
for class_id in unique_class_ids:
class_indices = np.where(class_ids == class_id)[0]
class_boxes = boxes[class_indices, :]
class_scores = scores[class_indices]
class_keep_boxes = nms(class_boxes, class_scores, 0.45)
indices.extend(class_indices[class_keep_boxes])
return [{"type": self.label_list[class_ids[i]].lower(), "bbox": [float(t) for t in boxes[i].tolist()], "score": float(scores[i])} for i in indices]
class AscendLayoutRecognizer(Recognizer):
labels = [
"title",
"Text",
"Reference",
"Figure",
"Figure caption",
"Table",
"Table caption",
"Table caption",
"Equation",
"Figure caption",
]
def __init__(self, domain):
from ais_bench.infer.interface import InferSession
model_dir = os.path.join(get_project_base_directory(), "res/deepdoc")
model_file_path = os.path.join(model_dir, domain + ".om")
if not os.path.exists(model_file_path):
raise ValueError(f"Model file not found: {model_file_path}")
device_id = int(os.getenv("ASCEND_LAYOUT_RECOGNIZER_DEVICE_ID", 0))
self.session = InferSession(device_id=device_id, model_path=model_file_path)
self.input_shape = self.session.get_inputs()[0].shape[2:4] # H,W
self.garbage_layouts = ["footer", "header", "reference"]
def preprocess(self, image_list):
inputs = []
H, W = self.input_shape
for img in image_list:
h, w = img.shape[:2]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
r = min(H / h, W / w)
new_unpad = (int(round(w * r)), int(round(h * r)))
dw, dh = (W - new_unpad[0]) / 2.0, (H - new_unpad[1]) / 2.0
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
img /= 255.0
img = img.transpose(2, 0, 1)[np.newaxis, :, :, :].astype(np.float32)
inputs.append(
{
"image": img,
"scale_factor": [w / new_unpad[0], h / new_unpad[1]],
"pad": [dw, dh],
"orig_shape": [h, w],
}
)
return inputs
def postprocess(self, boxes, inputs, thr=0.25):
arr = np.squeeze(boxes)
if arr.ndim == 1:
arr = arr.reshape(1, -1)
results = []
if arr.shape[1] == 6:
# [x1,y1,x2,y2,score,cls]
m = arr[:, 4] >= thr
arr = arr[m]
if arr.size == 0:
return []
xyxy = arr[:, :4].astype(np.float32)
scores = arr[:, 4].astype(np.float32)
cls_ids = arr[:, 5].astype(np.int32)
if "pad" in inputs:
dw, dh = inputs["pad"]
sx, sy = inputs["scale_factor"]
xyxy[:, [0, 2]] -= dw
xyxy[:, [1, 3]] -= dh
xyxy *= np.array([sx, sy, sx, sy], dtype=np.float32)
else:
# backup
sx, sy = inputs["scale_factor"]
xyxy *= np.array([sx, sy, sx, sy], dtype=np.float32)
keep_indices = []
for c in np.unique(cls_ids):
idx = np.where(cls_ids == c)[0]
k = nms(xyxy[idx], scores[idx], 0.45)
keep_indices.extend(idx[k])
for i in keep_indices:
cid = int(cls_ids[i])
if 0 <= cid < len(self.labels):
results.append({"type": self.labels[cid].lower(), "bbox": [float(t) for t in xyxy[i].tolist()], "score": float(scores[i])})
return results
raise ValueError(f"Unexpected output shape: {arr.shape}")
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
import re
from collections import Counter
assert len(image_list) == len(ocr_res)
images = [np.array(im) if not isinstance(im, np.ndarray) else im for im in image_list]
layouts_all_pages = [] # list of list[{"type","score","bbox":[x1,y1,x2,y2]}]
conf_thr = max(thr, 0.08)
batch_loop_cnt = math.ceil(float(len(images)) / batch_size)
for bi in range(batch_loop_cnt):
s = bi * batch_size
e = min((bi + 1) * batch_size, len(images))
batch_images = images[s:e]
inputs_list = self.preprocess(batch_images)
logging.debug("preprocess done")
for ins in inputs_list:
feeds = [ins["image"]]
out_list = self.session.infer(feeds=feeds, mode="static")
for out in out_list:
lts = self.postprocess(out, ins, conf_thr)
page_lts = []
for b in lts:
if float(b["score"]) >= 0.4 or b["type"] not in self.garbage_layouts:
x0, y0, x1, y1 = b["bbox"]
page_lts.append(
{
"type": b["type"],
"score": float(b["score"]),
"x0": float(x0) / scale_factor,
"x1": float(x1) / scale_factor,
"top": float(y0) / scale_factor,
"bottom": float(y1) / scale_factor,
"page_number": len(layouts_all_pages),
}
)
layouts_all_pages.append(page_lts)
def _is_garbage_text(box):
patt = [r"^•+$", r"^[0-9]{1,2} / ?[0-9]{1,2}$", r"^[0-9]{1,2} of [0-9]{1,2}$", r"^http://[^ ]{12,}", r"\(cid *: *[0-9]+ *\)"]
return any(re.search(p, box.get("text", "")) for p in patt)
boxes_out = []
page_layout = []
garbages = {}
for pn, lts in enumerate(layouts_all_pages):
if lts:
avg_h = np.mean([lt["bottom"] - lt["top"] for lt in lts])
lts = self.sort_Y_firstly(lts, avg_h / 2 if avg_h > 0 else 0)
bxs = ocr_res[pn]
lts = self.layouts_cleanup(bxs, lts)
page_layout.append(lts)
def _tag_layout(ty):
nonlocal bxs, lts
lts_of_ty = [lt for lt in lts if lt["type"] == ty]
i = 0
while i < len(bxs):
if bxs[i].get("layout_type"):
i += 1
continue
if _is_garbage_text(bxs[i]):
bxs.pop(i)
continue
ii = self.find_overlapped_with_threshold(bxs[i], lts_of_ty, thr=0.4)
if ii is None:
bxs[i]["layout_type"] = ""
i += 1
continue
lts_of_ty[ii]["visited"] = True
keep_feats = [
lts_of_ty[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].shape[0] * 0.9 / scale_factor,
lts_of_ty[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].shape[0] * 0.1 / scale_factor,
]
if drop and lts_of_ty[ii]["type"] in self.garbage_layouts and not any(keep_feats):
garbages.setdefault(lts_of_ty[ii]["type"], []).append(bxs[i].get("text", ""))
bxs.pop(i)
continue
bxs[i]["layoutno"] = f"{ty}-{ii}"
bxs[i]["layout_type"] = lts_of_ty[ii]["type"] if lts_of_ty[ii]["type"] != "equation" else "figure"
i += 1
for ty in ["footer", "header", "reference", "figure caption", "table caption", "title", "table", "text", "figure", "equation"]:
_tag_layout(ty)
figs = [lt for lt in lts if lt["type"] in ["figure", "equation"]]
for i, lt in enumerate(figs):
if lt.get("visited"):
continue
lt = deepcopy(lt)
lt.pop("type", None)
lt["text"] = ""
lt["layout_type"] = "figure"
lt["layoutno"] = f"figure-{i}"
bxs.append(lt)
boxes_out.extend(bxs)
garbag_set = set()
for k, lst in garbages.items():
cnt = Counter(lst)
for g, c in cnt.items():
if c > 1:
garbag_set.add(g)
ocr_res_new = [b for b in boxes_out if b["text"].strip() not in garbag_set]
return ocr_res_new, page_layout

View File

@@ -0,0 +1,737 @@
import gc
import logging
import copy
import time
import os
from huggingface_hub import snapshot_download
from app.core.rag.common.file_utils import get_project_base_directory
from app.core.rag.common.misc_utils import pip_install_torch
from app.core.rag.common import settings
from .operators import * # noqa: F403
from . import operators
import math
import numpy as np
import cv2
import onnxruntime as ort
from .postprocess import build_post_process
loaded_models = {}
def transform(data, ops=None):
""" transform """
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
def create_operators(op_param_list, global_config=None):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(
op_param_list, list), ('operator config should be a list')
ops = []
for operator in op_param_list:
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
param.update(global_config)
op = getattr(operators, op_name)(**param)
ops.append(op)
return ops
def load_model(model_dir, nm, device_id: int | None = None):
model_file_path = os.path.join(model_dir, nm + ".onnx")
model_cached_tag = model_file_path + str(device_id) if device_id is not None else model_file_path
global loaded_models
loaded_model = loaded_models.get(model_cached_tag)
if loaded_model:
logging.info(f"load_model {model_file_path} reuses cached model")
return loaded_model
if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(
model_file_path))
def cuda_is_available():
try:
pip_install_torch()
import torch
target_id = 0 if device_id is None else device_id
if torch.cuda.is_available() and torch.cuda.device_count() > target_id:
return True
except Exception:
return False
return False
options = ort.SessionOptions()
options.enable_cpu_mem_arena = False
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
options.intra_op_num_threads = 2
options.inter_op_num_threads = 2
# https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580
# Shrink GPU memory after execution
run_options = ort.RunOptions()
if cuda_is_available():
gpu_mem_limit_mb = int(os.environ.get("OCR_GPU_MEM_LIMIT_MB", "2048"))
arena_strategy = os.environ.get("OCR_ARENA_EXTEND_STRATEGY", "kNextPowerOfTwo")
provider_device_id = 0 if device_id is None else device_id
cuda_provider_options = {
"device_id": provider_device_id, # Use specific GPU
"gpu_mem_limit": max(gpu_mem_limit_mb, 0) * 1024 * 1024,
"arena_extend_strategy": arena_strategy, # gpu memory allocation strategy
}
sess = ort.InferenceSession(
model_file_path,
options=options,
providers=['CUDAExecutionProvider'],
provider_options=[cuda_provider_options]
)
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(provider_device_id))
logging.info(f"load_model {model_file_path} uses GPU (device {provider_device_id}, gpu_mem_limit={cuda_provider_options['gpu_mem_limit']}, arena_strategy={arena_strategy})")
else:
sess = ort.InferenceSession(
model_file_path,
options=options,
providers=['CPUExecutionProvider'])
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu")
logging.info(f"load_model {model_file_path} uses CPU")
loaded_model = (sess, run_options)
loaded_models[model_cached_tag] = loaded_model
return loaded_model
class TextRecognizer:
def __init__(self, model_dir, device_id: int | None = None):
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
self.rec_batch_num = 16
postprocess_params = {
'name': 'CTCLabelDecode',
"character_dict_path": os.path.join(model_dir, "ocr.res"),
"use_space_char": True
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.run_options = load_model(model_dir, 'rec', device_id)
self.input_tensor = self.predictor.get_inputs()[0]
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2]
imgW = int((imgH * max_wh_ratio))
w = self.input_tensor.shape[3:][0]
if isinstance(w, str):
pass
elif w is not None and w > 0:
imgW = w
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_vl(self, img, image_shape):
imgC, imgH, imgW = image_shape
img = img[:, :, ::-1] # bgr2rgb
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
return resized_image
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(
gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(
gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
self.srn_other_inputs(image_shape, num_heads, max_text_length)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
encoder_word_pos = encoder_word_pos.astype(np.int64)
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def resize_norm_img_sar(self, img, image_shape,
width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h)
resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype('float32')
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape
return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img_spin(self, img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
img = np.array(img, np.float32)
img = np.expand_dims(img, -1)
img = img.transpose((2, 0, 1))
mean = [127.5]
std = [127.5]
mean = np.array(mean, dtype=np.float32)
std = np.array(std, dtype=np.float32)
mean = np.float32(mean.reshape(1, -1))
stdinv = 1 / np.float32(std.reshape(1, -1))
img -= mean
img *= stdinv
return img
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
def resize_norm_img_abinet(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image / 255.
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
resized_image = (
resized_image - mean[None, None, ...]) / std[None, None, ...]
resized_image = resized_image.transpose((2, 0, 1))
resized_image = resized_image.astype('float32')
return resized_image
def norm_img_can(self, img, image_shape):
img = cv2.cvtColor(
img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
if self.rec_image_shape[0] == 1:
h, w = img.shape
_, imgH, imgW = self.rec_image_shape
if h < imgH or w < imgW:
padding_h = max(imgH - h, 0)
padding_w = max(imgW - w, 0)
img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
'constant',
constant_values=(255))
img = img_padded
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
img = img.astype('float32')
return img
def close(self):
# close session and release manually
logging.info('Close text recognizer.')
if hasattr(self, "predictor"):
del self.predictor
gc.collect()
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
st = time.time()
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
# max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
for i in range(100000):
try:
outputs = self.predictor.run(None, input_dict, self.run_options)
break
except Exception as e:
if i >= 3:
raise e
time.sleep(5)
preds = outputs[0]
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
return rec_res, time.time() - st
def __del__(self):
self.close()
class TextDetector:
def __init__(self, model_dir, device_id: int | None = None):
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': 960,
'limit_type': "max",
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image', 'shape']
}
}]
postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.5, "max_candidates": 1000,
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.run_options = load_model(model_dir, 'det', device_id)
self.input_tensor = self.predictor.get_inputs()[0]
img_h, img_w = self.input_tensor.shape[2:]
if isinstance(img_h, str) or isinstance(img_w, str):
pass
elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
pre_process_list[0] = {
'DetResizeForTest': {
'image_shape': [img_h, img_w]
}
}
self.preprocess_op = create_operators(pre_process_list)
def order_points_clockwise(self, pts):
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
diff = np.diff(np.array(tmp), axis=1)
rect[1] = tmp[np.argmin(diff)]
rect[3] = tmp[np.argmax(diff)]
return rect
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if isinstance(box, list):
box = np.array(box)
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if isinstance(box, list):
box = np.array(box)
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def close(self):
logging.info("Close text detector.")
if hasattr(self, "predictor"):
del self.predictor
gc.collect()
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
st = time.time()
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
input_dict = {}
input_dict[self.input_tensor.name] = img
for i in range(100000):
try:
outputs = self.predictor.run(None, input_dict, self.run_options)
break
except Exception as e:
if i >= 3:
raise e
time.sleep(5)
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
dt_boxes = post_result[0]['points']
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
return dt_boxes, time.time() - st
def __del__(self):
self.close()
class OCR:
def __init__(self, model_dir=None):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not model_dir:
try:
model_dir = os.path.join(
get_project_base_directory(),
"res/deepdoc")
# Append muti-gpus task to the list
if settings.PARALLEL_DEVICES > 0:
self.text_detector = []
self.text_recognizer = []
for device_id in range(settings.PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:
self.text_detector = [TextDetector(model_dir)]
self.text_recognizer = [TextRecognizer(model_dir)]
except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "res/deepdoc"),
local_dir_use_symlinks=False)
if settings.PARALLEL_DEVICES > 0:
self.text_detector = []
self.text_recognizer = []
for device_id in range(settings.PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else:
self.text_detector = [TextDetector(model_dir)]
self.text_recognizer = [TextRecognizer(model_dir)]
self.drop_score = 0.5
self.crop_image_res_index = 0
def get_rotate_crop_image(self, img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
# Try original orientation
rec_result = self.text_recognizer[0]([dst_img])
text, score = rec_result[0][0]
best_score = score
best_img = dst_img
# Try clockwise 90° rotation
rotated_cw = np.rot90(dst_img, k=3)
rec_result = self.text_recognizer[0]([rotated_cw])
rotated_cw_text, rotated_cw_score = rec_result[0][0]
if rotated_cw_score > best_score:
best_score = rotated_cw_score
best_img = rotated_cw
# Try counter-clockwise 90° rotation
rotated_ccw = np.rot90(dst_img, k=1)
rec_result = self.text_recognizer[0]([rotated_ccw])
rotated_ccw_text, rotated_ccw_score = rec_result[0][0]
if rotated_ccw_score > best_score:
best_img = rotated_ccw
# Use the best image
dst_img = best_img
return dst_img
def sorted_boxes(self, dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
def detect(self, img, device_id: int | None = None):
if device_id is None:
device_id = 0
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
if img is None:
return None, None, time_dict
start = time.time()
dt_boxes, elapse = self.text_detector[device_id](img)
time_dict['det'] = elapse
if dt_boxes is None:
end = time.time()
time_dict['all'] = end - start
return None, None, time_dict
return zip(self.sorted_boxes(dt_boxes), [
("", 0) for _ in range(len(dt_boxes))])
def recognize(self, ori_im, box, device_id: int | None = None):
if device_id is None:
device_id = 0
img_crop = self.get_rotate_crop_image(ori_im, box)
rec_res, elapse = self.text_recognizer[device_id]([img_crop])
text, score = rec_res[0]
if score < self.drop_score:
return ""
return text
def recognize_batch(self, img_list, device_id: int | None = None):
if device_id is None:
device_id = 0
rec_res, elapse = self.text_recognizer[device_id](img_list)
texts = []
for i in range(len(rec_res)):
text, score = rec_res[i]
if score < self.drop_score:
text = ""
texts.append(text)
return texts
def __call__(self, img, device_id = 0, cls=True):
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
if device_id is None:
device_id = 0
if img is None:
return None, None, time_dict
start = time.time()
ori_im = img.copy()
dt_boxes, elapse = self.text_detector[device_id](img)
time_dict['det'] = elapse
if dt_boxes is None:
end = time.time()
time_dict['all'] = end - start
return None, None, time_dict
img_crop_list = []
dt_boxes = self.sorted_boxes(dt_boxes)
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop)
rec_res, elapse = self.text_recognizer[device_id](img_crop_list)
time_dict['rec'] = elapse
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
end = time.time()
time_dict['all'] = end - start
# for bno in range(len(img_crop_list)):
# print(f"{bno}, {rec_res[bno]}")
return list(zip([a.tolist() for a in filter_boxes], filter_rec_res))

View File

@@ -0,0 +1,709 @@
import logging
import sys
import six
import cv2
import numpy as np
import math
from PIL import Image
class DecodeImage:
""" decode image """
def __init__(self,
img_mode='RGB',
channel_first=False,
ignore_orientation=False,
**kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
self.ignore_orientation = ignore_orientation
def __call__(self, data):
img = data['image']
if six.PY2:
assert isinstance(img, str) and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert isinstance(img, bytes) and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
if self.ignore_orientation:
img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
cv2.IMREAD_COLOR)
else:
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class StandardizeImag:
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_scale:
scale = 1.0 / 255.0
im *= scale
if self.norm_type == 'mean_std':
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
class NormalizeImage:
""" normalize image such as subtract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
return data
class ToCHWImage:
""" convert hwc image to chw image
"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
data['image'] = img.transpose((2, 0, 1))
return data
class KeepKeys:
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
class Pad:
def __init__(self, size=None, size_div=32, **kwargs):
if size is not None and not isinstance(size, (int, list, tuple)):
raise TypeError("Type of target_size is invalid. Now is {}".format(
type(size)))
if isinstance(size, int):
size = [size, size]
self.size = size
self.size_div = size_div
def __call__(self, data):
img = data['image']
img_h, img_w = img.shape[0], img.shape[1]
if self.size:
resize_h2, resize_w2 = self.size
assert (
img_h < resize_h2 and img_w < resize_w2
), '(h, w) of target size should be greater than (img_h, img_w)'
else:
resize_h2 = max(
int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
self.size_div)
resize_w2 = max(
int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
self.size_div)
img = cv2.copyMakeBorder(
img,
0,
resize_h2 - img_h,
0,
resize_w2 - img_w,
cv2.BORDER_CONSTANT,
value=0)
data['image'] = img
return data
class LinearResize:
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
_im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32')
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
_im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class Resize:
def __init__(self, size=(640, 640), **kwargs):
self.size = size
def resize_image(self, img):
resize_h, resize_w = self.size
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
return img, [ratio_h, ratio_w]
def __call__(self, data):
img = data['image']
if 'polys' in data:
text_polys = data['polys']
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
if 'polys' in data:
new_boxes = []
for box in text_polys:
new_box = []
for cord in box:
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
data['polys'] = np.array(new_boxes, dtype=np.float32)
data['image'] = img_resize
return data
class DetResizeForTest:
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
self.keep_ratio = False
if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape']
self.resize_type = 1
if 'keep_ratio' in kwargs:
self.keep_ratio = kwargs['keep_ratio']
elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
elif 'resize_long' in kwargs:
self.resize_type = 2
self.resize_long = kwargs.get('resize_long', 960)
else:
self.limit_side_len = 736
self.limit_type = 'min'
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if sum([src_h, src_w]) < 64:
img = self.image_padding(img)
if self.resize_type == 0:
# img, shape = self.resize_image_type0(img)
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
elif self.resize_type == 2:
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
data['image'] = img
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def image_padding(self, im, value=0):
h, w, c = im.shape
im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
im_pad[:h, :w, :] = im
return im_pad
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
if self.keep_ratio is True:
resize_w = ori_w * resize_h / ori_h
N = math.ceil(resize_w / 32)
resize_w = N * 32
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
# return img, np.array([ori_h, ori_w])
return img, [ratio_h, ratio_w]
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'min':
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h, w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = max(int(round(resize_w / 32) * 32), 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except BaseException:
logging.exception("{} {} {}".format(img.shape, resize_w, resize_h))
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
def resize_image_type2(self, img):
h, w, _ = img.shape
resize_w = w
resize_h = h
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
ratio = float(self.resize_long) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
img = cv2.resize(img, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
class E2EResizeForTest:
def __init__(self, **kwargs):
super(E2EResizeForTest, self).__init__()
self.max_side_len = kwargs['max_side_len']
self.valid_set = kwargs['valid_set']
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.valid_set == 'totaltext':
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
img, max_side_len=self.max_side_len)
else:
im_resized, (ratio_h, ratio_w) = self.resize_image(
img, max_side_len=self.max_side_len)
data['image'] = im_resized
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_for_totaltext(self, im, max_side_len=512):
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image(self, im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
class KieResize:
def __init__(self, **kwargs):
super(KieResize, self).__init__()
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
'img_scale'][1]
def __call__(self, data):
img = data['image']
points = data['points']
src_h, src_w, _ = img.shape
im_resized, scale_factor, [ratio_h, ratio_w
], [new_h, new_w] = self.resize_image(img)
resize_points = self.resize_boxes(img, points, scale_factor)
data['ori_image'] = img
data['ori_boxes'] = points
data['points'] = resize_points
data['image'] = im_resized
data['shape'] = np.array([new_h, new_w])
return data
def resize_image(self, img):
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
scale = [512, 1024]
h, w = img.shape[:2]
max_long_edge = max(scale)
max_short_edge = min(scale)
scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w))
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
scale_factor) + 0.5)
max_stride = 32
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(img, (resize_w, resize_h))
new_h, new_w = im.shape[:2]
w_scale = new_w / w
h_scale = new_h / h
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
norm_img[:new_h, :new_w, :] = im
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
def resize_boxes(self, im, points, scale_factor):
points = points * scale_factor
img_shape = im.shape[:2]
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
return points
class SRResize:
def __init__(self,
imgH=32,
imgW=128,
down_sample_scale=4,
keep_ratio=False,
min_ratio=1,
mask=False,
infer_mode=False,
**kwargs):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio = keep_ratio
self.min_ratio = min_ratio
self.down_sample_scale = down_sample_scale
self.mask = mask
self.infer_mode = infer_mode
def __call__(self, data):
imgH = self.imgH
imgW = self.imgW
images_lr = data["image_lr"]
transform2 = ResizeNormalize(
(imgW // self.down_sample_scale, imgH // self.down_sample_scale))
images_lr = transform2(images_lr)
data["img_lr"] = images_lr
if self.infer_mode:
return data
images_HR = data["image_hr"]
_label_strs = data["label"]
transform = ResizeNormalize((imgW, imgH))
images_HR = transform(images_HR)
data["img_hr"] = images_HR
return data
class ResizeNormalize:
def __init__(self, size, interpolation=Image.BICUBIC):
self.size = size
self.interpolation = interpolation
def __call__(self, img):
img = img.resize(self.size, self.interpolation)
img_numpy = np.array(img).astype("float32")
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
return img_numpy
class GrayImageChannelFormat:
"""
format gray scale image's channel: (3,h,w) -> (1,h,w)
Args:
inverse: inverse gray image
"""
def __init__(self, inverse=False, **kwargs):
self.inverse = inverse
def __call__(self, data):
img = data['image']
img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_expanded = np.expand_dims(img_single_channel, 0)
if self.inverse:
data['image'] = np.abs(img_expanded - 1)
else:
data['image'] = img_expanded
data['src_image'] = img
return data
class Permute:
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(self, ):
super(Permute, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class PadStride:
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride <= 0:
return im, im_info
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
return padding_im, im_info
def decode_image(im_file, im_info):
"""read rgb image
Args:
im_file (str|np.ndarray): input can be image path or np.ndarray
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if isinstance(im_file, str):
with open(im_file, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
else:
im = im_file
im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
return im, im_info
def preprocess(im, preprocess_ops):
# process image by preprocess_ops
im_info = {
'scale_factor': np.array(
[1., 1.], dtype=np.float32),
'im_shape': None,
}
im, im_info = decode_image(im, im_info)
for operator in preprocess_ops:
im, im_info = operator(im, im_info)
return im, im_info
def nms(bboxes, scores, iou_thresh):
import numpy as np
x1 = bboxes[:, 0]
y1 = bboxes[:, 1]
x2 = bboxes[:, 2]
y2 = bboxes[:, 3]
areas = (y2 - y1) * (x2 - x1)
indices = []
index = scores.argsort()[::-1]
while index.size > 0:
i = index[0]
indices.append(i)
x11 = np.maximum(x1[i], x1[index[1:]])
y11 = np.maximum(y1[i], y1[index[1:]])
x22 = np.minimum(x2[i], x2[index[1:]])
y22 = np.minimum(y2[i], y2[index[1:]])
w = np.maximum(0, x22 - x11 + 1)
h = np.maximum(0, y22 - y11 + 1)
overlaps = w * h
ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
idx = np.where(ious <= iou_thresh)[0]
index = index[idx + 1]
return indices

View File

@@ -0,0 +1,354 @@
import copy
import re
import numpy as np
import cv2
from shapely.geometry import Polygon
import pyclipper
def build_post_process(config, global_config=None):
support_dict = {'DBPostProcess': DBPostProcess, 'CTCLabelDecode': CTCLabelDecode}
config = copy.deepcopy(config)
module_name = config.pop('name')
if module_name == "None":
return
if global_config is not None:
config.update(global_config)
module_class = support_dict.get(module_name)
if module_class is None:
raise ValueError(
'post process only support {}'.format(list(support_dict)))
return module_class(**config)
class DBPostProcess:
"""
The post process for Differentiable Binarization (DB).
"""
def __init__(self,
thresh=0.3,
box_thresh=0.7,
max_candidates=1000,
unclip_ratio=2.0,
use_dilation=False,
score_mode="fast",
box_type='quad',
**kwargs):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio
self.min_size = 3
self.score_mode = score_mode
self.box_type = box_type
assert score_mode in [
"slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]])
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
boxes = []
scores = []
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours[:self.max_candidates]:
epsilon = 0.002 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape((-1, 2))
if points.shape[0] < 4:
continue
score = self.box_score_fast(pred, points.reshape(-1, 2))
if self.box_thresh > score:
continue
if points.shape[0] > 2:
box = self.unclip(points, self.unclip_ratio)
if len(box) > 1:
continue
else:
continue
box = box.reshape(-1, 2)
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.tolist())
scores.append(score)
return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
_img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = []
scores = []
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
if self.score_mode == "fast":
score = self.box_score_fast(pred, points.reshape(-1, 2))
else:
score = self.box_score_slow(pred, contour)
if self.box_thresh > score:
continue
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.astype("int32"))
scores.append(score)
return np.array(boxes, dtype="int32"), scores
def unclip(self, box, unclip_ratio):
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [
points[index_1], points[index_2], points[index_3], points[index_4]
]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
'''
box_score_fast: use bbox mean score as the mean score
'''
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
'''
box_score_slow: use polyon mean score as the mean score
'''
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if not isinstance(pred, np.ndarray):
pred = pred.numpy()
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
if self.dilation_kernel is not None:
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
self.dilation_kernel)
else:
mask = segmentation[batch_index]
if self.box_type == 'poly':
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
mask, src_w, src_h)
elif self.box_type == 'quad':
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
else:
raise ValueError(
"box_type can only be one of ['quad', 'poly']")
boxes_batch.append({'points': boxes})
return boxes_batch
class BaseRecLabelDecode:
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False):
self.beg_str = "sos"
self.end_str = "eos"
self.reverse = False
self.character_str = []
if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
else:
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str.append(line)
if use_space_char:
self.character_str.append(" ")
dict_character = list(self.character_str)
if 'arabic' in character_dict_path:
self.reverse = True
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def pred_reverse(self, pred):
pred_re = []
c_current = ''
for c in pred:
if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
if c_current != '':
pred_re.append(c_current)
pred_re.append(c)
c_current = ''
else:
c_current += c
if c_current != '':
pred_re.append(c_current)
return ''.join(pred_re[::-1])
def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[
batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
char_list = [
self.character[text_id]
for text_id in text_index[batch_idx][selection]
]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
text = ''.join(char_list)
if self.reverse: # for arabic rec
text = self.pred_reverse(text)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def get_ignored_tokens(self):
return [0] # for ctc blank
class CTCLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if not isinstance(preds, np.ndarray):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None:
return text
label = self.decode(label)
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character

View File

@@ -0,0 +1,427 @@
import gc
import logging
import os
import math
import numpy as np
import cv2
from functools import cmp_to_key
from app.core.rag.common.file_utils import get_project_base_directory
from .operators import * # noqa: F403
from .operators import preprocess
from . import operators
from .ocr import load_model
class Recognizer:
def __init__(self, label_list, task_name, model_dir=None):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not model_dir:
model_dir = os.path.join(
get_project_base_directory(),
"res/deepdoc")
self.ort_sess, self.run_options = load_model(model_dir, task_name)
self.input_names = [node.name for node in self.ort_sess.get_inputs()]
self.output_names = [node.name for node in self.ort_sess.get_outputs()]
self.input_shape = self.ort_sess.get_inputs()[0].shape[2:4]
self.label_list = label_list
@staticmethod
def sort_Y_firstly(arr, threshold):
def cmp(c1, c2):
diff = c1["top"] - c2["top"]
if abs(diff) < threshold:
diff = c1["x0"] - c2["x0"]
return diff
arr = sorted(arr, key=cmp_to_key(cmp))
return arr
@staticmethod
def sort_X_firstly(arr, threshold):
def cmp(c1, c2):
diff = c1["x0"] - c2["x0"]
if abs(diff) < threshold:
diff = c1["top"] - c2["top"]
return diff
arr = sorted(arr, key=cmp_to_key(cmp))
return arr
@staticmethod
def sort_C_firstly(arr, thr=0):
# sort using y1 first and then x1
# sorted(arr, key=lambda r: (r["x0"], r["top"]))
arr = Recognizer.sort_X_firstly(arr, thr)
for i in range(len(arr) - 1):
for j in range(i, -1, -1):
# restore the order using th
if "C" not in arr[j] or "C" not in arr[j + 1]:
continue
if arr[j + 1]["C"] < arr[j]["C"] \
or (
arr[j + 1]["C"] == arr[j]["C"]
and arr[j + 1]["top"] < arr[j]["top"]
):
tmp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = tmp
return arr
@staticmethod
def sort_R_firstly(arr, thr=0):
# sort using y1 first and then x1
# sorted(arr, key=lambda r: (r["top"], r["x0"]))
arr = Recognizer.sort_Y_firstly(arr, thr)
for i in range(len(arr) - 1):
for j in range(i, -1, -1):
if "R" not in arr[j] or "R" not in arr[j + 1]:
continue
if arr[j + 1]["R"] < arr[j]["R"] \
or (
arr[j + 1]["R"] == arr[j]["R"]
and arr[j + 1]["x0"] < arr[j]["x0"]
):
tmp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = tmp
return arr
@staticmethod
def overlapped_area(a, b, ratio=True):
tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"]
if b["x0"] > x1 or b["x1"] < x0:
return 0
if b["bottom"] < tp or b["top"] > btm:
return 0
x0_ = max(b["x0"], x0)
x1_ = min(b["x1"], x1)
assert x0_ <= x1_, "Bbox mismatch! T:{},B:{},X0:{},X1:{} ==> {}".format(
tp, btm, x0, x1, b)
tp_ = max(b["top"], tp)
btm_ = min(b["bottom"], btm)
assert tp_ <= btm_, "Bbox mismatch! T:{},B:{},X0:{},X1:{} => {}".format(
tp, btm, x0, x1, b)
ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
x0 != 0 and btm - tp != 0 else 0
if ov > 0 and ratio:
ov /= (x1 - x0) * (btm - tp)
return ov
@staticmethod
def layouts_cleanup(boxes, layouts, far=2, thr=0.7):
def not_overlapped(a, b):
return any([a["x1"] < b["x0"],
a["x0"] > b["x1"],
a["bottom"] < b["top"],
a["top"] > b["bottom"]])
i = 0
while i + 1 < len(layouts):
j = i + 1
while j < min(i + far, len(layouts)) \
and (layouts[i].get("type", "") != layouts[j].get("type", "")
or not_overlapped(layouts[i], layouts[j])):
j += 1
if j >= min(i + far, len(layouts)):
i += 1
continue
if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \
and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr:
i += 1
continue
if layouts[i].get("score") and layouts[j].get("score"):
if layouts[i]["score"] > layouts[j]["score"]:
layouts.pop(j)
else:
layouts.pop(i)
continue
area_i, area_i_1 = 0, 0
for b in boxes:
if not not_overlapped(b, layouts[i]):
area_i += Recognizer.overlapped_area(b, layouts[i], False)
if not not_overlapped(b, layouts[j]):
area_i_1 += Recognizer.overlapped_area(b, layouts[j], False)
if area_i > area_i_1:
layouts.pop(j)
else:
layouts.pop(i)
return layouts
def create_inputs(self, imgs, im_info):
"""generate input for different model type
Args:
imgs (list(numpy)): list of images (np.ndarray)
im_info (list(dict)): list of image info
Returns:
inputs (dict): input of model
"""
inputs = {}
im_shape = []
scale_factor = []
if len(imgs) == 1:
inputs['image'] = np.array((imgs[0],)).astype('float32')
inputs['im_shape'] = np.array(
(im_info[0]['im_shape'],)).astype('float32')
inputs['scale_factor'] = np.array(
(im_info[0]['scale_factor'],)).astype('float32')
return inputs
im_shape = np.array([info['im_shape'] for info in im_info], dtype='float32')
scale_factor = np.array([info['scale_factor'] for info in im_info], dtype='float32')
inputs['im_shape'] = np.concatenate(im_shape, axis=0)
inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
max_shape_h = max([e[0] for e in imgs_shape])
max_shape_w = max([e[1] for e in imgs_shape])
padding_imgs = []
for img in imgs:
im_c, im_h, im_w = img.shape[:]
padding_im = np.zeros(
(im_c, max_shape_h, max_shape_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = img
padding_imgs.append(padding_im)
inputs['image'] = np.stack(padding_imgs, axis=0)
return inputs
@staticmethod
def find_overlapped(box, boxes_sorted_by_y, naive=False):
if not boxes_sorted_by_y:
return
bxs = boxes_sorted_by_y
s, e, ii = 0, len(bxs), 0
while s < e and not naive:
ii = (e + s) // 2
pv = bxs[ii]
if box["bottom"] < pv["top"]:
e = ii
continue
if box["top"] > pv["bottom"]:
s = ii + 1
continue
break
while s < ii:
if box["top"] > bxs[s]["bottom"]:
s += 1
break
while e - 1 > ii:
if box["bottom"] < bxs[e - 1]["top"]:
e -= 1
break
max_overlapped_i, max_overlapped = None, 0
for i in range(s, e):
ov = Recognizer.overlapped_area(bxs[i], box)
if ov <= max_overlapped:
continue
max_overlapped_i = i
max_overlapped = ov
return max_overlapped_i
@staticmethod
def find_horizontally_tightest_fit(box, boxes):
if not boxes:
return
min_dis, min_i = 1000000, None
for i,b in enumerate(boxes):
if box.get("layoutno", "0") != b.get("layoutno", "0"):
continue
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
if dis < min_dis:
min_i = i
min_dis = dis
return min_i
@staticmethod
def find_overlapped_with_threshold(box, boxes, thr=0.3):
if not boxes:
return
max_overlapped_i, max_overlapped, _max_overlapped = None, thr, 0
s, e = 0, len(boxes)
for i in range(s, e):
ov = Recognizer.overlapped_area(box, boxes[i])
_ov = Recognizer.overlapped_area(boxes[i], box)
if (ov, _ov) < (max_overlapped, _max_overlapped):
continue
max_overlapped_i = i
max_overlapped = ov
_max_overlapped = _ov
return max_overlapped_i
def preprocess(self, image_list):
inputs = []
if "scale_factor" in self.input_names:
preprocess_ops = []
for op_info in [
{'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
{'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
{'type': 'Permute'},
{'stride': 32, 'type': 'PadStride'}
]:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(getattr(operators, op_type)(**new_op_info))
for im_path in image_list:
im, im_info = preprocess(im_path, preprocess_ops)
inputs.append({"image": np.array((im,)).astype('float32'),
"scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
else:
hh, ww = self.input_shape
for img in image_list:
h, w = img.shape[:2]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(np.array(img).astype('float32'), (ww, hh))
# Scale input pixel values to 0 to 1
img /= 255.0
img = img.transpose(2, 0, 1)
img = img[np.newaxis, :, :, :].astype(np.float32)
inputs.append({self.input_names[0]: img, "scale_factor": [w/ww, h/hh]})
return inputs
def postprocess(self, boxes, inputs, thr):
if "scale_factor" in self.input_names:
bb = []
for b in boxes:
clsid, bbox, score = int(b[0]), b[2:], b[1]
if score < thr:
continue
if clsid >= len(self.label_list):
continue
bb.append({
"type": self.label_list[clsid].lower(),
"bbox": [float(t) for t in bbox.tolist()],
"score": float(score)
})
return bb
def xywh2xyxy(x):
# [x, y, w, h] to [x1, y1, x2, y2]
y = np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2
y[:, 1] = x[:, 1] - x[:, 3] / 2
y[:, 2] = x[:, 0] + x[:, 2] / 2
y[:, 3] = x[:, 1] + x[:, 3] / 2
return y
def compute_iou(box, boxes):
# Compute xmin, ymin, xmax, ymax for both boxes
xmin = np.maximum(box[0], boxes[:, 0])
ymin = np.maximum(box[1], boxes[:, 1])
xmax = np.minimum(box[2], boxes[:, 2])
ymax = np.minimum(box[3], boxes[:, 3])
# Compute intersection area
intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
# Compute union area
box_area = (box[2] - box[0]) * (box[3] - box[1])
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
union_area = box_area + boxes_area - intersection_area
# Compute IoU
iou = intersection_area / union_area
return iou
def iou_filter(boxes, scores, iou_threshold):
sorted_indices = np.argsort(scores)[::-1]
keep_boxes = []
while sorted_indices.size > 0:
# Pick the last box
box_id = sorted_indices[0]
keep_boxes.append(box_id)
# Compute IoU of the picked box with the rest
ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
# Remove boxes with IoU over the threshold
keep_indices = np.where(ious < iou_threshold)[0]
# print(keep_indices.shape, sorted_indices.shape)
sorted_indices = sorted_indices[keep_indices + 1]
return keep_boxes
boxes = np.squeeze(boxes).T
# Filter out object confidence scores below threshold
scores = np.max(boxes[:, 4:], axis=1)
boxes = boxes[scores > thr, :]
scores = scores[scores > thr]
if len(boxes) == 0:
return []
# Get the class with the highest confidence
class_ids = np.argmax(boxes[:, 4:], axis=1)
boxes = boxes[:, :4]
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0], inputs["scale_factor"][1]])
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
boxes = xywh2xyxy(boxes)
unique_class_ids = np.unique(class_ids)
indices = []
for class_id in unique_class_ids:
class_indices = np.where(class_ids == class_id)[0]
class_boxes = boxes[class_indices, :]
class_scores = scores[class_indices]
class_keep_boxes = iou_filter(class_boxes, class_scores, 0.2)
indices.extend(class_indices[class_keep_boxes])
return [{
"type": self.label_list[class_ids[i]].lower(),
"bbox": [float(t) for t in boxes[i].tolist()],
"score": float(scores[i])
} for i in indices]
def close(self):
logging.info("Close recognizer.")
if hasattr(self, "ort_sess"):
del self.ort_sess
gc.collect()
def __call__(self, image_list, thr=0.7, batch_size=16):
res = []
images = []
for i in range(len(image_list)):
if not isinstance(image_list[i], np.ndarray):
images.append(np.array(image_list[i]))
else:
images.append(image_list[i])
batch_loop_cnt = math.ceil(float(len(images)) / batch_size)
for i in range(batch_loop_cnt):
start_index = i * batch_size
end_index = min((i + 1) * batch_size, len(images))
batch_image_list = images[start_index:end_index]
inputs = self.preprocess(batch_image_list)
logging.debug("preprocess")
for ins in inputs:
bb = self.postprocess(self.ort_sess.run(None, {k:v for k,v in ins.items() if k in self.input_names}, self.run_options)[0], ins, thr)
res.append(bb)
#seeit.save_results(image_list, res, self.label_list, threshold=thr)
return res
def __del__(self):
self.close()

View File

@@ -0,0 +1,71 @@
import logging
import os
import PIL
from PIL import ImageDraw
def save_results(image_list, results, labels, output_dir='output/', threshold=0.5):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for idx, im in enumerate(image_list):
im = draw_box(im, results[idx], labels, threshold=threshold)
out_path = os.path.join(output_dir, f"{idx}.jpg")
im.save(out_path, quality=95)
logging.debug("save result to: " + out_path)
def draw_box(im, result, labels, threshold=0.5):
draw_thickness = min(im.size) // 320
draw = ImageDraw.Draw(im)
color_list = get_color_map_list(len(labels))
clsid2color = {n.lower():color_list[i] for i,n in enumerate(labels)}
result = [r for r in result if r["score"] >= threshold]
for dt in result:
color = tuple(clsid2color[dt["type"]])
xmin, ymin, xmax, ymax = dt["bbox"]
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=draw_thickness,
fill=color)
# draw label
text = "{} {:.4f}".format(dt["type"], dt["score"])
tw, th = imagedraw_textsize_c(draw, text)
draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return im
def get_color_map_list(num_classes):
"""
Args:
num_classes (int): number of class
Returns:
color_map (list): RGB color list
"""
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
return color_map
def imagedraw_textsize_c(draw, text):
if int(PIL.__version__.split('.')[0]) < 10:
tw, th = draw.textsize(text)
else:
left, top, right, bottom = draw.textbbox((0, 0), text)
tw, th = right - left, bottom - top
return tw, th

View File

@@ -0,0 +1,77 @@
import os
import sys
sys.path.insert(
0,
os.path.abspath(
os.path.join(
os.path.dirname(
os.path.abspath(__file__)),
'../../')))
from .seeit import draw_box
from . import OCR, init_in_out
import argparse
import numpy as np
import trio
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu
# os.environ['CUDA_VISIBLE_DEVICES'] = '' #cpu
def main(args):
import torch.cuda
cuda_devices = torch.cuda.device_count()
limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
ocr = OCR()
images, outputs = init_in_out(args)
def __ocr(i, id, img):
print("Task {} start".format(i))
bxs = ocr(np.array(img), id)
bxs = [(line[0], line[1][0]) for line in bxs]
bxs = [{
"text": t,
"bbox": [b[0][0], b[0][1], b[1][0], b[-1][1]],
"type": "ocr",
"score": 1} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]]
img = draw_box(images[i], bxs, ["ocr"], 1.)
img.save(outputs[i], quality=95)
with open(outputs[i] + ".txt", "w+", encoding='utf-8') as f:
f.write("\n".join([o["text"] for o in bxs]))
print("Task {} done".format(i))
async def __ocr_thread(i, id, img, limiter = None):
if limiter:
async with limiter:
print("Task {} use device {}".format(i, id))
await trio.to_thread.run_sync(lambda: __ocr(i, id, img))
else:
__ocr(i, id, img)
async def __ocr_launcher():
if cuda_devices > 1:
async with trio.open_nursery() as nursery:
for i, img in enumerate(images):
nursery.start_soon(__ocr_thread, i, i % cuda_devices, img, limiter[i % cuda_devices])
await trio.sleep(0.1)
else:
for i, img in enumerate(images):
await __ocr_thread(i, 0, img)
trio.run(__ocr_launcher)
print("OCR tasks are all done")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--inputs',
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
required=True)
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './ocr_outputs'",
default="./ocr_outputs")
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,170 @@
import logging
import os
import sys
sys.path.insert(
0,
os.path.abspath(
os.path.join(
os.path.dirname(
os.path.abspath(__file__)),
'../../')))
from .seeit import draw_box
from . import LayoutRecognizer, TableStructureRecognizer, OCR, init_in_out
import argparse
import re
import numpy as np
def main(args):
images, outputs = init_in_out(args)
if args.mode.lower() == "layout":
detr = LayoutRecognizer("layout")
layouts = detr.forward(images, thr=float(args.threshold))
if args.mode.lower() == "tsr":
detr = TableStructureRecognizer()
ocr = OCR()
layouts = detr(images, thr=float(args.threshold))
for i, lyt in enumerate(layouts):
if args.mode.lower() == "tsr":
#lyt = [t for t in lyt if t["type"] == "table column"]
html = get_table_html(images[i], lyt, ocr)
with open(outputs[i] + ".html", "w+", encoding='utf-8') as f:
f.write(html)
lyt = [{
"type": t["label"],
"bbox": [t["x0"], t["top"], t["x1"], t["bottom"]],
"score": t["score"]
} for t in lyt]
img = draw_box(images[i], lyt, detr.labels, float(args.threshold))
img.save(outputs[i], quality=95)
logging.info("save result to: " + outputs[i])
def get_table_html(img, tb_cpns, ocr):
boxes = ocr(np.array(img))
boxes = LayoutRecognizer.sort_Y_firstly(
[{"x0": b[0][0], "x1": b[1][0],
"top": b[0][1], "text": t[0],
"bottom": b[-1][1],
"layout_type": "table",
"page_number": 0} for b, t in boxes if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
np.mean([b[-1][1] - b[0][1] for b, _ in boxes]) / 3
)
def gather(kwd, fzy=10, ption=0.6):
nonlocal boxes
eles = LayoutRecognizer.sort_Y_firstly(
[r for r in tb_cpns if re.match(kwd, r["label"])], fzy)
eles = LayoutRecognizer.layouts_cleanup(boxes, eles, 5, ption)
return LayoutRecognizer.sort_Y_firstly(eles, 0)
headers = gather(r".*header$")
rows = gather(r".* (row|header)")
spans = gather(r".*spanning")
clmns = sorted([r for r in tb_cpns if re.match(
r"table column$", r["label"])], key=lambda x: x["x0"])
clmns = LayoutRecognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
for b in boxes:
ii = LayoutRecognizer.find_overlapped_with_threshold(b, rows, thr=0.3)
if ii is not None:
b["R"] = ii
b["R_top"] = rows[ii]["top"]
b["R_bott"] = rows[ii]["bottom"]
ii = LayoutRecognizer.find_overlapped_with_threshold(b, headers, thr=0.3)
if ii is not None:
b["H_top"] = headers[ii]["top"]
b["H_bott"] = headers[ii]["bottom"]
b["H_left"] = headers[ii]["x0"]
b["H_right"] = headers[ii]["x1"]
b["H"] = ii
ii = LayoutRecognizer.find_horizontally_tightest_fit(b, clmns)
if ii is not None:
b["C"] = ii
b["C_left"] = clmns[ii]["x0"]
b["C_right"] = clmns[ii]["x1"]
ii = LayoutRecognizer.find_overlapped_with_threshold(b, spans, thr=0.3)
if ii is not None:
b["H_top"] = spans[ii]["top"]
b["H_bott"] = spans[ii]["bottom"]
b["H_left"] = spans[ii]["x0"]
b["H_right"] = spans[ii]["x1"]
b["SP"] = ii
html = """
<html>
<head>
<style>
._table_1nkzy_11 {
margin: auto;
width: 70%%;
padding: 10px;
}
._table_1nkzy_11 p {
margin-bottom: 50px;
border: 1px solid #e1e1e1;
}
caption {
color: #6ac1ca;
font-size: 20px;
height: 50px;
line-height: 50px;
font-weight: 600;
margin-bottom: 10px;
}
._table_1nkzy_11 table {
width: 100%%;
border-collapse: collapse;
}
th {
color: #fff;
background-color: #6ac1ca;
}
td:hover {
background: #c1e8e8;
}
tr:nth-child(even) {
background-color: #f2f2f2;
}
._table_1nkzy_11 th,
._table_1nkzy_11 td {
text-align: center;
border: 1px solid #ddd;
padding: 8px;
}
</style>
</head>
<body>
%s
</body>
</html>
""" % TableStructureRecognizer.construct_table(boxes, html=True)
return html
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--inputs',
help="Directory where to store images or PDFs, or a file path to a single image or PDF",
required=True)
parser.add_argument('--output_dir', help="Directory where to store the output images. Default: './layouts_outputs'",
default="./layouts_outputs")
parser.add_argument(
'--threshold',
help="A threshold to filter out detections. Default: 0.5",
default=0.5)
parser.add_argument('--mode', help="Task mode: layout recognition or table structure recognition", choices=["layout", "tsr"],
default="layout")
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,597 @@
import logging
import os
import re
from collections import Counter
import numpy as np
from huggingface_hub import snapshot_download
from app.core.rag.common.file_utils import get_project_base_directory
from app.core.rag.nlp import rag_tokenizer
from .recognizer import Recognizer
class TableStructureRecognizer(Recognizer):
labels = [
"table",
"table column",
"table row",
"table column header",
"table projected row header",
"table spanning cell",
]
def __init__(self):
try:
super().__init__(self.labels, "tsr", os.path.join(get_project_base_directory(), "res/deepdoc"))
except Exception:
super().__init__(
self.labels,
"tsr",
snapshot_download(
repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "res/deepdoc"),
local_dir_use_symlinks=False,
),
)
def __call__(self, images, thr=0.2):
table_structure_recognizer_type = os.getenv("TABLE_STRUCTURE_RECOGNIZER_TYPE", "onnx").lower()
if table_structure_recognizer_type not in ["onnx", "ascend"]:
raise RuntimeError("Unsupported table structure recognizer type.")
if table_structure_recognizer_type == "onnx":
logging.debug("Using Onnx table structure recognizer")
tbls = super().__call__(images, thr)
else: # ascend
logging.debug("Using Ascend table structure recognizer")
tbls = self._run_ascend_tsr(images, thr)
res = []
# align left&right for rows, align top&bottom for columns
for tbl in tbls:
lts = [
{
"label": b["type"],
"score": b["score"],
"x0": b["bbox"][0],
"x1": b["bbox"][2],
"top": b["bbox"][1],
"bottom": b["bbox"][-1],
}
for b in tbl
]
if not lts:
continue
left = [b["x0"] for b in lts if b["label"].find("row") > 0 or b["label"].find("header") > 0]
right = [b["x1"] for b in lts if b["label"].find("row") > 0 or b["label"].find("header") > 0]
if not left:
continue
left = np.mean(left) if len(left) > 4 else np.min(left)
right = np.mean(right) if len(right) > 4 else np.max(right)
for b in lts:
if b["label"].find("row") > 0 or b["label"].find("header") > 0:
if b["x0"] > left:
b["x0"] = left
if b["x1"] < right:
b["x1"] = right
top = [b["top"] for b in lts if b["label"] == "table column"]
bottom = [b["bottom"] for b in lts if b["label"] == "table column"]
if not top:
res.append(lts)
continue
top = np.median(top) if len(top) > 4 else np.min(top)
bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom)
for b in lts:
if b["label"] == "table column":
if b["top"] > top:
b["top"] = top
if b["bottom"] < bottom:
b["bottom"] = bottom
res.append(lts)
return res
@staticmethod
def is_caption(bx):
patt = [r"[图表]+[ 0-9:]{2,}"]
if any([re.match(p, bx["text"].strip()) for p in patt]) or bx.get("layout_type", "").find("caption") >= 0:
return True
return False
@staticmethod
def blockType(b):
patt = [
("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
(r"^(20|19)[0-9]{2}年$", "Dt"),
(r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"),
("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"),
(r"^第*[一二三四1-4]季度$", "Dt"),
(r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
(r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"),
("^[0-9.,+%/ -]+$", "Nu"),
(r"^[0-9A-Z/\._~-]+$", "Ca"),
(r"^[A-Z]*[a-z' -]+$", "En"),
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()' -]+$", "NE"),
(r"^.{1}$", "Sg"),
]
for p, n in patt:
if re.search(p, b["text"].strip()):
return n
tks = [t for t in rag_tokenizer.tokenize(b["text"]).split() if len(t) > 1]
if len(tks) > 3:
if len(tks) < 12:
return "Tx"
else:
return "Lx"
if len(tks) == 1 and rag_tokenizer.tag(tks[0]) == "nr":
return "Nr"
return "Ot"
@staticmethod
def construct_table(boxes, is_english=False, html=True, **kwargs):
cap = ""
i = 0
while i < len(boxes):
if TableStructureRecognizer.is_caption(boxes[i]):
if is_english:
cap + " "
cap += boxes[i]["text"]
boxes.pop(i)
i -= 1
i += 1
if not boxes:
return []
for b in boxes:
b["btype"] = TableStructureRecognizer.blockType(b)
max_type = Counter([b["btype"] for b in boxes]).items()
max_type = max(max_type, key=lambda x: x[1])[0] if max_type else ""
logging.debug("MAXTYPE: " + max_type)
rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
rowh = np.min(rowh) if rowh else 0
boxes = Recognizer.sort_R_firstly(boxes, rowh / 2)
# for b in boxes:print(b)
boxes[0]["rn"] = 0
rows = [[boxes[0]]]
btm = boxes[0]["bottom"]
for b in boxes[1:]:
b["rn"] = len(rows) - 1
lst_r = rows[-1]
if lst_r[-1].get("R", "") != b.get("R", "") or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")): # new row
btm = b["bottom"]
b["rn"] += 1
rows.append([b])
continue
btm = (btm + b["bottom"]) / 2.0
rows[-1].append(b)
colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
colwm = np.min(colwm) if colwm else 0
crosspage = len(set([b["page_number"] for b in boxes])) > 1
if crosspage:
boxes = Recognizer.sort_X_firstly(boxes, colwm / 2)
else:
boxes = Recognizer.sort_C_firstly(boxes, colwm / 2)
boxes[0]["cn"] = 0
cols = [[boxes[0]]]
right = boxes[0]["x1"]
for b in boxes[1:]:
b["cn"] = len(cols) - 1
lst_c = cols[-1]
if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1]["page_number"]) or (
b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")
): # new col
right = b["x1"]
b["cn"] += 1
cols.append([b])
continue
right = (right + b["x1"]) / 2.0
cols[-1].append(b)
tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
for b in boxes:
tbl[b["rn"]][b["cn"]].append(b)
if len(rows) >= 4:
# remove single in column
j = 0
while j < len(tbl[0]):
e, ii = 0, 0
for i in range(len(tbl)):
if tbl[i][j]:
e += 1
ii = i
if e > 1:
break
if e > 1:
j += 1
continue
f = (j > 0 and tbl[ii][j - 1] and tbl[ii][j - 1][0].get("text")) or j == 0
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii][j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
if f and ff:
j += 1
continue
bx = tbl[ii][j][0]
logging.debug("Relocate column single: " + bx["text"])
# j column only has one value
left, right = 100000, 100000
if j > 0 and not f:
for i in range(len(tbl)):
if tbl[i][j - 1]:
left = min(left, np.min([bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
if j + 1 < len(tbl[0]) and not ff:
for i in range(len(tbl)):
if tbl[i][j + 1]:
right = min(right, np.min([a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
assert left < 100000 or right < 100000
if left < right:
for jj in range(j, len(tbl[0])):
for i in range(len(tbl)):
for a in tbl[i][jj]:
a["cn"] -= 1
if tbl[ii][j - 1]:
tbl[ii][j - 1].extend(tbl[ii][j])
else:
tbl[ii][j - 1] = tbl[ii][j]
for i in range(len(tbl)):
tbl[i].pop(j)
else:
for jj in range(j + 1, len(tbl[0])):
for i in range(len(tbl)):
for a in tbl[i][jj]:
a["cn"] -= 1
if tbl[ii][j + 1]:
tbl[ii][j + 1].extend(tbl[ii][j])
else:
tbl[ii][j + 1] = tbl[ii][j]
for i in range(len(tbl)):
tbl[i].pop(j)
cols.pop(j)
assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (len(cols), len(tbl[0]))
if len(cols) >= 4:
# remove single in row
i = 0
while i < len(tbl):
e, jj = 0, 0
for j in range(len(tbl[i])):
if tbl[i][j]:
e += 1
jj = j
if e > 1:
break
if e > 1:
i += 1
continue
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1][jj][0].get("text")) or i == 0
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1][jj][0].get("text")) or i + 1 >= len(tbl)
if f and ff:
i += 1
continue
bx = tbl[i][jj][0]
logging.debug("Relocate row single: " + bx["text"])
# i row only has one value
up, down = 100000, 100000
if i > 0 and not f:
for j in range(len(tbl[i - 1])):
if tbl[i - 1][j]:
up = min(up, np.min([bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
if i + 1 < len(tbl) and not ff:
for j in range(len(tbl[i + 1])):
if tbl[i + 1][j]:
down = min(down, np.min([a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
assert up < 100000 or down < 100000
if up < down:
for ii in range(i, len(tbl)):
for j in range(len(tbl[ii])):
for a in tbl[ii][j]:
a["rn"] -= 1
if tbl[i - 1][jj]:
tbl[i - 1][jj].extend(tbl[i][jj])
else:
tbl[i - 1][jj] = tbl[i][jj]
tbl.pop(i)
else:
for ii in range(i + 1, len(tbl)):
for j in range(len(tbl[ii])):
for a in tbl[ii][j]:
a["rn"] -= 1
if tbl[i + 1][jj]:
tbl[i + 1][jj].extend(tbl[i][jj])
else:
tbl[i + 1][jj] = tbl[i][jj]
tbl.pop(i)
rows.pop(i)
# which rows are headers
hdset = set([])
for i in range(len(tbl)):
cnt, h = 0, 0
for j, arr in enumerate(tbl[i]):
if not arr:
continue
cnt += 1
if max_type == "Nu" and arr[0]["btype"] == "Nu":
continue
if any([a.get("H") for a in arr]) or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
h += 1
if h / cnt > 0.5:
hdset.add(i)
if html:
return TableStructureRecognizer.__html_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, True))
return TableStructureRecognizer.__desc_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, False), is_english)
@staticmethod
def __html_table(cap, hdset, tbl):
# constrcut HTML
html = "<table>"
if cap:
html += f"<caption>{cap}</caption>"
for i in range(len(tbl)):
row = "<tr>"
txts = []
for j, arr in enumerate(tbl[i]):
if arr is None:
continue
if not arr:
row += "<td></td>" if i not in hdset else "<th></th>"
continue
txt = ""
if arr:
h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10)
txt = " ".join([c["text"] for c in Recognizer.sort_Y_firstly(arr, h)])
txts.append(txt)
sp = ""
if arr[0].get("colspan"):
sp = "colspan={}".format(arr[0]["colspan"])
if arr[0].get("rowspan"):
sp += " rowspan={}".format(arr[0]["rowspan"])
if i in hdset:
row += f"<th {sp} >" + txt + "</th>"
else:
row += f"<td {sp} >" + txt + "</td>"
if i in hdset:
if all([t in hdset for t in txts]):
continue
for t in txts:
hdset.add(t)
if row != "<tr>":
row += "</tr>"
else:
row = ""
html += "\n" + row
html += "\n</table>"
return html
@staticmethod
def __desc_table(cap, hdr_rowno, tbl, is_english):
# get text of every colomn in header row to become header text
clmno = len(tbl[0])
rowno = len(tbl)
headers = {}
hdrset = set()
lst_hdr = []
de = "" if not is_english else " for "
for r in sorted(list(hdr_rowno)):
headers[r] = ["" for _ in range(clmno)]
for i in range(clmno):
if not tbl[r][i]:
continue
txt = " ".join([a["text"].strip() for a in tbl[r][i]])
headers[r][i] = txt
hdrset.add(txt)
if all([not t for t in headers[r]]):
del headers[r]
hdr_rowno.remove(r)
continue
for j in range(clmno):
if headers[r][j]:
continue
if j >= len(lst_hdr):
break
headers[r][j] = lst_hdr[j]
lst_hdr = headers[r]
for i in range(rowno):
if i not in hdr_rowno:
continue
for j in range(i + 1, rowno):
if j not in hdr_rowno:
break
for k in range(clmno):
if not headers[j - 1][k]:
continue
if headers[j][k].find(headers[j - 1][k]) >= 0:
continue
if len(headers[j][k]) > len(headers[j - 1][k]):
headers[j][k] += (de if headers[j][k] else "") + headers[j - 1][k]
else:
headers[j][k] = headers[j - 1][k] + (de if headers[j - 1][k] else "") + headers[j][k]
logging.debug(f">>>>>>>>>>>>>>>>>{cap}SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
row_txt = []
for i in range(rowno):
if i in hdr_rowno:
continue
rtxt = []
def append(delimer):
nonlocal rtxt, row_txt
rtxt = delimer.join(rtxt)
if row_txt and len(row_txt[-1]) + len(rtxt) < 64:
row_txt[-1] += "\n" + rtxt
else:
row_txt.append(rtxt)
r = 0
if len(headers.items()):
_arr = [(i - r, r) for r, _ in headers.items() if r < i]
if _arr:
_, r = min(_arr, key=lambda x: x[0])
if r not in headers and clmno <= 2:
for j in range(clmno):
if not tbl[i][j]:
continue
txt = "".join([a["text"].strip() for a in tbl[i][j]])
if txt:
rtxt.append(txt)
if rtxt:
append("")
continue
for j in range(clmno):
if not tbl[i][j]:
continue
txt = "".join([a["text"].strip() for a in tbl[i][j]])
if not txt:
continue
ctt = headers[r][j] if r in headers else ""
if ctt:
ctt += ""
ctt += txt
if ctt:
rtxt.append(ctt)
if rtxt:
row_txt.append("; ".join(rtxt))
if cap:
if is_english:
from_ = " in "
else:
from_ = "来自"
row_txt = [t + f"\t——{from_}{cap}" for t in row_txt]
return row_txt
@staticmethod
def __cal_spans(boxes, rows, cols, tbl, html=True):
# caculate span
clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) for cln in cols]
crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) for cln in cols]
rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) for row in rows]
rbtm = [np.mean([c.get("R_btm", c["bottom"]) for c in row]) for row in rows]
for b in boxes:
if "SP" not in b:
continue
b["colspan"] = [b["cn"]]
b["rowspan"] = [b["rn"]]
# col span
for j in range(0, len(clft)):
if j == b["cn"]:
continue
if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]:
continue
if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]:
continue
b["colspan"].append(j)
# row span
for j in range(0, len(rtop)):
if j == b["rn"]:
continue
if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]:
continue
if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]:
continue
b["rowspan"].append(j)
def join(arr):
if not arr:
return ""
return "".join([t["text"] for t in arr])
# rm the spaning cells
for i in range(len(tbl)):
for j, arr in enumerate(tbl[i]):
if not arr:
continue
if all(["rowspan" not in a and "colspan" not in a for a in arr]):
continue
rowspan, colspan = [], []
for a in arr:
if isinstance(a.get("rowspan", 0), list):
rowspan.extend(a["rowspan"])
if isinstance(a.get("colspan", 0), list):
colspan.extend(a["colspan"])
rowspan, colspan = set(rowspan), set(colspan)
if len(rowspan) < 2 and len(colspan) < 2:
for a in arr:
if "rowspan" in a:
del a["rowspan"]
if "colspan" in a:
del a["colspan"]
continue
rowspan, colspan = sorted(rowspan), sorted(colspan)
rowspan = list(range(rowspan[0], rowspan[-1] + 1))
colspan = list(range(colspan[0], colspan[-1] + 1))
assert i in rowspan, rowspan
assert j in colspan, colspan
arr = []
for r in rowspan:
for c in colspan:
arr_txt = join(arr)
if tbl[r][c] and join(tbl[r][c]) != arr_txt:
arr.extend(tbl[r][c])
tbl[r][c] = None if html else arr
for a in arr:
if len(rowspan) > 1:
a["rowspan"] = len(rowspan)
elif "rowspan" in a:
del a["rowspan"]
if len(colspan) > 1:
a["colspan"] = len(colspan)
elif "colspan" in a:
del a["colspan"]
tbl[rowspan[0]][colspan[0]] = arr
return tbl
def _run_ascend_tsr(self, image_list, thr=0.2, batch_size=16):
import math
from ais_bench.infer.interface import InferSession
model_dir = os.path.join(get_project_base_directory(), "res/deepdoc")
model_file_path = os.path.join(model_dir, "tsr.om")
if not os.path.exists(model_file_path):
raise ValueError(f"Model file not found: {model_file_path}")
device_id = int(os.getenv("ASCEND_LAYOUT_RECOGNIZER_DEVICE_ID", 0))
session = InferSession(device_id=device_id, model_path=model_file_path)
images = [np.array(im) if not isinstance(im, np.ndarray) else im for im in image_list]
results = []
conf_thr = max(thr, 0.08)
batch_loop_cnt = math.ceil(float(len(images)) / batch_size)
for bi in range(batch_loop_cnt):
s = bi * batch_size
e = min((bi + 1) * batch_size, len(images))
batch_images = images[s:e]
inputs_list = self.preprocess(batch_images)
for ins in inputs_list:
feeds = []
if "image" in ins:
feeds.append(ins["image"])
else:
feeds.append(ins[self.input_names[0]])
output_list = session.infer(feeds=feeds, mode="static")
bb = self.postprocess(output_list, ins, conf_thr)
results.append(bb)
return results

View File

View File

@@ -0,0 +1,19 @@
import xxhash
from app.aioRedis import aio_redis_set, aio_redis_get
def get_llm_cache(llmnm, txt, history, genconf):
hasher = xxhash.xxh64()
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
k = hasher.hexdigest()
bin = aio_redis_get(k)
if not bin:
return None
return bin
def set_llm_cache(llmnm, txt, v, history, genconf):
hasher = xxhash.xxh64()
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
k = hasher.hexdigest()
aio_redis_set(k, v.encode("utf-8"), 24 * 3600)

View File

View File

@@ -0,0 +1,670 @@
import json
import logging
import os
import random
import re
import time
from abc import ABC
from copy import deepcopy
from typing import Any, Protocol
from urllib.parse import urljoin
import json_repair
import openai
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from strenum import StrEnum
from app.core.rag.nlp import is_chinese, is_english
from app.core.rag.common.token_utils import num_tokens_from_string, total_token_count_from_response
# Error message constants
class LLMErrorCode(StrEnum):
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
ERROR_AUTHENTICATION = "AUTH_ERROR"
ERROR_INVALID_REQUEST = "INVALID_REQUEST"
ERROR_SERVER = "SERVER_ERROR"
ERROR_TIMEOUT = "TIMEOUT"
ERROR_CONNECTION = "CONNECTION_ERROR"
ERROR_MODEL = "MODEL_ERROR"
ERROR_MAX_ROUNDS = "ERROR_MAX_ROUNDS"
ERROR_CONTENT_FILTER = "CONTENT_FILTERED"
ERROR_QUOTA = "QUOTA_EXCEEDED"
ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED"
ERROR_GENERIC = "GENERIC_ERROR"
class ReActMode(StrEnum):
FUNCTION_CALL = "function_call"
REACT = "react"
ERROR_PREFIX = "**ERROR**"
LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。"
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
class ToolCallSession(Protocol):
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
class Base(ABC):
def __init__(self, key, model_name, base_url, **kwargs):
timeout = int(os.environ.get("LLM_TIMEOUT_SECONDS", 600))
self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout)
self.model_name = model_name
# Configure retry parameters
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
self.max_rounds = kwargs.get("max_rounds", 5)
self.is_tools = False
self.tools = []
self.toolcall_sessions = {}
def _get_delay(self):
"""Calculate retry delay time"""
return self.base_delay * random.uniform(10, 150)
def _classify_error(self, error):
"""Classify error based on error message content"""
error_str = str(error).lower()
keywords_mapping = [
(["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA),
(["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT),
(["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION),
(["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST),
(["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER),
(["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT),
(["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION),
(["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER),
(["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL),
(["max rounds"], LLMErrorCode.ERROR_MODEL),
]
for words, code in keywords_mapping:
if re.search("({})".format("|".join(words)), error_str):
return code
return LLMErrorCode.ERROR_GENERIC
def _clean_conf(self, gen_conf):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
allowed_conf = {
"temperature",
"max_completion_tokens",
"top_p",
"stream",
"stream_options",
"stop",
"n",
"presence_penalty",
"frequency_penalty",
"functions",
"function_call",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"logprobs",
"top_logprobs",
"extra_headers"
}
gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf}
return gen_conf
def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwq") >= 0:
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly")
final_ans = ""
tol_token = 0
for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
if delta.startswith("<think>") or delta.endswith("</think>"):
continue
final_ans += delta
tol_token = tol
if len(final_ans.strip()) == 0:
final_ans = "**ERROR**: Empty response from reasoning model"
return final_ans.strip(), tol_token
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False
if kwargs.get("stop") or "stop" in gen_conf:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
else:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
tol = num_tokens_from_string(resp.choices[0].delta.content)
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
def _length_stop(self, ans):
if is_chinese([ans]):
return ans + LENGTH_NOTIFICATION_CN
return ans + LENGTH_NOTIFICATION_EN
@property
def _retryable_errors(self) -> set[str]:
return {
LLMErrorCode.ERROR_RATE_LIMIT,
LLMErrorCode.ERROR_SERVER,
}
def _should_retry(self, error_code: str) -> bool:
return error_code in self._retryable_errors
def _exceptions(self, e, attempt) -> str | None:
logging.exception("OpenAI chat_with_tools")
# Classify the error
error_code = self._classify_error(e)
if attempt == self.max_retries:
error_code = LLMErrorCode.ERROR_MAX_RETRIES
if self._should_retry(error_code):
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
return None
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
def _verbose_tool_use(self, name, args, res):
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
def _append_history(self, hist, tool_call, tool_res):
hist.append(
{
"role": "assistant",
"tool_calls": [
{
"index": tool_call.index,
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
},
"type": "function",
},
],
}
)
try:
if isinstance(tool_res, dict):
tool_res = json.dumps(tool_res, ensure_ascii=False)
finally:
hist.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_res)})
return hist
def bind_tools(self, toolcall_session, tools):
if not (toolcall_session and tools):
return
self.is_tools = True
self.toolcall_session = toolcall_session
self.tools = tools
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
ans = ""
tk_count = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
tk_count += total_token_count_from_response(response)
if any([not response.choices, not response.choices[0].message]):
raise Exception(f"500 response structure error. Response: {response}")
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
ans += response.choices[0].message.content
if response.choices[0].finish_reason == "length":
ans = self._length_stop(ans)
return ans, tk_count
for tool_call in response.choices[0].message.tool_calls:
logging.info(f"Response {tool_call=}")
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
ans += self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = self._chat(history, gen_conf)
ans += response
tk_count += token_count
return ans, tk_count
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
try:
return self._chat(history, gen_conf, **kwargs)
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, 0
assert False, "Shouldn't be here."
def _wrap_toolcall_message(self, stream):
final_tool_calls = {}
for chunk in stream:
for tool_call in chunk.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
final_tool_calls[index] = tool_call
final_tool_calls[index].function.arguments += tool_call.function.arguments
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
total_tokens = 0
hist = deepcopy(history)
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds + 1):
reasoning_start = False
logging.info(f"{tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
final_tool_calls = {}
answer = ""
for resp in response:
if resp.choices[0].delta.tool_calls:
for tool_call in resp.choices[0].delta.tool_calls or []:
index = tool_call.index
if index not in final_tool_calls:
if not tool_call.function.arguments:
tool_call.function.arguments = ""
final_tool_calls[index] = tool_call
else:
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
continue
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
yield ans
else:
reasoning_start = False
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
if finish_reason == "length":
yield self._length_stop("")
if answer:
yield total_tokens
return
for tool_call in final_tool_calls.values():
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
yield self._verbose_tool_use(name, {}, str(e))
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
continue
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
yield total_tokens
return
except Exception as e:
e = self._exceptions(e, attempt)
if e:
yield e
yield total_tokens
return
assert False, "Shouldn't be here."
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
try:
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
yield delta_ans
total_tokens += tol
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
def _calculate_dynamic_ctx(self, history):
"""Calculate dynamic context window size"""
def count_tokens(text):
"""Calculate token count for text"""
# Simple calculation: 1 token per ASCII character
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
total = 0
for char in text:
if ord(char) < 128: # ASCII characters
total += 1
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
total += 2
return total
# Calculate total tokens for all messages
total_tokens = 0
for message in history:
content = message.get("content", "")
# Calculate content tokens
content_tokens = count_tokens(content)
# Add role marker token overhead
role_tokens = 4
total_tokens += content_tokens + role_tokens
# Apply 1.2x buffer ratio
total_tokens_with_buffer = int(total_tokens * 1.2)
if total_tokens_with_buffer <= 8192:
ctx_size = 8192
else:
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
ctx_size = ctx_multiplier * 8192
return ctx_size
class GptTurbo(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.openai.com/v1"
super().__init__(key, model_name, base_url, **kwargs)
class XinferenceChat(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name, base_url, **kwargs)
class HuggingFaceChat(Base):
_FACTORY_NAME = "HuggingFace"
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name.split("___")[0], base_url, **kwargs)
class ModelScopeChat(Base):
_FACTORY_NAME = "ModelScope"
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name.split("___")[0], base_url, **kwargs)
class AzureChat(Base):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, base_url, **kwargs):
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
super().__init__(key, model_name, base_url, **kwargs)
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=base_url, api_version=api_version)
self.model_name = model_name
@property
def _retryable_errors(self) -> set[str]:
return {
LLMErrorCode.ERROR_RATE_LIMIT,
LLMErrorCode.ERROR_SERVER,
LLMErrorCode.ERROR_QUOTA,
}
class BaiChuanChat(Base):
_FACTORY_NAME = "BaiChuan"
def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.baichuan-ai.com/v1"
super().__init__(key, model_name, base_url, **kwargs)
@staticmethod
def _format_params(params):
return {
"temperature": params.get("temperature", 0.3),
"top_p": params.get("top_p", 0.85),
}
def _clean_conf(self, gen_conf):
return {
"temperature": gen_conf.get("temperature", 0.3),
"top_p": gen_conf.get("top_p", 0.85),
}
def _chat(self, history, gen_conf={}, **kwargs):
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
**gen_conf,
)
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
return ans, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
ans = ""
total_tokens = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]},
stream=True,
**self._format_params(gen_conf),
)
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans = resp.choices[0].delta.content
tol = total_token_count_from_response(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens = tol
if resp.choices[0].finish_reason == "length":
if is_chinese([ans]):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
class LocalAIChat(Base):
_FACTORY_NAME = "LocalAI"
def __init__(self, key, model_name, base_url=None, **kwargs):
super().__init__(key, model_name, base_url=base_url, **kwargs)
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key="empty", base_url=base_url)
self.model_name = model_name.split("___")[0]
class VolcEngineChat(Base):
_FACTORY_NAME = "VolcEngine"
def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3", **kwargs):
"""
Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special,
Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use
model_name is for display only
"""
base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3"
ark_api_key = json.loads(key).get("ark_api_key", "")
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
super().__init__(ark_api_key, model_name, base_url, **kwargs)
class OpenAI_APIChat(Base):
_FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
def __init__(self, key, model_name, base_url, **kwargs):
if not base_url:
raise ValueError("url cannot be None")
model_name = model_name.split("___")[0]
super().__init__(key, model_name, base_url, **kwargs)
class GPUStackChat(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key=None, model_name="", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
super().__init__(key, model_name, base_url, **kwargs)

View File

@@ -0,0 +1,470 @@
import base64
import json
import os
import tempfile
import logging
from abc import ABC
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from urllib.parse import urljoin
import requests
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from app.core.rag.nlp import is_english
from app.core.rag.prompts.generator import vision_llm_describe_prompt
from app.core.rag.common.token_utils import num_tokens_from_string, total_token_count_from_response
class Base(ABC):
def __init__(self, **kwargs):
# Configure retry parameters
self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
self.max_rounds = kwargs.get("max_rounds", 5)
self.is_tools = False
self.tools = []
self.toolcall_sessions = {}
self.extra_body = None
def describe(self, image):
raise NotImplementedError("Please implement encode method!")
def describe_with_prompt(self, image, prompt=None):
raise NotImplementedError("Please implement encode method!")
def _form_history(self, system, history, images=None):
hist = []
if system:
hist.append({"role": "system", "content": system})
for h in history:
if images and h["role"] == "user":
h["content"] = self._image_prompt(h["content"], images)
images = []
hist.append(h)
return hist
def _image_prompt(self, text, images):
if not images:
return text
if isinstance(images, str) or "bytes" in type(images).__name__:
images = [images]
pmpt = [{"type": "text", "text": text}]
for img in images:
pmpt.append({
"type": "image_url",
"image_url": {
"url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"
}
})
return pmpt
def chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=self._form_history(system, history, images),
extra_body=self.extra_body,
)
return response.choices[0].message.content.strip(), response.usage.total_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
ans = ""
tk_count = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=self._form_history(system, history, images),
stream=True,
extra_body=self.extra_body,
)
for resp in response:
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans = delta
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
if resp.choices[0].finish_reason == "stop":
tk_count += resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
@staticmethod
def image2base64_rawvalue(self, image):
# Return a base64 string without data URL header
if isinstance(image, bytes):
b64 = base64.b64encode(image).decode("utf-8")
return b64
if isinstance(image, BytesIO):
data = image.getvalue()
b64 = base64.b64encode(data).decode("utf-8")
return b64
with BytesIO() as buffered:
try:
image.save(buffered, format="JPEG")
except Exception:
# reset buffer before saving PNG
buffered.seek(0)
buffered.truncate()
image.save(buffered, format="PNG")
data = buffered.getvalue()
b64 = base64.b64encode(data).decode("utf-8")
return b64
@staticmethod
def image2base64(image):
# Return a data URL with the correct MIME to avoid provider mismatches
if isinstance(image, bytes):
# Best-effort magic number sniffing
mime = "image/png"
if len(image) >= 2 and image[0] == 0xFF and image[1] == 0xD8:
mime = "image/jpeg"
b64 = base64.b64encode(image).decode("utf-8")
return f"data:{mime};base64,{b64}"
if isinstance(image, BytesIO):
data = image.getvalue()
mime = "image/png"
if len(data) >= 2 and data[0] == 0xFF and data[1] == 0xD8:
mime = "image/jpeg"
b64 = base64.b64encode(data).decode("utf-8")
return f"data:{mime};base64,{b64}"
with BytesIO() as buffered:
fmt = "jpeg"
try:
image.save(buffered, format="JPEG")
except Exception:
# reset buffer before saving PNG
buffered.seek(0)
buffered.truncate()
image.save(buffered, format="PNG")
fmt = "png"
data = buffered.getvalue()
b64 = base64.b64encode(data).decode("utf-8")
mime = f"image/{fmt}"
return f"data:{mime};base64,{b64}"
def prompt(self, b64):
return [
{
"role": "user",
"content": self._image_prompt(
"请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
if self.lang.lower() == "chinese"
else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
b64
)
}
]
def vision_llm_prompt(self, b64, prompt=None):
return [
{
"role": "user",
"content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)
}
]
class GptV4(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.openai.com/v1"
self.api_key = key
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
super().__init__(**kwargs)
def describe(self, image):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.prompt(b64),
extra_body=self.extra_body,
)
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
model=self.model_name,
messages=self.vision_llm_prompt(b64, prompt),
extra_body=self.extra_body,
)
return res.choices[0].message.content.strip(),total_token_count_from_response(res)
class AzureGptV4(GptV4):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, lang="Chinese", **kwargs):
api_key = json.loads(key).get("api_key", "")
api_version = json.loads(key).get("api_version", "2024-02-01")
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class QWenCV(GptV4):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", base_url=None, **kwargs):
if not base_url:
base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
def chat(self, system, history, gen_conf, images=None, video_bytes=None, filename="", **kwargs):
if video_bytes:
try:
summary, summary_num_tokens = self._process_video(video_bytes, filename)
return summary, summary_num_tokens
except Exception as e:
return "**ERROR**: " + str(e), 0
return "**ERROR**: Method chat not supported yet.", 0
def _process_video(self, video_bytes, filename):
from dashscope import MultiModalConversation
video_suffix = Path(filename).suffix or ".mp4"
with tempfile.NamedTemporaryFile(delete=False, suffix=video_suffix) as tmp:
tmp.write(video_bytes)
tmp_path = tmp.name
video_path = f"file://{tmp_path}"
messages = [
{
"role": "user",
"content": [
{
"video": video_path,
"fps": 2,
},
{
"text": "Please summarize this video in proper sentences.",
},
],
}
]
def call_api():
response = MultiModalConversation.call(
api_key=self.api_key,
model=self.model_name,
messages=messages,
)
summary = response["output"]["choices"][0]["message"].content[0]["text"]
return summary, num_tokens_from_string(summary)
try:
return call_api()
except Exception as e1:
import dashscope
dashscope.base_http_api_url = "https://dashscope-intl.aliyuncs.com/api/v1"
try:
return call_api()
except Exception as e2:
raise RuntimeError(f"Both default and intl endpoint failed.\nFirst error: {e1}\nSecond error: {e2}")
class XinferenceCV(GptV4):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="", lang="Chinese", base_url="", **kwargs):
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class GPUStackCV(GptV4):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
if not base_url:
raise ValueError("Local llm url cannot be None")
base_url = urljoin(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
Base.__init__(self, **kwargs)
class OllamaCV(Base):
_FACTORY_NAME = "Ollama"
def __init__(self, key, model_name, lang="Chinese", **kwargs):
from ollama import Client
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
self.lang = lang
self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
Base.__init__(self, **kwargs)
def _clean_img(self, img):
if not isinstance(img, str):
return img
#remove the header like "data/*;base64,"
if img.startswith("data:") and ";base64," in img:
img = img.split(";base64,")[1]
return img
def _clean_conf(self, gen_conf):
options = {}
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
return options
def _form_history(self, system, history, images=None):
hist = deepcopy(history)
if system and hist[0]["role"] == "user":
hist.insert(0, {"role": "system", "content": system})
if not images:
return hist
temp_images = []
for img in images:
temp_images.append(self._clean_img(img))
for his in hist:
if his["role"] == "user":
his["images"] = temp_images
break
return hist
def describe(self, image):
prompt = self.prompt("")
try:
response = self.client.generate(
model=self.model_name,
prompt=prompt[0]["content"],
images=[image],
)
ans = response["response"].strip()
return ans, 128
except Exception as e:
return "**ERROR**: " + str(e), 0
def describe_with_prompt(self, image, prompt=None):
vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
try:
response = self.client.generate(
model=self.model_name,
prompt=vision_prompt[0]["content"],
images=[image],
)
ans = response["response"].strip()
return ans, 128
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat(self, system, history, gen_conf, images=None, **kwargs):
try:
response = self.client.chat(
model=self.model_name,
messages=self._form_history(system, history, images),
options=self._clean_conf(gen_conf),
keep_alive=self.keep_alive
)
ans = response["message"]["content"].strip()
return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
ans = ""
try:
response = self.client.chat(
model=self.model_name,
messages=self._form_history(system, history, images),
stream=True,
options=self._clean_conf(gen_conf),
keep_alive=self.keep_alive
)
for resp in response:
if resp["done"]:
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
ans = resp["message"]["content"]
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield 0
if __name__ == "__main__":
# import sys
# chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
# # 准备配置vision_model信息
# azure_config = {
# "api_key": "xxxxx",
# "api_version": "2024-02-01"
# }
# # 转换为 JSON 字符串,因为类中使用 json.loads(key) 解析
# key = json.dumps(azure_config)
# # 初始化 AzureGptV4
# vision_model = AzureGptV4(
# key=key, # JSON 字符串形式的配置
# model_name="gpt-4o",
# lang="Chinese", # 默认使用中文
# base_url="https://fosun-openai-east-us-001.openai.azure.com/" # Azure OpenAI 端点
# )
# try:
# # 测试图像描述功能
# image_path = "/Users/sbtjfdn/Downloads/记忆科学/files/aippt.cn.png"
# with open(image_path, "rb") as image_file:
# image_data = image_file.read()
#
# # 使用 describe 方法
# description, token_count = vision_model.describe(image_data)
# # from app.core.rag.prompts.generator import vision_llm_figure_describe_prompt
# # description, token_count = vision_model.describe_with_prompt(image_data, prompt=vision_llm_figure_describe_prompt())
# print(f"描述: {description}")
# print(f"使用的令牌数: {token_count}")
#
# except Exception as e:
# print(f"初始化或处理过程中出错: {str(e)}")
# 准备配置vision_model信息
# 初始化 QWenCV
vision_model = QWenCV(
key="sk-8e9e40cd171749858ce2d3722ea75669",
model_name="qwen-vl-max",
lang="Chinese", # 默认使用中文
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
)
try:
# 测试图像描述功能
image_path = "/Users/sbtjfdn/Downloads/记忆科学/files/10.png"
with open(image_path, "rb") as image_file:
image_data = image_file.read()
# 使用 describe 方法
description, token_count = vision_model.describe(image_data)
# from app.core.rag.prompts.generator import vision_llm_figure_describe_prompt
# description, token_count = vision_model.describe_with_prompt(image_data, prompt=vision_llm_figure_describe_prompt())
print(f"描述: {description}")
print(f"使用的令牌数: {token_count}")
except Exception as e:
print(f"初始化或处理过程中出错: {str(e)}")

View File

@@ -0,0 +1,179 @@
import base64
import io
import json
import os
import re
from abc import ABC
import requests
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from app.core.rag.common.token_utils import num_tokens_from_string
class Base(ABC):
def __init__(self, key, model_name, **kwargs):
"""
Abstract base class constructor.
Parameters are not stored; initialization is left to subclasses.
"""
pass
def transcription(self, audio_path, **kwargs):
audio_file = open(audio_path, "rb")
transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio_file)
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
def audio2base64(self, audio):
if isinstance(audio, bytes):
return base64.b64encode(audio).decode("utf-8")
if isinstance(audio, io.BytesIO):
return base64.b64encode(audio.getvalue()).decode("utf-8")
raise TypeError("The input audio file should be in binary format.")
class GPTSeq2txt(Base):
_FACTORY_NAME = "OpenAI"
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1", **kwargs):
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
class QWenSeq2txt(Base):
_FACTORY_NAME = "Tongyi-Qianwen"
def __init__(self, key, model_name="qwen-audio-asr", **kwargs):
import dashscope
dashscope.api_key = key
self.model_name = model_name
def transcription(self, audio_path):
if "paraformer" in self.model_name or "sensevoice" in self.model_name:
return f"**ERROR**: model {self.model_name} is not suppported yet.", 0
from dashscope import MultiModalConversation
audio_path = f"file://{audio_path}"
messages = [
{
"role": "user",
"content": [{"audio": audio_path}],
}
]
response = None
full_content = ""
try:
response = MultiModalConversation.call(model="qwen-audio-asr", messages=messages, result_format="message", stream=True)
for response in response:
try:
full_content += response["output"]["choices"][0]["message"].content[0]["text"]
except Exception:
pass
return full_content, num_tokens_from_string(full_content)
except Exception as e:
return "**ERROR**: " + str(e), 0
class AzureSeq2txt(Base):
_FACTORY_NAME = "Azure-OpenAI"
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
self.model_name = model_name
self.lang = lang
class XinferenceSeq2txt(Base):
_FACTORY_NAME = "Xinference"
def __init__(self, key, model_name="whisper-small", **kwargs):
self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
self.key = key
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
if isinstance(audio, str):
audio_file = open(audio, "rb")
audio_data = audio_file.read()
audio_file_name = audio.split("/")[-1]
else:
audio_data = audio
audio_file_name = "audio.wav"
payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
files = {"file": (audio_file_name, audio_data, "audio/wav")}
try:
response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
response.raise_for_status()
result = response.json()
if "text" in result:
transcription_text = result["text"].strip()
return transcription_text, num_tokens_from_string(transcription_text)
else:
return "**ERROR**: Failed to retrieve transcription.", 0
except requests.exceptions.RequestException as e:
return f"**ERROR**: {str(e)}", 0
class GPUStackSeq2txt(Base):
_FACTORY_NAME = "GPUStack"
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
self.base_url = base_url
self.model_name = model_name
self.key = key
class ZhipuSeq2txt(Base):
_FACTORY_NAME = "ZHIPU-AI"
def __init__(self, key, model_name="glm-asr", base_url="https://open.bigmodel.cn/api/paas/v4", **kwargs):
if not base_url:
base_url = "https://open.bigmodel.cn/api/paas/v4"
self.base_url = base_url
self.api_key = key
self.model_name = model_name
self.gen_conf = kwargs.get("gen_conf", {})
self.stream = kwargs.get("stream", False)
def transcription(self, audio_path):
payload = {
"model": self.model_name,
"temperature": str(self.gen_conf.get("temperature", 0.75)) or "0.75",
"stream": self.stream,
}
headers = {"Authorization": f"Bearer {self.api_key}"}
with open(audio_path, "rb") as audio_file:
files = {"file": audio_file}
try:
response = requests.post(
url=f"{self.base_url}/audio/transcriptions",
data=payload,
files=files,
headers=headers,
)
body = response.json()
if response.status_code == 200:
full_content = body["text"]
return full_content, num_tokens_from_string(full_content)
else:
error = body["error"]
return f"**ERROR**: code: {error['code']}, message: {error['message']}", 0
except Exception as e:
return "**ERROR**: " + str(e), 0

View File

View File

@@ -0,0 +1,72 @@
from pydantic import BaseModel, Field
class ChildDocumentChunk(BaseModel):
"""Class for storing a piece of text and associated metadata."""
page_content: str
vector: list[float] | None = None
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: dict = Field(default_factory=dict)
class DocumentChunk(BaseModel):
"""Class for storing a piece of text and associated metadata."""
page_content: str
vector: list[float] | None = None
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: dict = Field(default_factory=dict)
children: list[ChildDocumentChunk] | None = None
class GeneralStructureChunk(BaseModel):
"""
General Structure Chunk.
"""
general_chunks: list[str]
class ParentChildChunk(BaseModel):
"""
Parent Child Chunk.
"""
parent_content: str
child_contents: list[str]
class ParentChildStructureChunk(BaseModel):
"""
Parent Child Structure Chunk.
"""
parent_child_chunks: list[ParentChildChunk]
parent_mode: str = "paragraph"
class QAChunk(BaseModel):
"""
QA Chunk.
"""
question: str
answer: str
class QAStructureChunk(BaseModel):
"""
QAStructureChunk.
"""
qa_chunks: list[QAChunk]

View File

@@ -0,0 +1,857 @@
import logging
import random
from collections import Counter
from app.core.rag.common.token_utils import num_tokens_from_string
from . import rag_tokenizer
import re
import copy
import roman_numbers as r
from word2number import w2n
from cn2an import cn2an
from PIL import Image
import chardet
all_codecs = [
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
'cp037', 'cp273', 'cp424', 'cp437',
'cp500', 'cp720', 'cp737', 'cp775', 'cp850', 'cp852', 'cp855', 'cp856', 'cp857',
'cp858', 'cp860', 'cp861', 'cp862', 'cp863', 'cp864', 'cp865', 'cp866', 'cp869',
'cp874', 'cp875', 'cp932', 'cp949', 'cp950', 'cp1006', 'cp1026', 'cp1125',
'cp1140', 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', 'cp1256',
'cp1257', 'cp1258', 'euc_jp', 'euc_jis_2004', 'euc_jisx0213', 'euc_kr',
'gb18030', 'hz', 'iso2022_jp', 'iso2022_jp_1', 'iso2022_jp_2',
'iso2022_jp_2004', 'iso2022_jp_3', 'iso2022_jp_ext', 'iso2022_kr', 'latin_1',
'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', 'iso8859_7',
'iso8859_8', 'iso8859_9', 'iso8859_10', 'iso8859_11', 'iso8859_13',
'iso8859_14', 'iso8859_15', 'iso8859_16', 'johab', 'koi8_r', 'koi8_t', 'koi8_u',
'kz1048', 'mac_cyrillic', 'mac_greek', 'mac_iceland', 'mac_latin2', 'mac_roman',
'mac_turkish', 'ptcp154', 'shift_jis', 'shift_jis_2004', 'shift_jisx0213',
'utf_32', 'utf_32_be', 'utf_32_le', 'utf_16_be', 'utf_16_le', 'utf_7', 'windows-1250', 'windows-1251',
'windows-1252', 'windows-1253', 'windows-1254', 'windows-1255', 'windows-1256',
'windows-1257', 'windows-1258', 'latin-2'
]
def find_codec(blob):
detected = chardet.detect(blob[:1024])
if detected['confidence'] > 0.5:
if detected['encoding'] == "ascii":
return "utf-8"
for c in all_codecs:
try:
blob[:1024].decode(c)
return c
except Exception:
pass
try:
blob.decode(c)
return c
except Exception:
pass
return "utf-8"
QUESTION_PATTERN = [
r"第([零一二三四五六七八九十百0-9]+)问",
r"第([零一二三四五六七八九十百0-9]+)条",
r"[\(]([零一二三四五六七八九十百]+)[\)]",
r"第([0-9]+)问",
r"第([0-9]+)条",
r"([0-9]{1,2})[\. 、]",
r"([零一二三四五六七八九十百]+)[ 、]",
r"[\(]([0-9]{1,2})[\)]",
r"QUESTION (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
r"QUESTION (I+V?|VI*|XI|IX|X)",
r"QUESTION ([0-9]+)",
]
def has_qbullet(reg, box, last_box, last_index, last_bull, bull_x0_list):
section, last_section = box['text'], last_box['text']
q_reg = r'(\w|\W)*?(?:|\?|\n|$)+'
full_reg = reg + q_reg
has_bull = re.match(full_reg, section)
index_str = None
if has_bull:
if 'x0' not in last_box:
last_box['x0'] = box['x0']
if 'top' not in last_box:
last_box['top'] = box['top']
if last_bull and box['x0'] - last_box['x0'] > 10:
return None, last_index
if not last_bull and box['x0'] >= last_box['x0'] and box['top'] - last_box['top'] < 20:
return None, last_index
avg_bull_x0 = 0
if bull_x0_list:
avg_bull_x0 = sum(bull_x0_list) / len(bull_x0_list)
else:
avg_bull_x0 = box['x0']
if box['x0'] - avg_bull_x0 > 10:
return None, last_index
index_str = has_bull.group(1)
index = index_int(index_str)
if last_section[-1] == ':' or last_section[-1] == '':
return None, last_index
if not last_index or index >= last_index:
bull_x0_list.append(box['x0'])
return has_bull, index
if section[-1] == '?' or section[-1] == '':
bull_x0_list.append(box['x0'])
return has_bull, index
if box['layout_type'] == 'title':
bull_x0_list.append(box['x0'])
return has_bull, index
pure_section = section.lstrip(re.match(reg, section).group()).lower()
ask_reg = r'(what|when|where|how|why|which|who|whose|为什么|为啥|哪)'
if re.match(ask_reg, pure_section):
bull_x0_list.append(box['x0'])
return has_bull, index
return None, last_index
def index_int(index_str):
res = -1
try:
res = int(index_str)
except ValueError:
try:
res = w2n.word_to_num(index_str)
except ValueError:
try:
res = cn2an(index_str)
except ValueError:
try:
res = r.number(index_str)
except ValueError:
return -1
return res
def qbullets_category(sections):
global QUESTION_PATTERN
hits = [0] * len(QUESTION_PATTERN)
for i, pro in enumerate(QUESTION_PATTERN):
for sec in sections:
if re.match(pro, sec) and not not_bullet(sec):
hits[i] += 1
break
maxium = 0
res = -1
for i, h in enumerate(hits):
if h <= maxium:
continue
res = i
maxium = h
return res, QUESTION_PATTERN[res]
BULLET_PATTERN = [[
r"第[零一二三四五六七八九十百0-9]+(分?编|部分)",
r"第[零一二三四五六七八九十百0-9]+章",
r"第[零一二三四五六七八九十百0-9]+节",
r"第[零一二三四五六七八九十百0-9]+条",
r"[\(][零一二三四五六七八九十百]+[\)]",
], [
r"第[0-9]+章",
r"第[0-9]+节",
r"[0-9]{,2}[\. 、]",
r"[0-9]{,2}\.[0-9]{,2}[^a-zA-Z/%~-]",
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
], [
r"第[零一二三四五六七八九十百0-9]+章",
r"第[零一二三四五六七八九十百0-9]+节",
r"[零一二三四五六七八九十百]+[ 、]",
r"[\(][零一二三四五六七八九十百]+[\)]",
r"[\(][0-9]{,2}[\)]",
], [
r"PART (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
r"Chapter (I+V?|VI*|XI|IX|X)",
r"Section [0-9]+",
r"Article [0-9]+"
], [
r"^#[^#]",
r"^##[^#]",
r"^###.*",
r"^####.*",
r"^#####.*",
r"^######.*",
]
]
def random_choices(arr, k):
k = min(len(arr), k)
return random.choices(arr, k=k)
def not_bullet(line):
patt = [
r"0", r"[0-9]+ +[0-9~个只-]", r"[0-9]+\.{2,}"
]
return any([re.match(r, line) for r in patt])
def bullets_category(sections):
global BULLET_PATTERN
hits = [0] * len(BULLET_PATTERN)
for i, pro in enumerate(BULLET_PATTERN):
for sec in sections:
sec = sec.strip()
for p in pro:
if re.match(p, sec) and not not_bullet(sec):
hits[i] += 1
break
maxium = 0
res = -1
for i, h in enumerate(hits):
if h <= maxium:
continue
res = i
maxium = h
return res
def is_english(texts):
if not texts:
return False
pattern = re.compile(r"[`a-zA-Z0-9\s.,':;/\"?<>!\(\)\-]")
if isinstance(texts, str):
texts = list(texts)
elif isinstance(texts, list):
texts = [t for t in texts if isinstance(t, str) and t.strip()]
else:
return False
if not texts:
return False
eng = sum(1 for t in texts if pattern.fullmatch(t.strip()))
return (eng / len(texts)) > 0.8
def is_chinese(text):
if not text:
return False
chinese = 0
for ch in text:
if '\u4e00' <= ch <= '\u9fff':
chinese += 1
if chinese / len(text) > 0.2:
return True
return False
def tokenize(d, t, eng):
d["content_with_weight"] = t
t = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", t)
d["content_ltks"] = rag_tokenizer.tokenize(t)
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
res = []
# wrap up as es documents
for ii, ck in enumerate(chunks):
if len(ck.strip()) == 0:
continue
logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc)
if pdf_parser:
try:
d["image"], poss = pdf_parser.crop(ck, need_position=True)
add_positions(d, poss)
ck = pdf_parser.remove_tag(ck)
except NotImplementedError:
pass
else:
add_positions(d, [[ii]*5])
tokenize(d, ck, eng)
res.append(d)
return res
def tokenize_chunks_with_images(chunks, doc, eng, images):
res = []
# wrap up as es documents
for ii, (ck, image) in enumerate(zip(chunks, images)):
if len(ck.strip()) == 0:
continue
logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc)
d["image"] = image
add_positions(d, [[ii]*5])
tokenize(d, ck, eng)
res.append(d)
return res
def tokenize_table(tbls, doc, eng, batch_size=10):
res = []
# add tables
for (img, rows), poss in tbls:
if not rows:
continue
if isinstance(rows, str):
d = copy.deepcopy(doc)
tokenize(d, rows, eng)
d["content_with_weight"] = rows
if img:
d["image"] = img
d["doc_type_kwd"] = "image"
if poss:
add_positions(d, poss)
res.append(d)
continue
de = "; " if eng else " "
for i in range(0, len(rows), batch_size):
d = copy.deepcopy(doc)
r = de.join(rows[i:i + batch_size])
tokenize(d, r, eng)
if img:
d["image"] = img
d["doc_type_kwd"] = "image"
add_positions(d, poss)
res.append(d)
return res
def add_positions(d, poss):
if not poss:
return
page_num_int = []
position_int = []
top_int = []
for pn, left, right, top, bottom in poss:
page_num_int.append(int(pn + 1))
top_int.append(int(top))
position_int.append((int(pn + 1), int(left), int(right), int(top), int(bottom)))
d["page_num_int"] = page_num_int
d["position_int"] = position_int
d["top_int"] = top_int
def remove_contents_table(sections, eng=False):
i = 0
while i < len(sections):
def get(i):
nonlocal sections
return (sections[i] if isinstance(sections[i],
type("")) else sections[i][0]).strip()
if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$",
re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], flags=re.IGNORECASE)):
i += 1
continue
sections.pop(i)
if i >= len(sections):
break
prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2])
while not prefix:
sections.pop(i)
if i >= len(sections):
break
prefix = get(i)[:3] if not eng else " ".join(get(i).split()[:2])
sections.pop(i)
if i >= len(sections) or not prefix:
break
for j in range(i, min(i + 128, len(sections))):
if not re.match(prefix, get(j)):
continue
for _ in range(i, j):
sections.pop(i)
break
def make_colon_as_title(sections):
if not sections:
return []
if isinstance(sections[0], type("")):
return sections
i = 0
while i < len(sections):
txt, layout = sections[i]
i += 1
txt = txt.split("@")[0].strip()
if not txt:
continue
if txt[-1] not in ":":
continue
txt = txt[::-1]
arr = re.split(r"([。?!!?;]| \.)", txt)
if len(arr) < 2 or len(arr[1]) < 32:
continue
sections.insert(i - 1, (arr[0][::-1], "title"))
i += 1
def title_frequency(bull, sections):
bullets_size = len(BULLET_PATTERN[bull])
levels = [bullets_size + 1 for _ in range(len(sections))]
if not sections or bull < 0:
return bullets_size + 1, levels
for i, (txt, layout) in enumerate(sections):
for j, p in enumerate(BULLET_PATTERN[bull]):
if re.match(p, txt.strip()) and not not_bullet(txt):
levels[i] = j
break
else:
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
levels[i] = bullets_size
most_level = bullets_size + 1
for level, c in sorted(Counter(levels).items(), key=lambda x: x[1] * -1):
if level <= bullets_size:
most_level = level
break
return most_level, levels
def not_title(txt):
if re.match(r"第[零一二三四五六七八九十百0-9]+条", txt):
return False
if len(txt.split()) > 12 or (txt.find(" ") < 0 and len(txt) >= 32):
return True
return re.search(r"[,;,。;!!]", txt)
def tree_merge(bull, sections, depth):
if not sections or bull < 0:
return sections
if isinstance(sections[0], type("")):
sections = [(s, "") for s in sections]
# filter out position information in pdf sections
sections = [(t, o) for t, o in sections if
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
def get_level(bull, section):
text, layout = section
text = re.sub(r"\u3000", " ", text).strip()
for i, title in enumerate(BULLET_PATTERN[bull]):
if re.match(title, text.strip()):
return i+1, text
else:
if re.search(r"(title|head)", layout) and not not_title(text):
return len(BULLET_PATTERN[bull])+1, text
else:
return len(BULLET_PATTERN[bull])+2, text
level_set = set()
lines = []
for section in sections:
level, text = get_level(bull, section)
if not text.strip("\n"):
continue
lines.append((level, text))
level_set.add(level)
sorted_levels = sorted(list(level_set))
if depth <= len(sorted_levels):
target_level = sorted_levels[depth - 1]
else:
target_level = sorted_levels[-1]
if target_level == len(BULLET_PATTERN[bull]) + 2:
target_level = sorted_levels[-2] if len(sorted_levels) > 1 else sorted_levels[0]
root = Node(level=0, depth=target_level, texts=[])
root.build_tree(lines)
return [("\n").join(element) for element in root.get_tree() if element]
def hierarchical_merge(bull, sections, depth):
if not sections or bull < 0:
return []
if isinstance(sections[0], type("")):
sections = [(s, "") for s in sections]
sections = [(t, o) for t, o in sections if
t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
bullets_size = len(BULLET_PATTERN[bull])
levels = [[] for _ in range(bullets_size + 2)]
for i, (txt, layout) in enumerate(sections):
for j, p in enumerate(BULLET_PATTERN[bull]):
if re.match(p, txt.strip()):
levels[j].append(i)
break
else:
if re.search(r"(title|head)", layout) and not not_title(txt):
levels[bullets_size].append(i)
else:
levels[bullets_size + 1].append(i)
sections = [t for t, _ in sections]
# for s in sections: print("--", s)
def binary_search(arr, target):
if not arr:
return -1
if target > arr[-1]:
return len(arr) - 1
if target < arr[0]:
return -1
s, e = 0, len(arr)
while e - s > 1:
i = (e + s) // 2
if target > arr[i]:
s = i
continue
elif target < arr[i]:
e = i
continue
else:
assert False
return s
cks = []
readed = [False] * len(sections)
levels = levels[::-1]
for i, arr in enumerate(levels[:depth]):
for j in arr:
if readed[j]:
continue
readed[j] = True
cks.append([j])
if i + 1 == len(levels) - 1:
continue
for ii in range(i + 1, len(levels)):
jj = binary_search(levels[ii], j)
if jj < 0:
continue
if levels[ii][jj] > cks[-1][-1]:
cks[-1].pop(-1)
cks[-1].append(levels[ii][jj])
for ii in cks[-1]:
readed[ii] = True
if not cks:
return cks
for i in range(len(cks)):
cks[i] = [sections[j] for j in cks[i][::-1]]
logging.debug("\n* ".join(cks[i]))
res = [[]]
num = [0]
for ck in cks:
if len(ck) == 1:
n = num_tokens_from_string(re.sub(r"@@[0-9]+.*", "", ck[0]))
if n + num[-1] < 218:
res[-1].append(ck[0])
num[-1] += n
continue
res.append(ck)
num.append(n)
continue
res.append(ck)
num.append(218)
return res
def naive_merge(sections: str | list, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
from app.core.rag.deepdoc.parser.pdf_parser import RAGPdfParser
if not sections:
return []
if isinstance(sections, str):
sections = [sections]
if isinstance(sections[0], str):
sections = [(s, "") for s in sections]
cks = [""]
tk_nums = [0]
def add_chunk(t, pos):
nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if not pos:
pos = ""
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks:
overlapped = RAGPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
if t.find(pos) < 0:
t += pos
cks.append(t)
tk_nums.append(tnum)
else:
if cks[-1].find(pos) < 0:
t += pos
cks[-1] += t
tk_nums[-1] += tnum
dels = get_delimiters(delimiter)
for sec, pos in sections:
if num_tokens_from_string(sec) < chunk_token_num:
add_chunk("\n"+sec, pos)
continue
split_sec = re.split(r"(%s)" % dels, sec, flags=re.DOTALL)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk("\n"+sub_sec, pos)
return cks
def naive_merge_with_images(texts, images, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
from app.core.rag.deepdoc.parser.pdf_parser import RAGPdfParser
if not texts or len(texts) != len(images):
return [], []
cks = [""]
result_images = [None]
tk_nums = [0]
def add_chunk(t, image, pos=""):
nonlocal cks, result_images, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if not pos:
pos = ""
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks:
overlapped = RAGPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
if t.find(pos) < 0:
t += pos
cks.append(t)
result_images.append(image)
tk_nums.append(tnum)
else:
if cks[-1].find(pos) < 0:
t += pos
cks[-1] += t
if result_images[-1] is None:
result_images[-1] = image
else:
result_images[-1] = concat_img(result_images[-1], image)
tk_nums[-1] += tnum
dels = get_delimiters(delimiter)
for text, image in zip(texts, images):
# if text is tuple, unpack it
if isinstance(text, tuple):
text_str = text[0]
text_pos = text[1] if len(text) > 1 else ""
split_sec = re.split(r"(%s)" % dels, text_str)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk("\n"+sub_sec, image, text_pos)
else:
split_sec = re.split(r"(%s)" % dels, text)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk("\n"+sub_sec, image)
return cks, result_images
def docx_question_level(p, bull=-1):
txt = re.sub(r"\u3000", " ", p.text).strip()
if p.style.name.startswith('Heading'):
return int(p.style.name.split(' ')[-1]), txt
else:
if bull < 0:
return 0, txt
for j, title in enumerate(BULLET_PATTERN[bull]):
if re.match(title, txt):
return j + 1, txt
return len(BULLET_PATTERN[bull])+1, txt
def concat_img(img1, img2):
if img1 and not img2:
return img1
if not img1 and img2:
return img2
if not img1 and not img2:
return None
if img1 is img2:
return img1
if isinstance(img1, Image.Image) and isinstance(img2, Image.Image):
pixel_data1 = img1.tobytes()
pixel_data2 = img2.tobytes()
if pixel_data1 == pixel_data2:
return img1
width1, height1 = img1.size
width2, height2 = img2.size
new_width = max(width1, width2)
new_height = height1 + height2
new_image = Image.new('RGB', (new_width, new_height))
new_image.paste(img1, (0, 0))
new_image.paste(img2, (0, height1))
return new_image
def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"):
if not sections:
return [], []
cks = [""]
images = [None]
tk_nums = [0]
def add_chunk(t, image, pos=""):
nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if tnum < 8:
pos = ""
if cks[-1] == "" or tk_nums[-1] > chunk_token_num:
if t.find(pos) < 0:
t += pos
cks.append(t)
images.append(image)
tk_nums.append(tnum)
else:
if cks[-1].find(pos) < 0:
t += pos
cks[-1] += t
images[-1] = concat_img(images[-1], image)
tk_nums[-1] += tnum
dels = get_delimiters(delimiter)
line = ""
for sec, image in sections:
if not image:
line += sec + "\n"
continue
split_sec = re.split(r"(%s)" % dels, line + sec)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk("\n"+sub_sec, image,"")
line = ""
if line:
split_sec = re.split(r"(%s)" % dels, line)
for sub_sec in split_sec:
if re.match(f"^{dels}$", sub_sec):
continue
add_chunk("\n"+sub_sec, image,"")
return cks, images
def extract_between(text: str, start_tag: str, end_tag: str) -> list[str]:
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
return re.findall(pattern, text, flags=re.DOTALL)
def get_delimiters(delimiters: str):
dels = []
s = 0
for m in re.finditer(r"`([^`]+)`", delimiters, re.I):
f, t = m.span()
dels.append(m.group(1))
dels.extend(list(delimiters[s: f]))
s = t
if s < len(delimiters):
dels.extend(list(delimiters[s:]))
dels.sort(key=lambda x: -len(x))
dels = [re.escape(d) for d in dels if d]
dels = [d for d in dels if d]
dels_pattern = "|".join(dels)
return dels_pattern
class Node:
def __init__(self, level, depth=-1, texts=None):
self.level = level
self.depth = depth
self.texts = texts or []
self.children = []
def add_child(self, child_node):
self.children.append(child_node)
def get_children(self):
return self.children
def get_level(self):
return self.level
def get_texts(self):
return self.texts
def set_texts(self, texts):
self.texts = texts
def add_text(self, text):
self.texts.append(text)
def clear_text(self):
self.texts = []
def __repr__(self):
return f"Node(level={self.level}, texts={self.texts}, children={len(self.children)})"
def build_tree(self, lines):
stack = [self]
for level, text in lines:
if self.depth != -1 and level > self.depth:
# Beyond target depth: merge content into the current leaf instead of creating deeper nodes
stack[-1].add_text(text)
continue
# Move up until we find the proper parent whose level is strictly smaller than current
while len(stack) > 1 and level <= stack[-1].get_level():
stack.pop()
node = Node(level=level, texts=[text])
# Attach as child of current parent and descend
stack[-1].add_child(node)
stack.append(node)
return self
def get_tree(self):
tree_list = []
self._dfs(self, tree_list, [])
return tree_list
def _dfs(self, node, tree_list, titles):
level = node.get_level()
texts = node.get_texts()
child = node.get_children()
if level == 0 and texts:
tree_list.append("\n".join(titles+texts))
# Titles within configured depth are accumulated into the current path
if 1 <= level <= self.depth:
path_titles = titles + texts
else:
path_titles = titles
# Body outside the depth limit becomes its own chunk under the current title path
if level > self.depth and texts:
tree_list.append("\n".join(path_titles + texts))
# A leaf title within depth emits its title path as a chunk (header-only section)
elif not child and (1 <= level <= self.depth):
tree_list.append("\n".join(path_titles))
# Recurse into children with the updated title path
for c in child:
self._dfs(c, tree_list, path_titles)

View File

@@ -0,0 +1,261 @@
import logging
import json
import re
from collections import defaultdict
from app.core.rag.utils.doc_store_conn import MatchTextExpr
from . import rag_tokenizer, term_weight, synonym
class FulltextQueryer:
def __init__(self):
self.tw = term_weight.Dealer()
self.syn = synonym.Dealer()
self.query_fields = [
"title_tks^10",
"title_sm_tks^5",
"important_kwd^30",
"important_tks^20",
"question_tks^20",
"content_ltks^2",
"content_sm_ltks",
]
@staticmethod
def subSpecialChar(line):
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
@staticmethod
def isChinese(line):
arr = re.split(r"[ \t]+", line)
if len(arr) <= 3:
return True
e = 0
for t in arr:
if not re.match(r"[a-zA-Z]+$", t):
e += 1
return e * 1.0 / len(arr) >= 0.7
@staticmethod
def rmWWW(txt):
patts = [
(
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
"",
),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
" ")
]
otxt = txt
for r, p in patts:
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
if not txt:
txt = otxt
return txt
@staticmethod
def add_space_between_eng_zh(txt):
# (ENG/ENG+NUM) + ZH
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ENG + ZH
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
# ZH + (ENG/ENG+NUM)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
return txt
def question(self, txt, tbl="qa", min_match: float = 0.6):
txt = FulltextQueryer.add_space_between_eng_zh(txt)
txt = re.sub(
r"[ :|\r\n\t,,。??/`!&^%%()\[\]{}<>]+",
" ",
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
).strip()
otxt = txt
txt = FulltextQueryer.rmWWW(txt)
if not self.isChinese(txt):
txt = FulltextQueryer.rmWWW(txt)
tks = rag_tokenizer.tokenize(txt).split()
keywords = [t for t in tks if t]
tks_w = self.tw.weights(tks, preprocess=False)
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()]
syns = []
for tk, w in tks_w[:256]:
syn = self.syn.lookup(tk)
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
keywords.extend(syn)
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
syns.append(" ".join(syn))
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
tk and not re.match(r"[.^+\(\)-]", tk)]
for i in range(1, len(tks_w)):
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
if not left or not right:
continue
q.append(
'"%s %s"^%.4f'
% (
tks_w[i - 1][0],
tks_w[i][0],
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
)
)
if not q:
q.append(txt)
query = " ".join(q)
return MatchTextExpr(
self.query_fields, query, 100
), keywords
def need_fine_grained_tokenize(tk):
if len(tk) < 3:
return False
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
return False
return True
txt = FulltextQueryer.rmWWW(txt)
qs, keywords = [], []
for tt in self.tw.split(txt)[:256]: # .split():
if not tt:
continue
keywords.append(tt)
twts = self.tw.weights([tt])
syns = self.syn.lookup(tt)
if syns and len(keywords) < 32:
keywords.extend(syns)
logging.debug(json.dumps(twts, ensure_ascii=False))
tms = []
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
sm = (
rag_tokenizer.fine_grained_tokenize(tk).split()
if need_fine_grained_tokenize(tk)
else []
)
sm = [
re.sub(
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
"",
m,
)
for m in sm
]
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
sm = [m for m in sm if len(m) > 1]
if len(keywords) < 32:
keywords.append(re.sub(r"[ \\\"']+", "", tk))
keywords.extend(sm)
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
if len(keywords) >= 32:
break
tk = FulltextQueryer.subSpecialChar(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
if sm:
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
if tk.strip():
tms.append((tk, w))
tms = " ".join([f"({t})^{w}" for t, w in tms])
if len(twts) > 1:
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
syns = " OR ".join(
[
'"%s"'
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
for s in syns
]
)
if syns and tms:
tms = f"({tms})^5 OR ({syns})^0.7"
qs.append(tms)
if qs:
query = " OR ".join([f"({t})" for t in qs if t])
if not query:
query = otxt
return MatchTextExpr(
self.query_fields, query, 100, {"minimum_should_match": min_match}
), keywords
return None, keywords
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np
sims = CosineSimilarity([avec], bvecs)
tksim = self.token_similarity(atks, btkss)
if np.sum(sims[0]) == 0:
return np.array(tksim), tksim, sims[0]
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
def token_similarity(self, atks, btkss):
def toDict(tks):
if isinstance(tks, str):
tks = tks.split()
d = defaultdict(int)
wts = self.tw.weights(tks, preprocess=False)
for i, (t, c) in enumerate(wts):
d[t] += c
return d
atks = toDict(atks)
btkss = [toDict(tks) for tks in btkss]
return [self.similarity(atks, btks) for btks in btkss]
def similarity(self, qtwt, dtwt):
if isinstance(dtwt, type("")):
dtwt = {t: w for t, w in self.tw.weights(self.tw.split(dtwt), preprocess=False)}
if isinstance(qtwt, type("")):
qtwt = {t: w for t, w in self.tw.weights(self.tw.split(qtwt), preprocess=False)}
s = 1e-9
for k, v in qtwt.items():
if k in dtwt:
s += v #* dtwt[k]
q = 1e-9
for k, v in qtwt.items():
q += v #* v
return s/q #math.sqrt(3. * (s / q / math.log10( len(dtwt.keys()) + 512 )))
def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
if isinstance(content_tks, str):
content_tks = [c.strip() for c in content_tks.strip() if c.strip()]
tks_w = self.tw.weights(content_tks, preprocess=False)
keywords = [f'"{k.strip()}"' for k in keywords]
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
tk = FulltextQueryer.subSpecialChar(tk)
if tk.find(" ") > 0:
tk = '"%s"' % tk
if tk_syns:
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
if tk:
keywords.append(f"{tk}^{w}")
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
{"minimum_should_match": min(3, len(keywords) // 10)})

View File

@@ -0,0 +1,499 @@
import logging
import copy
import datrie
import math
import os
import re
import string
import sys
from hanziconv import HanziConv
from nltk import word_tokenize
from nltk.stem import PorterStemmer, WordNetLemmatizer
from app.core.rag.common.file_utils import get_project_base_directory
class RagTokenizer:
def key_(self, line):
return str(line.lower().encode("utf-8"))[2:-1]
def rkey_(self, line):
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
def loadDict_(self, fnm):
logging.info(f"[HUQIE]:Build trie from {fnm}")
try:
of = open(fnm, "r", encoding='utf-8')
while True:
line = of.readline()
if not line:
break
line = re.sub(r"[\r\n]+", "", line)
line = re.split(r"[ \t]", line)
k = self.key_(line[0])
F = int(math.log(float(line[1]) / self.DENOMINATOR) + .5)
if k not in self.trie_ or self.trie_[k][0] < F:
self.trie_[self.key_(line[0])] = (F, line[2])
self.trie_[self.rkey_(line[0])] = 1
trie_file_name = os.path.join(get_project_base_directory(), "res", "huqie") + ".txt.trie"
logging.info(f"[HUQIE]:Build trie cache to {trie_file_name}")
self.trie_.save(trie_file_name)
of.close()
except Exception:
logging.exception(f"[HUQIE]:Build trie {fnm} failed")
def __init__(self, debug=False):
self.DEBUG = debug
self.DENOMINATOR = 1000000
self.stemmer = PorterStemmer()
self.lemmatizer = WordNetLemmatizer()
self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)"
trie_file_name = os.path.join(get_project_base_directory(), "res", "huqie") + ".txt.trie"
# check if trie file existence
if os.path.exists(trie_file_name):
try:
# load trie from file
self.trie_ = datrie.Trie.load(trie_file_name)
return
except Exception:
# fail to load trie from file, build default trie
logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
self.trie_ = datrie.Trie(string.printable)
else:
# file not exist, build default trie
logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
self.trie_ = datrie.Trie(string.printable)
# load data from dict file and save to trie file
self.loadDict_(os.path.join(get_project_base_directory(), "app/core/rag/res", "huqie") + ".txt")
def loadUserDict(self, fnm):
try:
self.trie_ = datrie.Trie.load(fnm + ".trie")
return
except Exception:
self.trie_ = datrie.Trie(string.printable)
self.loadDict_(fnm)
def addUserDict(self, fnm):
self.loadDict_(fnm)
def _strQ2B(self, ustring):
"""Convert full-width characters to half-width characters"""
rstring = ""
for uchar in ustring:
inside_code = ord(uchar)
if inside_code == 0x3000:
inside_code = 0x0020
else:
inside_code -= 0xfee0
if inside_code < 0x0020 or inside_code > 0x7e: # After the conversion, if it's not a half-width character, return the original character.
rstring += uchar
else:
rstring += chr(inside_code)
return rstring
def _tradi2simp(self, line):
return HanziConv.toSimplified(line)
def dfs_(self, chars, s, preTks, tkslist, _depth=0, _memo=None):
if _memo is None:
_memo = {}
MAX_DEPTH = 10
if _depth > MAX_DEPTH:
if s < len(chars):
copy_pretks = copy.deepcopy(preTks)
remaining = "".join(chars[s:])
copy_pretks.append((remaining, (-12, '')))
tkslist.append(copy_pretks)
return s
state_key = (s, tuple(tk[0] for tk in preTks)) if preTks else (s, None)
if state_key in _memo:
return _memo[state_key]
res = s
if s >= len(chars):
tkslist.append(preTks)
_memo[state_key] = s
return s
if s < len(chars) - 4:
is_repetitive = True
char_to_check = chars[s]
for i in range(1, 5):
if s + i >= len(chars) or chars[s + i] != char_to_check:
is_repetitive = False
break
if is_repetitive:
end = s
while end < len(chars) and chars[end] == char_to_check:
end += 1
mid = s + min(10, end - s)
t = "".join(chars[s:mid])
k = self.key_(t)
copy_pretks = copy.deepcopy(preTks)
if k in self.trie_:
copy_pretks.append((t, self.trie_[k]))
else:
copy_pretks.append((t, (-12, '')))
next_res = self.dfs_(chars, mid, copy_pretks, tkslist, _depth + 1, _memo)
res = max(res, next_res)
_memo[state_key] = res
return res
S = s + 1
if s + 2 <= len(chars):
t1 = "".join(chars[s:s + 1])
t2 = "".join(chars[s:s + 2])
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
S = s + 2
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
t1 = preTks[-1][0] + "".join(chars[s:s + 1])
if self.trie_.has_keys_with_prefix(self.key_(t1)):
S = s + 2
for e in range(S, len(chars) + 1):
t = "".join(chars[s:e])
k = self.key_(t)
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
break
if k in self.trie_:
pretks = copy.deepcopy(preTks)
pretks.append((t, self.trie_[k]))
res = max(res, self.dfs_(chars, e, pretks, tkslist, _depth + 1, _memo))
if res > s:
_memo[state_key] = res
return res
t = "".join(chars[s:s + 1])
k = self.key_(t)
copy_pretks = copy.deepcopy(preTks)
if k in self.trie_:
copy_pretks.append((t, self.trie_[k]))
else:
copy_pretks.append((t, (-12, '')))
result = self.dfs_(chars, s + 1, copy_pretks, tkslist, _depth + 1, _memo)
_memo[state_key] = result
return result
def freq(self, tk):
k = self.key_(tk)
if k not in self.trie_:
return 0
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
def tag(self, tk):
k = self.key_(tk)
if k not in self.trie_:
return ""
return self.trie_[k][1]
def score_(self, tfts):
B = 30
F, L, tks = 0, 0, []
for tk, (freq, tag) in tfts:
F += freq
L += 0 if len(tk) < 2 else 1
tks.append(tk)
#F /= len(tks)
L /= len(tks)
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
return tks, B / len(tks) + L + F
def sortTks_(self, tkslist):
res = []
for tfts in tkslist:
tks, s = self.score_(tfts)
res.append((tks, s))
return sorted(res, key=lambda x: x[1], reverse=True)
def merge_(self, tks):
# if split chars is part of token
res = []
tks = re.sub(r"[ ]+", " ", tks).split()
s = 0
while True:
if s >= len(tks):
break
E = s + 1
for e in range(s + 2, min(len(tks) + 2, s + 6)):
tk = "".join(tks[s:e])
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
E = e
res.append("".join(tks[s:E]))
s = E
return " ".join(res)
def maxForward_(self, line):
res = []
s = 0
while s < len(line):
e = s + 1
t = line[s:e]
while e < len(line) and self.trie_.has_keys_with_prefix(
self.key_(t)):
e += 1
t = line[s:e]
while e - 1 > s and self.key_(t) not in self.trie_:
e -= 1
t = line[s:e]
if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)]))
else:
res.append((t, (0, '')))
s = e
return self.score_(res)
def maxBackward_(self, line):
res = []
s = len(line) - 1
while s >= 0:
e = s + 1
t = line[s:e]
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
s -= 1
t = line[s:e]
while s + 1 < e and self.key_(t) not in self.trie_:
s += 1
t = line[s:e]
if self.key_(t) in self.trie_:
res.append((t, self.trie_[self.key_(t)]))
else:
res.append((t, (0, '')))
s -= 1
return self.score_(res[::-1])
def english_normalize_(self, tks):
return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
def _split_by_lang(self, line):
txt_lang_pairs = []
arr = re.split(self.SPLIT_CHAR, line)
for a in arr:
if not a:
continue
s = 0
e = s + 1
zh = is_chinese(a[s])
while e < len(a):
_zh = is_chinese(a[e])
if _zh == zh:
e += 1
continue
txt_lang_pairs.append((a[s: e], zh))
s = e
e = s + 1
zh = _zh
if s >= len(a):
continue
txt_lang_pairs.append((a[s: e], zh))
return txt_lang_pairs
def tokenize(self, line):
line = re.sub(r"\W+", " ", line)
line = self._strQ2B(line).lower()
line = self._tradi2simp(line)
arr = self._split_by_lang(line)
res = []
for L,lang in arr:
if not lang:
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
continue
if len(L) < 2 or re.match(
r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
res.append(L)
continue
# use maxforward for the first time
tks, s = self.maxForward_(L)
tks1, s1 = self.maxBackward_(L)
if self.DEBUG:
logging.debug("[FW] {} {}".format(tks, s))
logging.debug("[BW] {} {}".format(tks1, s1))
i, j, _i, _j = 0, 0, 0, 0
same = 0
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
if same > 0:
res.append(" ".join(tks[j: j + same]))
_i = i + same
_j = j + same
j = _j + 1
i = _i + 1
while i < len(tks1) and j < len(tks):
tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
if tk1 != tk:
if len(tk1) > len(tk):
j += 1
else:
i += 1
continue
if tks1[i] != tks[j]:
i += 1
j += 1
continue
# backward tokens from_i to i are different from forward tokens from _j to j.
tkslist = []
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
same = 1
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
res.append(" ".join(tks[j: j + same]))
_i = i + same
_j = j + same
j = _j + 1
i = _i + 1
if _i < len(tks1):
assert _j < len(tks)
assert "".join(tks1[_i:]) == "".join(tks[_j:])
tkslist = []
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
res = " ".join(res)
logging.debug("[TKS] {}".format(self.merge_(res)))
return self.merge_(res)
def fine_grained_tokenize(self, tks):
tks = tks.split()
zh_num = len([1 for c in tks if c and is_chinese(c[0])])
if zh_num < len(tks) * 0.2:
res = []
for tk in tks:
res.extend(tk.split("/"))
return " ".join(res)
res = []
for tk in tks:
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
res.append(tk)
continue
tkslist = []
if len(tk) > 10:
tkslist.append(tk)
else:
self.dfs_(tk, 0, [], tkslist)
if len(tkslist) < 2:
res.append(tk)
continue
stk = self.sortTks_(tkslist)[1][0]
if len(stk) == len(tk):
stk = tk
else:
if re.match(r"[a-z\.-]+$", tk):
for t in stk:
if len(t) < 3:
stk = tk
break
else:
stk = " ".join(stk)
else:
stk = " ".join(stk)
res.append(stk)
return " ".join(self.english_normalize_(res))
def is_chinese(s):
if s >= u'\u4e00' and s <= u'\u9fa5':
return True
else:
return False
def is_number(s):
if s >= u'\u0030' and s <= u'\u0039':
return True
else:
return False
def is_alphabet(s):
if (s >= u'\u0041' and s <= u'\u005a') or (
s >= u'\u0061' and s <= u'\u007a'):
return True
else:
return False
def naiveQie(txt):
tks = []
for t in txt.split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
) and re.match(r".*[a-zA-Z]$", t):
tks.append(" ")
tks.append(t)
return tks
tokenizer = RagTokenizer()
tokenize = tokenizer.tokenize
fine_grained_tokenize = tokenizer.fine_grained_tokenize
tag = tokenizer.tag
freq = tokenizer.freq
loadUserDict = tokenizer.loadUserDict
addUserDict = tokenizer.addUserDict
tradi2simp = tokenizer._tradi2simp
strQ2B = tokenizer._strQ2B
if __name__ == '__main__':
tknzr = RagTokenizer(debug=True)
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
tks = tknzr.tokenize(
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("虽然我不怎么玩")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
logging.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize(
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
logging.info(tknzr.fine_grained_tokenize(tks))
if len(sys.argv) < 2:
sys.exit()
tknzr.DEBUG = False
tknzr.loadUserDict(sys.argv[1])
of = open(sys.argv[2], "r")
while True:
line = of.readline()
if not line:
break
logging.info(tknzr.tokenize(line))
of.close()

View File

@@ -0,0 +1,192 @@
import uuid
from typing import Dict, List, Any
from sqlalchemy.orm import Session
from langchain_core.documents import Document
from app.db import get_db
from app.core.models.base import RedBearModelConfig
from app.core.models import RedBearLLM, RedBearRerank
from app.models.models_model import ModelApiKey
from app.models import knowledge_model
from app.core.rag.models.chunk import DocumentChunk
from app.repositories import knowledge_repository, knowledgeshare_repository
from app.services.model_service import ModelConfigService
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
def knowledge_retrieval(
query: str,
config: Dict[str, Any],
user_ids: List[str] = None,
) -> list[DocumentChunk]:
"""
Knowledge retrieval with multiple knowledge bases and reranking
Args:
query: Search query string
config: Configuration dictionary containing:
- knowledge_bases: List of knowledge base configs with:
- kb_id: Knowledge base ID
- similarity_threshold: float
- vector_similarity_weight: float
- top_k: int
- retrieve_type: "participle" or "semantic" or "hybrid"
- merge_strategy: "weight" or other strategies
- reranker_id: UUID of the reranker to use
- reranker_top_k: int
Returns:
Rearranged document block list (in descending order of relevance)
"""
db = next(get_db()) # Manually call the generator
try:
# parse configuration
knowledge_bases = config.get("knowledge_bases", [])
merge_strategy = config.get("merge_strategy", "weight")
reranker_id = config.get("reranker_id")
reranker_top_k = config.get("reranker_top_k", 1024)
file_names_filter=[]
if user_ids:
file_names_filter.extend([f"{user_id}.txt" for user_id in user_ids])
if not knowledge_bases:
return []
all_results = []
# Search each knowledge base
for kb_config in knowledge_bases:
kb_id = kb_config["kb_id"]
try:
# Check whether the knowledge base exists and is available
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id)
if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1:
# Process shared knowledge base
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db,
knowledgeshare_id=db_knowledge.id)
if knowledgeshare:
db_knowledge = knowledge_repository.get_knowledge_by_id(db,
knowledge_id=knowledgeshare.source_kb_id)
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
continue
else:
continue
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# Retrieve according to the configured retrieval type
match kb_config["retrieve_type"]:
case "participle":
rs = vector_service.search_by_full_text(
query=query,
top_k=kb_config["top_k"],
score_threshold=kb_config["similarity_threshold"],
file_names_filter=file_names_filter
)
case "semantic":
rs = vector_service.search_by_vector(
query=query,
top_k=kb_config["top_k"],
score_threshold=kb_config["vector_similarity_weight"],
file_names_filter=file_names_filter
)
case _: # hybrid
rs1 = vector_service.search_by_vector(
query=query,
top_k=kb_config["top_k"],
score_threshold=kb_config["vector_similarity_weight"],
file_names_filter=file_names_filter
)
rs2 = vector_service.search_by_full_text(
query=query,
top_k=kb_config["top_k"],
score_threshold=kb_config["similarity_threshold"],
file_names_filter=file_names_filter
)
# Deduplication of merge results
seen_ids = set()
unique_rs = []
for doc in rs1 + rs2:
if doc.metadata["doc_id"] not in seen_ids:
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = unique_rs
all_results.extend(rs)
except Exception as e:
# Failure of retrieval in a single knowledge base does not affect other knowledge bases
print(f"retrieval knowledge({kb_id}) failed: {str(e)}")
continue
# Use the specified reranker for re-ranking
if reranker_id:
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
return all_results
except Exception as e:
print(f"retrieval knowledge failed: {str(e)}")
finally:
db.close()
def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
"""
Reorder the list of document blocks and return the top_k results most relevant to the query
Args:
reranker_id: reranker model id
query: query string
docs: List of document blocks to be rearranged
top_k: Number of top-level documents returned
Returns:
Rearranged document block list (in descending order of relevance)
Raises:
ValueError: If the input document list is empty or top_k is invalid
"""
# 参数校验
if not reranker_id:
raise ValueError("reranker_id be empty")
if not docs:
raise ValueError("retrieval chunks be empty")
if top_k <= 0:
raise ValueError("top_k must be a positive integer")
try:
# initialize reranker
config = ModelConfigService.get_model_by_id(db=db, model_id=reranker_id)
apiConfig: ModelApiKey = config.api_keys[0]
reranker = RedBearRerank(RedBearModelConfig(
model_name=apiConfig.model_name,
provider=apiConfig.provider,
api_key=apiConfig.api_key,
base_url=apiConfig.api_base
))
# Convert to LangChain Document object
documents = [
Document(
page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
metadata=doc.metadata or {} # Deal with possible None metadata
)
for doc in docs
]
# Perform reordering (compress_documents will automatically handle relevance scores and indexing)
reranked_docs = list(reranker.compress_documents(documents, query))
print(reranked_docs)
# Sort in descending order based on relevance score
reranked_docs.sort(
key=lambda x: x.metadata.get("relevance_score", 0),
reverse=True
)
# Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
result = []
for item in reranked_docs[:top_k]:
for doc in docs:
if doc.page_content == item.page_content:
doc.metadata["score"] = item.metadata["relevance_score"]
result.append(doc)
return result
except Exception as e:
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e

View File

@@ -0,0 +1,126 @@
m = set(["","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","羿","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","宿","","怀",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","寿","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"广","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","",
"","","","西","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","","鹿","",
"万俟","司马","上官","欧阳",
"夏侯","诸葛","闻人","东方",
"赫连","皇甫","尉迟","公羊",
"澹台","公冶","宗政","濮阳",
"淳于","单于","太叔","申屠",
"公孙","仲孙","轩辕","令狐",
"钟离","宇文","长孙","慕容",
"鲜于","闾丘","司徒","司空",
"亓官","司寇","仉督","子车",
"颛孙","端木","巫马","公西",
"漆雕","乐正","壤驷","公良",
"拓跋","夹谷","宰父","榖梁",
"","","","","","","","",
"段干","百里","东郭","南门",
"呼延","","","羊舌","","",
"","","","","","","","",
"梁丘","左丘","东门","西门",
"","","","","","","南宫",
"","","","","","","","",
"第五","",""])
def isit(n):return n.strip() in m

View File

@@ -0,0 +1,85 @@
import logging
import json
import os
import time
import re
from nltk.corpus import wordnet
from app.core.rag.common.file_utils import get_project_base_directory
class Dealer:
def __init__(self, redis=None):
self.lookup_num = 100000000
self.load_tm = time.time() - 1000000
self.dictionary = None
path = os.path.join(get_project_base_directory(), "app/core/rag/res", "synonym.json")
try:
self.dictionary = json.load(open(path, 'r'))
self.dictionary = { (k.lower() if isinstance(k, str) else k): v for k, v in self.dictionary.items() }
except Exception:
logging.warning("Missing synonym.json")
self.dictionary = {}
if not redis:
logging.warning(
"Realtime synonym is disabled, since no redis connection.")
if not len(self.dictionary.keys()):
logging.warning("Fail to load synonym")
self.redis = redis
self.load()
def load(self):
if not self.redis:
return
if self.lookup_num < 100:
return
tm = time.time()
if tm - self.load_tm < 3600:
return
self.load_tm = time.time()
self.lookup_num = 0
d = self.redis.get("kevin_synonyms")
if not d:
return
try:
d = json.loads(d)
self.dictionary = d
except Exception as e:
logging.error("Fail to load synonym!" + str(e))
def lookup(self, tk, topn=8):
if not tk or not isinstance(tk, str):
return []
# 1) Check the custom dictionary first (both keys and tk are already lowercase)
self.lookup_num += 1
self.load()
key = re.sub(r"[ \t]+", " ", tk.strip())
res = self.dictionary.get(key, [])
if isinstance(res, str):
res = [res]
if res: # Found in dictionary → return directly
return res[:topn]
# 2) If not found and tk is purely alphabetical → fallback to WordNet
if re.fullmatch(r"[a-z]+", tk):
wn_set = {
re.sub("_", " ", syn.name().split(".")[0])
for syn in wordnet.synsets(tk)
}
wn_set.discard(tk) # Remove the original token itself
wn_res = [t for t in wn_set if t]
return wn_res[:topn]
# 3) Nothing found in either source
return []
if __name__ == '__main__':
dl = Dealer()
print(dl.dictionary)

View File

@@ -0,0 +1,228 @@
import logging
import math
import json
import re
import os
import numpy as np
from . import rag_tokenizer
from app.core.rag.common.file_utils import get_project_base_directory
class Dealer:
def __init__(self):
self.stop_words = set(["请问",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"#",
"什么",
"怎么",
"哪个",
"哪些",
"",
"相关"])
def load_dict(fnm):
res = {}
f = open(fnm, "r")
while True:
line = f.readline()
if not line:
break
arr = line.replace("\n", "").split("\t")
if len(arr) < 2:
res[arr[0]] = 0
else:
res[arr[0]] = int(arr[1])
c = 0
for _, v in res.items():
c += v
if c == 0:
return set(res.keys())
return res
fnm = os.path.join(get_project_base_directory(), "app/core/rag/res")
self.ne, self.df = {}, {}
try:
self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
except Exception:
logging.warning("Load ner.json FAIL!")
try:
self.df = load_dict(os.path.join(fnm, "term.freq"))
except Exception:
logging.warning("Load term.freq FAIL!")
def pretoken(self, txt, num=False, stpwd=True):
patt = [
r"[~—\t @#%!<>,\.\?\":;'\{\}\[\]_=\(\)\|,。?》•●○↓《;‘’:“”【¥ 】…¥!、·()×`&\\/「」\\]"
]
rewt = [
]
for p, r in rewt:
txt = re.sub(p, r, txt)
res = []
for t in rag_tokenizer.tokenize(txt).split():
tk = t
if (stpwd and tk in self.stop_words) or (
re.match(r"[0-9]$", tk) and not num):
continue
for p in patt:
if re.match(p, t):
tk = "#"
break
#tk = re.sub(r"([\+\\-])", r"\\\1", tk)
if tk != "#" and tk:
res.append(tk)
return res
def tokenMerge(self, tks):
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
res, i = [], 0
while i < len(tks):
j = i
if i == 0 and oneTerm(tks[i]) and len(
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
res.append(" ".join(tks[0:2]))
i = 2
continue
while j < len(
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
j += 1
if j - i > 1:
if j - i < 5:
res.append(" ".join(tks[i:j]))
i = j
else:
res.append(" ".join(tks[i:i + 2]))
i = i + 2
else:
if len(tks[i]) > 0:
res.append(tks[i])
i += 1
return [t for t in res if t]
def ner(self, t):
if not self.ne:
return ""
res = self.ne.get(t, "")
if res:
return res
def split(self, txt):
tks = []
for t in re.sub(r"[ \t]+", " ", txt).split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
re.match(r".*[a-zA-Z]$", t) and tks and \
self.ne.get(t, "") != "func" and self.ne.get(tks[-1], "") != "func":
tks[-1] = tks[-1] + " " + t
else:
tks.append(t)
return tks
def weights(self, tks, preprocess=True):
num_pattern = re.compile(r"[0-9,.]{2,}$")
short_letter_pattern = re.compile(r"[a-z]{1,2}$")
num_space_pattern = re.compile(r"[0-9. -]{2,}$")
letter_pattern = re.compile(r"[a-z. -]+$")
def ner(t):
if num_pattern.match(t):
return 2
if short_letter_pattern.match(t):
return 0.01
if not self.ne or t not in self.ne:
return 1
m = {"toxic": 2, "func": 1, "corp": 3, "loca": 3, "sch": 3, "stock": 3,
"firstnm": 1}
return m[self.ne[t]]
def postag(t):
t = rag_tokenizer.tag(t)
if t in set(["r", "c", "d"]):
return 0.3
if t in set(["ns", "nt"]):
return 3
if t in set(["n"]):
return 2
if re.match(r"[0-9-]+", t):
return 2
return 1
def freq(t):
if num_space_pattern.match(t):
return 3
s = rag_tokenizer.freq(t)
if not s and letter_pattern.match(t):
return 300
if not s:
s = 0
if not s and len(t) >= 4:
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
if len(s) > 1:
s = np.min([freq(tt) for tt in s]) / 6.
else:
s = 0
return max(s, 10)
def df(t):
if num_space_pattern.match(t):
return 5
if t in self.df:
return self.df[t] + 3
elif letter_pattern.match(t):
return 300
elif len(t) >= 4:
s = [tt for tt in rag_tokenizer.fine_grained_tokenize(t).split() if len(tt) > 1]
if len(s) > 1:
return max(3, np.min([df(tt) for tt in s]) / 6.)
return 3
def idf(s, N): return math.log10(10 + ((N - s + 0.5) / (s + 0.5)))
tw = []
if not preprocess:
idf1 = np.array([idf(freq(t), 10000000) for t in tks])
idf2 = np.array([idf(df(t), 1000000000) for t in tks])
wts = (0.3 * idf1 + 0.7 * idf2) * \
np.array([ner(t) * postag(t) for t in tks])
wts = [s for s in wts]
tw = list(zip(tks, wts))
else:
for tk in tks:
tt = self.tokenMerge(self.pretoken(tk, True))
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
wts = (0.3 * idf1 + 0.7 * idf2) * \
np.array([ner(t) * postag(t) for t in tt])
wts = [s for s in wts]
tw.extend(zip(tt, wts))
S = np.sum([s for _, s in tw])
return [(t, s / S) for t, s in tw]

View File

@@ -0,0 +1,6 @@
from . import generator
__all__ = [name for name in dir(generator)
if not name.startswith('_')]
globals().update({name: getattr(generator, name) for name in __all__})

View File

@@ -0,0 +1,48 @@
You are an intelligent task analyzer that adapts analysis depth to task complexity.
**Analysis Framework**
**Step 1: Task Transmission Assessment**
**Note**: This section is not subject to word count limitations when transmission is needed, as it serves critical handoff functions.
**Evaluate if task transmission information is needed:**
- **Is this an initial step?** If yes, skip this section
- **Are there upstream agents/steps?** If no, provide minimal transmission
- **Is there critical state/context to preserve?** If yes, include full transmission
### If Task Transmission is Needed:
- **Current State Summary**: [1-2 sentences on where we are]
- **Key Data/Results**: [Critical findings that must carry forward]
- **Context Dependencies**: [Essential context for next agent/step]
- **Unresolved Items**: [Issues requiring continuation]
- **Status for User**: [Clear status update in user terms]
- **Technical State**: [System state for technical handoffs]
**Step 2: Complexity Classification**
Classify as LOW / MEDIUM / HIGH:
- **LOW**: Single-step tasks, direct queries, small talk
- **MEDIUM**: Multi-step tasks within one domain
- **HIGH**: Multi-domain coordination or complex reasoning
**Step 3: Adaptive Analysis**
Scale depth to match complexity. Always stop once success criteria are met.
**For LOW (max 50 words for analysis only):**
- Detect small talk; if true, output exactly: `Small talk — no further analysis needed`
- One-sentence objective
- Direct execution approach (12 steps)
**For MEDIUM (80150 words for analysis only):**
- Objective; Intent & Scope
- 35 step minimal Plan (may mark parallel steps)
- **Uncertainty & Probes** (at least one probe with a clear stop condition)
- Success Criteria + basic Failure detection & fallback
- **Source Plan** (how evidence will be obtained/verified)
**For HIGH (150250 words for analysis only):**
- Comprehensive objective analysis; Intent & Scope
- 58 step Plan with dependencies/parallelism
- **Uncertainty & Probes** (key unknowns → probe → stop condition)
- Measurable Success Criteria; Failure detectors & fallbacks
- **Source Plan** (evidence acquisition & validation)
- **Reflection Hooks** (escalation/de-escalation triggers)

View File

@@ -0,0 +1,9 @@
**Input Variables**
- **{{ task }}** — the task/request to analyze
- **{{ context }}** — background, history, situational context
- **{{ agent_prompt }}** — special instructions/role hints
- **{{ tools_desc }}** — available sub-agents and capabilities
**Final Output Rule**
Return the Task Transmission section (if needed) followed by the concrete analysis and planning steps according to LOW / MEDIUM / HIGH complexity.
Do not restate the framework, definitions, or rules. Output only the final structured result.

View File

@@ -0,0 +1,14 @@
Role: You're a smart assistant. Your name is Miss R.
Task: Summarize the information from knowledge bases and answer user's question.
Requirements and restriction:
- DO NOT make things up, especially for numbers.
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
- Answer with markdown format text.
- Answer in language of user's question.
- DO NOT make things up, especially for numbers.
### Information from knowledge bases
{{ knowledge }}
The above is information from knowledge bases.

View File

@@ -0,0 +1,53 @@
You are given a JSON array of TOC(tabel of content) items. Each item has at least {"title": string} and may include an existing title hierarchical level.
Task
- For each item, assign a depth label using Arabic numerals only: top-level = 1, second-level = 2, third-level = 3, etc.
- Multiple items may share the same depth (e.g., many 1s, many 2s).
- Do not use dotted numbering (no 1.1/1.2). Use a single digit string per item indicating its depth only.
- Preserve the original item order exactly. Do not insert, delete, or reorder.
- Decide levels yourself to keep a coherent hierarchy. Keep peers at the same depth.
Output
- Return a valid JSON array only (no extra text).
- Each element must be {"level": "1|2|3", "title": <original title string>}.
- title must be the original title string.
Examples
Example A (chapters with sections)
Input:
["Chapter 1 Methods", "Section 1 Definition", "Section 2 Process", "Chapter 2 Experiment"]
Output:
[
{"level":"1","title":"Chapter 1 Methods"},
{"level":"2","title":"Section 1 Definition"},
{"level":"2","title":"Section 2 Process"},
{"level":"1","title":"Chapter 2 Experiment"}
]
Example B (parts with chapters)
Input:
["Part I Theory", "Chapter 1 Basics", "Chapter 2 Methods", "Part II Applications", "Chapter 3 Case Studies"]
Output:
[
{"level":"1","title":"Part I Theory"},
{"level":"2","title":"Chapter 1 Basics"},
{"level":"2","title":"Chapter 2 Methods"},
{"level":"1","title":"Part II Applications"},
{"level":"2","title":"Chapter 3 Case Studies"}
]
Example C (plain headings)
Input:
["Introduction", "Background and Motivation", "Related Work", "Methodology", "Evaluation"]
Output:
[
{"level":"1","title":"Introduction"},
{"level":"2","title":"Background and Motivation"},
{"level":"2","title":"Related Work"},
{"level":"1","title":"Methodology"},
{"level":"1","title":"Evaluation"}
]

View File

@@ -0,0 +1,13 @@
You are an agent for adding correct citations to the given text by user.
You are given a piece of text within [ID:<ID>] tags, which was generated based on the provided sources.
However, the sources are not cited in the [ID:<ID>].
Your task is to enhance user trust by generating correct, appropriate citations for this report.
{{ example }}
<context>
{{ sources }}
</context>

View File

@@ -0,0 +1,109 @@
Based on the provided document or chat history, add citations to the input text using the format specified later.
# Citation Requirements:
## Technical Rules:
- Use format: [ID:i] or [ID:i] [ID:j] for multiple sources
- Place citations at the end of sentences, before punctuation
- Maximum 4 citations per sentence
- DO NOT cite content not from <context></context>
- DO NOT modify whitespace or original text
- STRICTLY prohibit non-standard formatting (~~, etc.)
## What MUST Be Cited:
1. **Quantitative data**: Numbers, percentages, statistics, measurements
2. **Temporal claims**: Dates, timeframes, sequences of events
3. **Causal relationships**: Claims about cause and effect
4. **Comparative statements**: Rankings, comparisons, superlatives
5. **Technical definitions**: Specialized terms, concepts, methodologies
6. **Direct attributions**: What someone said, did, or believes
7. **Predictions/forecasts**: Future projections, trend analyses
8. **Controversial claims**: Disputed facts, minority opinions
## What Should NOT Be Cited:
- Common knowledge (e.g., "The sun rises in the east")
- Transitional phrases
- General introductions
- Your own analysis or synthesis (unless directly from source)
# Comprehensive Examples:
## Example 1: Data and Statistics
<context>
ID: 45
└── Content: The global smartphone market grew by 7.8% in Q3 2024, with Samsung holding 20.6% market share and Apple at 15.8%.
ID: 46
└── Content: 5G adoption reached 1.5 billion users worldwide by October 2024.
</context>
USER: How is the smartphone market performing?
ASSISTANT:
The smartphone industry is showing strong recovery. The global smartphone market grew by 7.8% in Q3 2024 [ID:45]. Samsung leads with 20.6% market share, followed by Apple at 15.8% [ID:45]. This growth is partly driven by 5G technology, which reached 1.5 billion users worldwide by October 2024 [ID:46].
## Example 2: Causal Relationships and Predictions
<context>
ID: 78
└── Content: Rising interest rates have led to a 15% decline in housing starts. Economists predict this trend will continue through 2025.
ID: 79
└── Content: The Federal Reserve's aggressive monetary policy is the primary driver behind the real estate slowdown.
</context>
USER: What's happening in the real estate market?
ASSISTANT:
The real estate market is experiencing significant challenges. Rising interest rates have led to a 15% decline in housing starts [ID:78]. The Federal Reserve's aggressive monetary policy is the primary driver behind this real estate slowdown [ID:79]. Economists predict this trend will continue through 2025 [ID:78]. These factors suggest a prolonged period of market adjustment ahead.
## Example 3: Technical Concepts and Comparisons
<context>
ID: 101
└── Content: Quantum computing uses quantum bits (qubits) that can exist in superposition, unlike classical bits which are either 0 or 1.
ID: 102
└── Content: Current quantum computers can solve certain problems 158 million times faster than classical supercomputers.
</context>
USER: Explain quantum computing advantages.
ASSISTANT:
Quantum computing represents a fundamental shift in computational capability. Unlike classical computers, quantum computing uses quantum bits (qubits) that can exist in superposition, rather than being limited to just 0 or 1 [ID:101]. This unique property enables extraordinary performance gains. Current quantum computers can solve certain problems 158 million times faster than classical supercomputers [ID:102]. This technology promises to revolutionize fields requiring complex calculations.
## Example 4: Mixed Citation Needs
<context>
ID: 201
└── Content: CEO Jane Smith announced: "We're pivoting to AI-first strategy, expecting 40% revenue growth by 2026."
ID: 202
└── Content: Industry analysts remain skeptical, with Morgan Stanley rating the stock as "underperform."
ID: 203
└── Content: The company's AI division currently represents only 5% of total revenue.
</context>
USER: What's the company's new strategy?
ASSISTANT:
The company is undergoing a major strategic transformation. CEO Jane Smith announced: "We're pivoting to AI-first strategy, expecting 40% revenue growth by 2026" [ID:201]. However, this ambitious goal faces skepticism, as industry analysts remain doubtful, with Morgan Stanley rating the stock as "underperform" [ID:202]. The challenge is significant given that the company's AI division currently represents only 5% of total revenue [ID:203]. Success will depend on execution of this bold vision.
## Example 5: When NOT to Cite
<context>
ID: 301
└── Content: Paris is the capital of France. The Eiffel Tower was completed in 1889.
</context>
USER: Tell me about Paris.
ASSISTANT:
Paris is the capital of France. It's known for its rich history, culture, and architecture. The Eiffel Tower was completed in 1889 [ID:301]. The city attracts millions of tourists annually. Paris remains one of the world's most visited destinations.
(Note: Only the specific date needs citation, not common knowledge about Paris)
--- Examples END ---
REMEMBER:
- Cite FACTS, not opinions or transitions
- Each citation supports the ENTIRE sentence
- When in doubt, ask: "Would a fact-checker need to verify this?"
- Place citations at sentence end, before punctuation
- Format likes this is FORBIDDEN: [ID:0, ID:5, ID:...]. It MUST be seperated like, [ID:0][ID:5]...

View File

@@ -0,0 +1,32 @@
## Role
You are a text analyzer.
## Task
Add tags (labels) to a given piece of text content based on the examples and the entire tag set.
## Steps
- Review the tag/label set.
- Review examples which all consist of both text content and assigned tags with relevance score in JSON format.
- Summarize the text content, and tag it with the top {{ topn }} most relevant tags from the set of tags/labels and the corresponding relevance score.
## Requirements
- The tags MUST be from the tag set.
- The output MUST be in JSON format only, the key is tag and the value is its relevance score.
- The relevance score must range from 1 to 10.
- Output keywords ONLY.
# TAG SET
{{ all_tags | join(', ') }}
{% for ex in examples %}
# Examples {{ loop.index0 }}
### Text Content
{{ ex.content }}
Output:
{{ ex.tags_json }}
{% endfor %}
# Real Data
### Text Content
{{ content }}

View File

@@ -0,0 +1,35 @@
## Role
A streamlined multilingual translator.
## Behavior Rules
1. Accept batch translation requests in the following format:
**Input:** `[text]`
**Target Languages:** comma-separated list
2. Maintain:
- Original formatting (tables, lists, spacing)
- Technical terminology accuracy
- Cultural context appropriateness
3. Output translations in the following format:
[Translation in language1]
###
[Translation in language2]
---
## Example
**Input:**
Hello World! Let's discuss AI safety.
===
Chinese, French, Japanese
**Output:**
你好世界!让我们讨论人工智能安全问题。
###
Bonjour le monde ! Parlons de la sécurité de l'IA.
###
こんにちは世界AIの安全性について話し合いましょう。

View File

@@ -0,0 +1,7 @@
**Input:**
{{ query }}
===
{{ languages | join(', ') }}
**Output:**

View File

@@ -0,0 +1,62 @@
## Role
A helpful assistant.
## Task & Steps
1. Generate a full user question that would follow the conversation.
2. If the user's question involves relative dates, convert them into absolute dates based on today ({{ today }}).
- "yesterday" = {{ yesterday }}, "tomorrow" = {{ tomorrow }}
## Requirements & Restrictions
- If the user's latest question is already complete, don't do anything — just return the original question.
- DON'T generate anything except a refined question.
{% if language %}
- Text generated MUST be in {{ language }}.
{% else %}
- Text generated MUST be in the same language as the original user's question.
{% endif %}
---
## Examples
### Example 1
**Conversation:**
USER: What is the name of Donald Trump's father?
ASSISTANT: Fred Trump.
USER: And his mother?
**Output:** What's the name of Donald Trump's mother?
---
### Example 2
**Conversation:**
USER: What is the name of Donald Trump's father?
ASSISTANT: Fred Trump.
USER: And his mother?
ASSISTANT: Mary Trump.
USER: What's her full name?
**Output:** What's the full name of Donald Trump's mother Mary Trump?
---
### Example 3
**Conversation:**
USER: What's the weather today in London?
ASSISTANT: Cloudy.
USER: What's about tomorrow in Rochester?
**Output:** What's the weather in Rochester on {{ tomorrow }}?
---
## Real Data
**Conversation:**
{{ conversation }}

View File

@@ -0,0 +1,728 @@
import datetime
import json
import logging
import re
from copy import deepcopy
from typing import Tuple
import jinja2
import json_repair
import trio
from app.core.rag.common.misc_utils import hash_str2int
from app.core.rag.nlp import rag_tokenizer
from .template import load_prompt
from app.core.rag.common.constants import TAG_FLD
from app.core.rag.common.token_utils import encoder, num_tokens_from_string
STOP_TOKEN="<|STOP|>"
COMPLETE_TASK="complete_task"
INPUT_UTILIZATION = 0.5
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))
def chunks_format(reference):
return [
{
"id": get_value(chunk, "chunk_id", "id"),
"content": get_value(chunk, "content", "content_with_weight"),
"document_id": get_value(chunk, "doc_id", "document_id"),
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"image_id": get_value(chunk, "image_id", "img_id"),
"positions": get_value(chunk, "positions", "position_int"),
"url": chunk.get("url"),
"similarity": chunk.get("similarity"),
"vector_similarity": chunk.get("vector_similarity"),
"term_similarity": chunk.get("term_similarity"),
"doc_type": chunk.get("doc_type_kwd"),
}
for chunk in reference.get("chunks", [])
]
def message_fit_in(msg, max_length=4000):
def count():
nonlocal msg
tks_cnts = []
for m in msg:
tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
total = 0
for m in tks_cnts:
total += m["count"]
return total
c = count()
if c < max_length:
return c, msg
msg_ = [m for m in msg if m["role"] == "system"]
if len(msg) > 1:
msg_.append(msg[-1])
msg = msg_
c = count()
if c < max_length:
return c, msg
ll = num_tokens_from_string(msg_[0]["content"])
ll2 = num_tokens_from_string(msg_[-1]["content"])
if ll / (ll + ll2) > 0.8:
m = msg_[0]["content"]
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
msg[0]["content"] = m
return max_length, msg
m = msg_[-1]["content"]
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
msg[-1]["content"] = m
return max_length, msg
CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
CITATION_PLUS_TEMPLATE = load_prompt("citation_plus")
CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE = load_prompt("cross_languages_sys_prompt")
CROSS_LANGUAGES_USER_PROMPT_TEMPLATE = load_prompt("cross_languages_user_prompt")
FULL_QUESTION_PROMPT_TEMPLATE = load_prompt("full_question_prompt")
KEYWORD_PROMPT_TEMPLATE = load_prompt("keyword_prompt")
QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt")
VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt")
VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt")
STRUCTURED_OUTPUT_PROMPT = load_prompt("structured_output_prompt")
ANALYZE_TASK_SYSTEM = load_prompt("analyze_task_system")
ANALYZE_TASK_USER = load_prompt("analyze_task_user")
NEXT_STEP = load_prompt("next_step")
REFLECT = load_prompt("reflect")
SUMMARY4MEMORY = load_prompt("summary4memory")
RANK_MEMORY = load_prompt("rank_memory")
META_FILTER = load_prompt("meta_filter")
ASK_SUMMARY = load_prompt("ask_summary")
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
def citation_prompt(user_defined_prompts: dict={}) -> str:
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE))
return template.render()
def citation_plus(sources: str) -> str:
template = PROMPT_JINJA_ENV.from_string(CITATION_PLUS_TEMPLATE)
return template.render(example=citation_prompt(), sources=sources)
def keyword_extraction(chat_mdl, content, topn=3):
template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
def question_proposal(chat_mdl, content, topn=3):
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
def full_question(messages=[], language=None, chat_mdl=None):
conv = []
for m in messages:
if m["role"] not in ["user", "assistant"]:
continue
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
conversation = "\n".join(conv)
today = datetime.date.today().isoformat()
yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat()
tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat()
template = PROMPT_JINJA_ENV.from_string(FULL_QUESTION_PROMPT_TEMPLATE)
rendered_prompt = template.render(
today=today,
yesterday=yesterday,
tomorrow=tomorrow,
conversation=conversation,
language=language,
)
ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
def cross_languages(query, languages=[], chat_mdl=None):
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if ans.find("**ERROR**") >= 0:
return query
return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE)
for ex in examples:
ex["tags_json"] = json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False)
rendered_prompt = template.render(
topn=topn,
all_tags=all_tags,
examples=examples,
content=content,
)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
raise Exception(kwd)
try:
obj = json_repair.loads(kwd)
except json_repair.JSONDecodeError:
try:
result = kwd.replace(rendered_prompt[:-1], "").replace("user", "").replace("model", "").strip()
result = "{" + result.split("{")[1].split("}")[0] + "}"
obj = json_repair.loads(result)
except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}")
raise e
res = {}
for k, v in obj.items():
try:
if int(v) > 0:
res[str(k)] = int(v)
except Exception:
pass
return res
def vision_llm_describe_prompt(page=None) -> str:
template = PROMPT_JINJA_ENV.from_string(VISION_LLM_DESCRIBE_PROMPT)
return template.render(page=page)
def vision_llm_figure_describe_prompt() -> str:
template = PROMPT_JINJA_ENV.from_string(VISION_LLM_FIGURE_DESCRIBE_PROMPT)
return template.render()
def tool_schema(tools_description: list[dict], complete_task=False):
if not tools_description:
return ""
desc = {}
if complete_task:
desc[COMPLETE_TASK] = {
"type": "function",
"function": {
"name": COMPLETE_TASK,
"description": "When you have the final answer and are ready to complete the task, call this function with your answer",
"parameters": {
"type": "object",
"properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}},
"required": ["answer"]
}
}
}
for tool in tools_description:
desc[tool["function"]["name"]] = tool
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])
def form_history(history, limit=-6):
context = ""
for h in history[limit:]:
if h["role"] == "system":
continue
role = "USER"
if h["role"].upper()!= role:
role = "AGENT"
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}"
return context
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
tools_desc = tool_schema(tools_description)
context = ""
if user_defined_prompts.get("task_analysis"):
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts["task_analysis"])
else:
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
kwd = chat_mdl.chat(context, [{"role": "user", "content": "Please analyze it."}])
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
if not tools_description:
return ""
desc = tool_schema(tools_description)
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP))
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
hist = deepcopy(history)
if hist[-1]["role"] == "user":
hist[-1]["content"] += user_prompt
else:
hist.append({"role": "user", "content": user_prompt})
json_str = chat_mdl.chat(template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
hist[1:], stop=["<|stop|>"])
tk_cnt = num_tokens_from_string(json_str)
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
return json_str, tk_cnt
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"]
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
user_prompt = template.render(goal=goal, tool_calls=tool_calls)
hist = deepcopy(history)
if hist[-1]["role"] == "user":
hist[-1]["content"] += user_prompt
else:
hist.append({"role": "user", "content": user_prompt})
_, msg = message_fit_in(hist, chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return """
**Observation**
{}
**Reflection**
{}
""".format(json.dumps(tool_calls, ensure_ascii=False, indent=2), ans)
def form_message(system_prompt, user_prompt):
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
def structured_output_prompt(schema=None) -> str:
template = PROMPT_JINJA_ENV.from_string(STRUCTURED_OUTPUT_PROMPT)
return template.render(schema=schema)
def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
system_prompt = template.render(name=name,
params=json.dumps(params, ensure_ascii=False, indent=2),
result=result)
user_prompt = "→ Summary: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> list:
sys_prompt = PROMPT_JINJA_ENV.from_string(META_FILTER).render(
current_date=datetime.datetime.today().strftime('%Y-%m-%d'),
metadata_keys=json.dumps(meta_data),
user_question=query
)
user_prompt = "Generate filters:"
ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}])
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
ans = json_repair.loads(ans)
assert isinstance(ans, list), ans
return ans
except Exception:
logging.exception(f"Loading json failure: {ans}")
return []
def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
from app.core.rag.graphrag.utils import get_llm_cache, set_llm_cache
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
if cached:
return json_repair.loads(cached)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
res = json_repair.loads(ans)
set_llm_cache(chat_mdl.llm_name, system_prompt, ans, user_prompt, gen_conf)
return res
except Exception:
logging.exception(f"Loading json failure: {ans}")
TOC_DETECTION = load_prompt("toc_detection")
def detect_table_of_contents(page_1024:list[str], chat_mdl):
toc_secs = []
for i, sec in enumerate(page_1024[:22]):
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
if toc_secs and not ans["exists"]:
break
toc_secs.append(sec)
return toc_secs
TOC_EXTRACTION = load_prompt("toc_extraction")
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
def extract_table_of_contents(toc_pages, chat_mdl):
if not toc_pages:
return []
return gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
tob_extractor_prompt = """
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
The provided pages contains tags like <physical_index_X> and <physical_index_X> to indicate the physical location of the page X.
The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
The response should be in the following JSON format:
[
{
"structure": <structure index, "x.x.x" or None> (string),
"title": <title of the section>,
"physical_index": "<physical_index_X>" (keep the format)
},
...
]
Only add the physical_index to the sections that are in the provided pages.
If the title of the section are not in the provided pages, do not add the physical_index to it.
Directly return the final JSON structure. Do not output anything else."""
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content
return gen_json(prompt, "Only JSON please.", chat_mdl)
TOC_INDEX = load_prompt("toc_index")
def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
if not toc_arr or not sections:
return []
toc_map = {}
for i, it in enumerate(toc_arr):
k1 = (it["structure"]+it["title"]).replace(" ", "")
k2 = it["title"].strip()
if k1 not in toc_map:
toc_map[k1] = []
if k2 not in toc_map:
toc_map[k2] = []
toc_map[k1].append(i)
toc_map[k2].append(i)
for it in toc_arr:
it["indices"] = []
for i, sec in enumerate(sections):
sec = sec.strip()
if sec.replace(" ", "") in toc_map:
for j in toc_map[sec.replace(" ", "")]:
toc_arr[j]["indices"].append(i)
all_pathes = []
def dfs(start, path):
nonlocal all_pathes
if start >= len(toc_arr):
if path:
all_pathes.append(path)
return
if not toc_arr[start]["indices"]:
dfs(start+1, path)
return
added = False
for j in toc_arr[start]["indices"]:
if path and j < path[-1][0]:
continue
_path = deepcopy(path)
_path.append((j, start))
added = True
dfs(start+1, _path)
if not added and path:
all_pathes.append(path)
dfs(0, [])
path = max(all_pathes, key=lambda x:len(x))
for it in toc_arr:
it["indices"] = []
for j, i in path:
toc_arr[i]["indices"] = [j]
print(json.dumps(toc_arr, ensure_ascii=False, indent=2))
i = 0
while i < len(toc_arr):
it = toc_arr[i]
if it["indices"]:
i += 1
continue
if i>0 and toc_arr[i-1]["indices"]:
st_i = toc_arr[i-1]["indices"][-1]
else:
st_i = 0
e = i + 1
while e <len(toc_arr) and not toc_arr[e]["indices"]:
e += 1
if e >= len(toc_arr):
e = len(sections)
else:
e = toc_arr[e]["indices"][0]
for j in range(st_i, min(e+1, len(sections))):
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
structure=it["structure"],
title=it["title"],
text=sections[j]), "Only JSON please.", chat_mdl)
if ans["exist"] == "yes":
it["indices"].append(j)
break
i += 1
return toc_arr
def check_if_toc_transformation_is_complete(content, toc, chat_mdl):
prompt = """
You are given a raw table of contents and a table of contents.
Your job is to check if the table of contents is complete.
Reply format:
{{
"thinking": <why do you think the cleaned table of contents is complete or not>
"completed": "yes" or "no"
}}
Directly return the final JSON structure. Do not output anything else."""
prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc
response = gen_json(prompt, "Only JSON please.", chat_mdl)
return response['completed']
def toc_transformer(toc_pages, chat_mdl):
init_prompt = """
You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents.
The `structure` is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
The `title` is a short phrase or a several-words term.
The response should be in the following JSON format:
[
{
"structure": <structure index, "x.x.x" or None> (string),
"title": <title of the section>
},
...
],
You should transform the full table of contents in one go.
Directly return the final JSON structure, do not output anything else. """
toc_content = "\n".join(toc_pages)
prompt = init_prompt + '\n Given table of contents\n:' + toc_content
def clean_toc(arr):
for a in arr:
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
last_complete = gen_json(prompt, "Only JSON please.", chat_mdl)
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
clean_toc(last_complete)
if if_complete == "yes":
return last_complete
while not (if_complete == "yes"):
prompt = f"""
Your task is to continue the table of contents json structure, directly output the remaining part of the json structure.
The response should be in the following JSON format:
The raw table of contents json structure is:
{toc_content}
The incomplete transformed table of contents json structure is:
{json.dumps(last_complete[-24:], ensure_ascii=False, indent=2)}
Please continue the json structure, directly output the remaining part of the json structure."""
new_complete = gen_json(prompt, "Only JSON please.", chat_mdl)
if not new_complete or str(last_complete).find(str(new_complete)) >= 0:
break
clean_toc(new_complete)
last_complete.extend(new_complete)
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
return last_complete
TOC_LEVELS = load_prompt("assign_toc_levels")
def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
if not toc_secs:
return []
return gen_json(
PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(),
str(toc_secs),
chat_mdl,
gen_conf
)
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
# Generate TOC from text chunks with text llms
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
try:
ans = gen_json(
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9}
)
txt_info["toc"] = ans if ans and not isinstance(ans, str) else []
if callback:
callback(msg="")
except Exception as e:
logging.exception(e)
def split_chunks(chunks, max_length: int):
"""
Pack chunks into batches according to max_length, returning [{"id": idx, "text": chunk_text}, ...].
Do not split a single chunk, even if it exceeds max_length.
"""
result = []
batch, batch_tokens = [], 0
for idx, chunk in enumerate(chunks):
t = num_tokens_from_string(chunk)
if batch_tokens + t > max_length:
result.append(batch)
batch, batch_tokens = [], 0
batch.append({idx: chunk})
batch_tokens += t
if batch:
result.append(batch)
return result
async def run_toc_from_text(chunks, chat_mdl, callback=None):
input_budget = int(chat_mdl.max_length * INPUT_UTILIZATION) - num_tokens_from_string(
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
)
input_budget = 1024 if input_budget > 1024 else input_budget
chunk_sections = split_chunks(chunks, input_budget)
titles = []
chunks_res = []
async with trio.open_nursery() as nursery:
for i, chunk in enumerate(chunk_sections):
if not chunk:
continue
chunks_res.append({"chunks": chunk})
nursery.start_soon(gen_toc_from_text, chunks_res[-1], chat_mdl, callback)
for chunk in chunks_res:
titles.extend(chunk.get("toc", []))
# Filter out entries with title == -1
prune = len(titles) > 512
max_len = 12 if prune else 22
filtered = []
for x in titles:
if not isinstance(x, dict) or not x.get("title") or x["title"] == "-1":
continue
if len(rag_tokenizer.tokenize(x["title"]).split(" ")) > max_len:
continue
if re.match(r"[0-9,.()/ -]+$", x["title"]):
continue
filtered.append(x)
logging.info(f"\n\nFiltered TOC sections:\n{filtered}")
if not filtered:
return []
# Generate initial level (level/title)
raw_structure = [x.get("title", "") for x in filtered]
# Assign hierarchy levels using LLM
toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9})
if not toc_with_levels:
return []
# Merge structure and content (by index)
prune = len(toc_with_levels) > 512
max_lvl = sorted([t.get("level", "0") for t in toc_with_levels if isinstance(t, dict)])[-1]
merged = []
for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
if prune and toc_item.get("level", "0") >= max_lvl:
continue
merged.append({
"level": toc_item.get("level", "0"),
"title": toc_item.get("title", ""),
"chunk_id": src_item.get("chunk_id", ""),
})
return merged
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
import numpy as np
try:
ans = gen_json(
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])),
chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9}
)
id2score = {}
for ti, sc in zip(toc, ans):
if not isinstance(sc, dict) or sc.get("score", -1) < 1:
continue
for id in ti.get("ids", []):
if id not in id2score:
id2score[id] = []
id2score[id].append(sc["score"]/5.)
for id in id2score.keys():
id2score[id] = np.mean(id2score[id])
return [(id, sc) for id, sc in list(id2score.items()) if sc>=0.3][:topn]
except Exception as e:
logging.exception(e)
return []

View File

@@ -0,0 +1,16 @@
## Role
You are a text analyzer.
## Task
Extract the most important keywords/phrases of a given piece of text content.
## Requirements
- Summarize the text content, and give the top {{ topn }} important keywords/phrases.
- The keywords MUST be in the same language as the given piece of text content.
- The keywords are delimited by ENGLISH COMMA.
- Output keywords ONLY.
---
## Text Content
{{ content }}

View File

@@ -0,0 +1,53 @@
You are a metadata filtering condition generator. Analyze the user's question and available document metadata to output a JSON array of filter objects. Follow these rules:
1. **Metadata Structure**:
- Metadata is provided as JSON where keys are attribute names (e.g., "color"), and values are objects mapping attribute values to document IDs.
- Example:
{
"color": {"red": ["doc1"], "blue": ["doc2"]},
"listing_date": {"2025-07-11": ["doc1"], "2025-08-01": ["doc2"]}
}
2. **Output Requirements**:
- Always output a JSON array of filter objects
- Each object must have:
"key": (metadata attribute name),
"value": (string value to compare),
"op": (operator from allowed list)
3. **Operator Guide**:
- Use these operators only: ["contains", "not contains", "start with", "end with", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤"]
- Date ranges: Break into two conditions (≥ start_date AND < next_month_start)
- Negations: Always use "≠" for exclusion terms ("not", "except", "exclude", "≠")
- Implicit logic: Derive unstated filters (e.g., "July" → [≥ YYYY-07-01, < YYYY-08-01])
4. **Processing Steps**:
a) Identify ALL filterable attributes in the query (both explicit and implicit)
b) For dates:
- Infer missing year from current date if needed
- Always format dates as "YYYY-MM-DD"
- Convert ranges: [≥ start, < end]
c) For values: Match EXACTLY to metadata's value keys
d) Skip conditions if:
- Attribute doesn't exist in metadata
- Value has no match in metadata
5. **Example**:
- User query: "上市日期七月份的有哪些商品,不要蓝色的"
- Metadata: { "color": {...}, "listing_date": {...} }
- Output:
[
{"key": "listing_date", "value": "2025-07-01", "op": "≥"},
{"key": "listing_date", "value": "2025-08-01", "op": "<"},
{"key": "color", "value": "blue", "op": "≠"}
]
6. **Final Output**:
- ONLY output valid JSON array
- NO additional text/explanations
**Current Task**:
- Today's date: {{current_date}}
- Available metadata keys: {{metadata_keys}}
- User query: "{{user_question}}"

View File

@@ -0,0 +1,92 @@
You are an expert Planning Agent tasked with solving problems efficiently through structured plans.
Your job is:
1. Based on the task analysis, chose some right tools to execute.
2. Track progress and adapt plans(tool calls) when necessary.
3. Use `complete_task` if no further step you need to take from tools. (All necessary steps done or little hope to be done)
# ========== TASK ANALYSIS =============
{{ task_analysis }}
# ========== TOOLS (JSON-Schema) ==========
You may invoke only the tools listed below.
Return a JSON array of objects in which item is with exactly two top-level keys:
• "name": the tool to call
• "arguments": an object whose keys/values satisfy the schema
{{ desc }}
# ========== MULTI-STEP EXECUTION ==========
When tasks require multiple independent steps, you can execute them in parallel by returning multiple tool calls in a single JSON array.
**Data Collection**: Gathering information from multiple sources simultaneously
**Validation**: Cross-checking facts using different tools
**Comprehensive Analysis**: Analyzing different aspects of the same problem
**Efficiency**: Reducing total execution time when steps don't depend on each other
**Example Scenarios:**
- Searching multiple databases for the same query
- Checking weather in multiple cities
- Validating information through different APIs
- Performing calculations on different datasets
- Gathering user preferences from multiple sources
# ========== RESPONSE FORMAT ==========
**When you need a tool**
Return ONLY the Json (no additional keys, no commentary, end with `<|stop|>`), such as following:
[{
"name": "<tool_name1>",
"arguments": { /* tool arguments matching its schema */ }
},{
"name": "<tool_name2>",
"arguments": { /* tool arguments matching its schema */ }
}...]<|stop|>
**When you need multiple tools:**
Return ONLY:
[{
"name": "<tool_name1>",
"arguments": { /* tool arguments matching its schema */ }
},{
"name": "<tool_name2>",
"arguments": { /* tool arguments matching its schema */ }
},{
"name": "<tool_name3>",
"arguments": { /* tool arguments matching its schema */ }
}...]<|stop|>
**When you are certain the task is solved OR no further information can be obtained**
Return ONLY:
[{
"name": "complete_task",
"arguments": { "answer": "<final answer text>" }
}]<|stop|>
<verification_steps>
Before providing a final answer:
1. Double-check all gathered information
2. Verify calculations and logic
3. Ensure answer matches exactly what was asked
4. Confirm answer format meets requirements
5. Run additional verification if confidence is not 100%
</verification_steps>
<error_handling>
If you encounter issues:
1. Try alternative approaches before giving up
2. Use different tools or combinations of tools
3. Break complex problems into simpler sub-tasks
4. Verify intermediate results frequently
5. Never return "I cannot answer" without exhausting all options
</error_handling>
⚠️ Any output that is not valid JSON or that contains extra fields will be rejected.
# ========== REASONING & REFLECTION ==========
You may think privately (not shown to the user) before producing each JSON object.
Internal guideline:
1. **Reason**: Analyse the user question; decide which tools (if any) are needed.
2. **Act**: Emit the JSON object to call the tool.
Today is {{ today }}. Remember that success in answering questions accurately is paramount - take all necessary steps to ensure your answer is correct.

View File

@@ -0,0 +1,19 @@
## Role
You are a text analyzer.
## Task
Propose {{ topn }} questions about a given piece of text content.
## Requirements
- Understand and summarize the text content, and propose the top {{ topn }} important questions.
- The questions SHOULD NOT have overlapping meanings.
- The questions SHOULD cover the main content of the text as much as possible.
- The questions MUST be in the same language as the given piece of text content.
- One question per line.
- Output questions ONLY.
---
## Text Content
{{ content }}

View File

@@ -0,0 +1,30 @@
**Task**: Sort the tool call results based on relevance to the overall goal and current sub-goal. Return ONLY a sorted list of indices (0-indexed).
**Rules**:
1. Analyze each result's contribution to both:
- The overall goal (primary priority)
- The current sub-goal (secondary priority)
2. Sort from MOST relevant (highest impact) to LEAST relevant
3. Output format: Strictly a Python-style list of integers. Example: [2, 0, 1]
🔹 Overall Goal: {{ goal }}
🔹 Sub-goal: {{ sub_goal }}
**Examples**:
🔹 Tool Response:
- index: 0
> Tokyo temperature is 78°F.
- index: 1
> Error: Authentication failed (expired API key).
- index: 2
> Available: 12 widgets in stock (max 5 per customer).
→ rank: [1,2,0]<|stop|>
**Your Turn**:
🔹 Tool Response:
{% for f in results %}
- index: f.i
> f.content
{% endfor %}

View File

@@ -0,0 +1,75 @@
**Context**:
- To achieve the goal: {{ goal }}.
- You have executed following tool calls:
{% for call in tool_calls %}
Tool call: `{{ call.name }}`
Results: {{ call.result }}
{% endfor %}
## Task Complexity Analysis & Reflection Scope
**First, analyze the task complexity using these dimensions:**
### Complexity Assessment Matrix
- **Scope Breadth**: Single-step (1) | Multi-step (2) | Multi-domain (3)
- **Data Dependency**: Self-contained (1) | External inputs (2) | Multiple sources (3)
- **Decision Points**: Linear (1) | Few branches (2) | Complex logic (3)
- **Risk Level**: Low (1) | Medium (2) | High (3)
**Complexity Score**: Sum all dimensions (4-12 points)
---
## Task Transmission Assessment
**Note**: This section is not subject to word count limitations when transmission is needed, as it serves critical handoff functions.
**Evaluate if task transmission information is needed:**
- **Is this an initial step?** If yes, skip this section
- **Are there downstream agents/steps?** If no, provide minimal transmission
- **Is there critical state/context to preserve?** If yes, include full transmission
### If Task Transmission is Needed:
- **Current State Summary**: [1-2 sentences on where we are]
- **Key Data/Results**: [Critical findings that must carry forward]
- **Context Dependencies**: [Essential context for next agent/step]
- **Unresolved Items**: [Issues requiring continuation]
- **Status for User**: [Clear status update in user terms]
- **Technical State**: [System state for technical handoffs]
---
## Situational Reflection (Adjust Length Based on Complexity Score)
### Reflection Guidelines:
- **Simple Tasks (4-5 points)**: ~50-100 words, focus on completion status and immediate next step
- **Moderate Tasks (6-8 points)**: ~100-200 words, include core details and main risks
- **Complex Tasks (9-12 points)**: ~200-300 words, provide full analysis and alternatives
### 1. Goal Achievement Status
- Does the current outcome align with the original purpose of this task phase?
- If not, what critical gaps exist?
### 2. Step Completion Check
- Which planned steps were completed? (List verified items)
- Which steps are pending/incomplete? (Specify exactly what's missing)
### 3. Information Adequacy
- Is the collected data sufficient to proceed?
- What key information is still needed? (e.g., metrics, user input, external data)
### 4. Critical Observations
- Unexpected outcomes: [Flag anomalies/errors]
- Risks/blockers: [Identify immediate obstacles]
- Accuracy concerns: [Highlight unreliable results]
### 5. Next-Step Recommendations
- Proposed immediate action: [Concrete next step]
- Alternative strategies if blocked: [Workaround solution]
- Tools/inputs required for next phase: [Specify resources]
---
**Output Instructions:**
1. First determine your complexity score
2. Assess if task transmission section is needed using the evaluation questions
3. Provide situational reflection with length appropriate to complexity
4. Use clear headers for easy parsing by downstream systems

View File

@@ -0,0 +1,55 @@
# Role
You are an AI language model assistant tasked with generating **5-10 related questions** based on a users original query.
These questions should help **expand the search query scope** and **improve search relevance**.
---
## Instructions
**Input:**
You are provided with a **users question**.
**Output:**
Generate **5-10 alternative questions** that are **related** to the original user question.
These alternatives should help retrieve a **broader range of relevant documents** from a vector database.
**Context:**
Focus on **rephrasing** the original question in different ways, ensuring the alternative questions are **diverse but still connected** to the topic of the original query.
Do **not** create overly obscure, irrelevant, or unrelated questions.
**Fallback:**
If you cannot generate any relevant alternatives, do **not** return any questions.
---
## Guidance
1. Each alternative should be **unique** but still **relevant** to the original query.
2. Keep the phrasing **clear, concise, and easy to understand**.
3. Avoid overly technical jargon or specialized terms **unless directly relevant**.
4. Ensure that each question **broadens** the search angle, **not narrows** it.
---
## Example
**Original Question:**
> What are the benefits of electric vehicles?
**Alternative Questions:**
1. How do electric vehicles impact the environment?
2. What are the advantages of owning an electric car?
3. What is the cost-effectiveness of electric vehicles?
4. How do electric vehicles compare to traditional cars in terms of fuel efficiency?
5. What are the environmental benefits of switching to electric cars?
6. How do electric vehicles help reduce carbon emissions?
7. Why are electric vehicles becoming more popular?
8. What are the long-term savings of using electric vehicles?
9. How do electric vehicles contribute to sustainability?
10. What are the key benefits of electric vehicles for consumers?
---
## Reason
Rephrasing the original query into multiple alternative questions helps the user explore **different aspects** of their search topic, improving the **quality of search results**.
These questions guide the search engine to provide a **more comprehensive set** of relevant documents.

View File

@@ -0,0 +1,16 @@
Youre a helpful AI assistant. You could answer questions and output in JSON format.
constraints:
- You must output in JSON format.
- Do not output boolean value, use string type instead.
- Do not output integer or float value, use number type instead.
eg:
Here is the JSON schema:
{"properties": {"age": {"type": "number","description": ""},"name": {"type": "string","description": ""}},"required": ["age","name"],"type": "Object Array String Number Boolean","value": ""}
Here is the user's question:
My name is John Doe and I am 30 years old.
output:
{"name": "John Doe", "age": 30}
Here is the JSON schema:
{{ schema }}

View File

@@ -0,0 +1,35 @@
**Role**: AI Assistant
**Task**: Summarize tool call responses
**Rules**:
1. Context: You've executed a tool (API/function) and received a response.
2. Condense the response into 1-2 short sentences.
3. Never omit:
- Success/error status
- Core results (e.g., data points, decisions)
- Critical constraints (e.g., limits, conditions)
4. Exclude technical details like timestamps/request IDs unless crucial.
5. Use language as the same as main content of the tool response.
**Response Template**:
"[Status] + [Key Outcome] + [Critical Constraints]"
**Examples**:
🔹 Tool Response:
{"status": "success", "temperature": 78.2, "unit": "F", "location": "Tokyo", "timestamp": 16923456}
→ Summary: "Success: Tokyo temperature is 78°F."
🔹 Tool Response:
{"error": "invalid_api_key", "message": "Authentication failed: expired key"}
→ Summary: "Error: Authentication failed (expired API key)."
🔹 Tool Response:
{"available": true, "inventory": 12, "product": "widget", "limit": "max 5 per customer"}
→ Summary: "Available: 12 widgets in stock (max 5 per customer)."
**Your Turn**:
- Tool call: {{ name }}
- Tool inputs as following:
{{ params }}
- Tool Response:
{{ result }}

View File

@@ -0,0 +1,20 @@
import os
PROMPT_DIR = os.path.dirname(__file__)
_loaded_prompts = {}
def load_prompt(name: str) -> str:
if name in _loaded_prompts:
return _loaded_prompts[name]
path = os.path.join(PROMPT_DIR, f"{name}.md")
if not os.path.isfile(path):
raise FileNotFoundError(f"Prompt file '{name}.md' not found in prompts/ directory.")
with open(path, "r", encoding="utf-8") as f:
content = f.read().strip()
_loaded_prompts[name] = content
return content

View File

@@ -0,0 +1,29 @@
You are an AI assistant designed to analyze text content and detect whether a table of contents (TOC) list exists on the given page. Follow these steps:
1. **Analyze the Input**: Carefully review the provided text content.
2. **Identify Key Features**: Look for common indicators of a TOC, such as:
- Section titles or headings paired with page numbers.
- Patterns like repeated formatting (e.g., bold/italicized text, dots/dashes between titles and numbers).
- Phrases like "Table of Contents," "Contents," or similar headings.
- Logical grouping of topics/subtopics with sequential page references.
3. **Discern Negative Features**:
- The text contains no numbers, or the numbers present are clearly not page references (e.g., dates, statistical figures, phone numbers, version numbers).
- The text consists of full, descriptive sentences and paragraphs that form a narrative, present arguments, or explain concepts, rather than succinctly listing topics.
- Contains citations with authors, publication years, journal titles, and page ranges (e.g., "Smith, J. (2020). Journal Title, 10(2), 45-67.").
- Lists keywords or terms followed by multiple page numbers, often in alphabetical order.
- Comprises terms followed by their definitions or explanations.
- Labeled with headers like "Appendix A," "Appendix B," etc.
- Contains expressive language thanking individuals or organizations for their support or contributions.
4. **Evaluate Evidence**: Weigh the presence/absence of these features to determine if the content resembles a TOC.
5. **Output Format**: Provide your response in the following JSON structure:
```json
{
"reasoning": "Step-by-step explanation of your analysis based on the features identified." ,
"exists": true/false
}
```
6. **DO NOT** output anything else except JSON structure.
**Input text Content ( Text-Only Extraction ):**
{{ page_txt }}

View File

@@ -0,0 +1,53 @@
You are an expert parser and data formatter. Your task is to analyze the provided table of contents (TOC) text and convert it into a valid JSON array of objects.
**Instructions:**
1. Analyze each line of the input TOC.
2. For each line, extract the following three pieces of information:
* `structure`: The hierarchical index/numbering (e.g., "1", "2.1", "3.2.5", "A.1"). If a line has no visible numbering or structure indicator (like a main "Chapter" title), use `null`.
* `title`: The textual title of the section or chapter. This should be the main descriptive text, clean and without the page number.
3. Output **only** a valid JSON array. Do not include any other text, explanations, or markdown code block fences (like ```json) in your response.
**JSON Format:**
The output must be a list of objects following this exact schema:
```json
[
{
"structure": <structure index, "x.x.x" or None> (string,
"title": <title of the section>
},
...
]
```
**Input Example:**
```
Contents
1 Introduction to the System ... 1
1.1 Overview .... 2
1.2 Key Features .... 5
2 Installation Guide ....8
2.1 Prerequisites ........ 9
2.2 Step-by-Step Process ........ 12
Appendix A: Specifications ..... 45
References ... 47
```
**Expected Output For The Example:**
```json
[
{"structure": null, "title": "Contents"},
{"structure": "1", "title": "Introduction to the System"},
{"structure": "1.1", "title": "Overview"},
{"structure": "1.2", "title": "Key Features"},
{"structure": "2", "title": "Installation Guide"},
{"structure": "2.1", "title": "Prerequisites"},
{"structure": "2.2", "title": "Step-by-Step Process"},
{"structure": "A", "title": "Specifications"},
{"structure": null, "title": "References"}
]
```
**Now, process the following TOC input:**
```
{{ toc_page }}
```

View File

@@ -0,0 +1,60 @@
You are an expert parser and data formatter, currently in the process of building a JSON array from a multi-page table of contents (TOC). Your task is to analyze the new page of content and **append** the new entries to the existing JSON array.
**Instructions:**
1. You will be given two inputs:
* `current_page_text`: The text content from the new page of the TOC.
* `existing_json`: The valid JSON array you have generated from the previous pages.
2. Analyze each line of the `current_page_text` input.
3. For each new line, extract the following three pieces of information:
* `structure`: The hierarchical index/numbering (e.g., "1", "2.1", "3.2.5"). Use `null` if none exists.
* `title`: The clean textual title of the section or chapter.
* `page`: The page number on which the section starts. Extract only the number. Use `null` if not present.
4. **Append these new entries** to the `existing_json` array. Do not modify, reorder, or delete any of the existing entries.
5. Output **only** the complete, updated JSON array. Do not include any other text, explanations, or markdown code block fences (like ```json).
**JSON Format:**
The output must be a valid JSON array following this schema:
```json
[
{
"structure": <string or null>,
"title": <string>,
"page": <number or null>
},
...
]
```
**Input Example:**
`current_page_text`:
```
3.2 Advanced Configuration ........... 25
3.3 Troubleshooting .................. 28
4 User Management .................... 30
```
`existing_json`:
```json
[
{"structure": "1", "title": "Introduction", "page": 1},
{"structure": "2", "title": "Installation", "page": 5},
{"structure": "3", "title": "Configuration", "page": 12},
{"structure": "3.1", "title": "Basic Setup", "page": 15}
]
```
**Expected Output For The Example:**
```json
[
{"structure": "3.2", "title": "Advanced Configuration", "page": 25},
{"structure": "3.3", "title": "Troubleshooting", "page": 28},
{"structure": "4", "title": "User Management", "page": 30}
]
```
**Now, process the following inputs:**
`current_page_text`:
{{ toc_page }}
`existing_json`:
{{ toc_json }}

View File

@@ -0,0 +1,119 @@
You are a robust Table-of-Contents (TOC) extractor.
GOAL
Given a dictionary of chunks {"<chunk_ID>": chunk_text}, extract TOC-like headings and return a strict JSON array of objects:
[
{"title": "", "chunk_id": ""},
...
]
FIELDS
- "title": the heading text (clean, no page numbers or leader dots).
- If any part of a chunk has no valid heading, output that part as {"title":"-1", ...}.
- "chunk_id": the chunk ID (string).
- One chunk can yield multiple JSON objects in order (unmatched text + one or more headings).
RULES
1) Preserve input chunk order strictly.
2) If a chunk contains multiple headings, expand them in order:
- Pre-heading narrative → {"title":"-1","chunk_id":"<chunk_ID>"}
- Then each heading → {"title":"...","chunk_id":"<chunk_ID>"}
3) Do not merge outputs across chunks; each object refers to exactly one chunk ID.
4) "title" must be non-empty (or exactly "-1"). "chunk_id" must be a string (chunk ID).
5) When ambiguous, prefer "-1" unless the text strongly looks like a heading.
HEADING DETECTION (cues, not hard rules)
- Appears near line start, short isolated phrase, often followed by content.
- May contain separators: — —— - : · •
- Numbering styles:
• 第[一二三四五六七八九十百]+(篇|章|节|条)
• [(]?[一二三四五六七八九十]+[)]?
• [(]?[①②③④⑤⑥⑦⑧⑨⑩][)]?
• ^\d+(\.\d+)*[).]?\s*
• ^[IVXLCDM]+[).]
• ^[A-Z][).]
- Canonical section cues (general only):
Common heading indicators include words such as:
"Overview", "Introduction", "Background", "Purpose", "Scope", "Definition",
"Method", "Procedure", "Result", "Discussion", "Summary", "Conclusion",
"Appendix", "Reference", "Annex", "Acknowledgment", "Disclaimer".
These are soft cues, not strict requirements.
- Length restriction:
• Chinese heading: ≤25 characters
• English heading: ≤80 characters
- Exclude long narrative sentences, continuous prose, or bullet-style lists → output as "-1".
OUTPUT FORMAT
- Return ONLY a valid JSON array of {"title","content"} objects.
- No reasoning or commentary.
EXAMPLES
Example 1 — No heading
Input:
[{"0": "Copyright page · Publication info (ISBN 123-456). All rights reserved."}, ...]
Output:
[
{"title":"-1","chunk_id":"0"},
...
]
Example 2 — One heading
Input:
[{"1": "Chapter 1: General Provisions This chapter defines the overall rules…"}, ...]
Output:
[
{"title":"Chapter 1: General Provisions","chunk_id":"1"},
...
]
Example 3 — Narrative + heading
Input:
[{"2": "This paragraph introduces the background and goals. Section 2: Definitions Key terms are explained…"}, ...]
Output:
[
{"title":"Section 2: Definitions","chunk_id":"2"},
...
]
Example 4 — Multiple headings in one chunk
Input:
[{"3": "Declarations and Commitments (I) Party B commits… (II) Party C commits… Appendix A Data Specification"}, ...]
Output:
[
{"title":"Declarations and Commitments","chunk_id":"3"},
{"title":"(I) Party B commits","chunk_id":"3"},
{"title":"(II) Party C commits","chunk_id":"3"},
{"title":"Appendix A Data Specification","chunk_id":"3"},
...
]
Example 5 — Numbering styles
Input:
[{"4": "1. Scope: Defines boundaries. 2) Definitions: Terms used. III) Methods Overview."}, ...]
Output:
[
{"title":"1. Scope","chunk_id":"4"},
{"title":"2) Definitions","chunk_id":"4"},
{"title":"III) Methods Overview","chunk_id":"4"},
...
]
Example 6 — Long list (NOT headings)
Input:
{"5": "Item list: apples, bananas, strawberries, blueberries, mangos, peaches"}, ...]
Output:
[
{"title":"-1","chunk_id":"5"},
...
]
Example 7 — Mixed Chinese/English
Input:
{"6": "出版信息略This standard follows industry practices. Chapter 1: Overview 摘要… 第2节术语与缩略语"}, ...]
Output:
[
{"title":"Chapter 1: Overview","chunk_id":"6"},
{"title":"第2节术语与缩略语","chunk_id":"6"},
...
]

View File

@@ -0,0 +1,8 @@
OUTPUT FORMAT
- Return ONLY the JSON array.
- Use double quotes.
- No extra commentary.
- Keep language of "title" the same as the input.
INPUT
{{text}}

View File

@@ -0,0 +1,20 @@
You are an expert analyst tasked with matching text content to the title.
**Instructions:**
1. Analyze the given title with its numeric structure index and the provided text.
2. Determine whether the title is mentioned as a section tile in the given text.
3. Provide a concise, step-by-step reasoning for your decision.
4. Output **only** the complete JSON object. Do not include any other text, explanations, or markdown code block fences (like ```json).
**Output Format:**
Your output must be a valid JSON object with the following keys:
{
"reasoning": "Step-by-step explanation of your analysis.",
"exist": "<yes or no>",
}
** The title: **
{{ structure }} {{ title }}
** Given text: **
{{ text }}

View File

@@ -0,0 +1,118 @@
# System Prompt: TOC Relevance Evaluation
You are an expert logical reasoning assistant specializing in hierarchical Table of Contents (TOC) relevance evaluation.
## GOAL
You will receive:
1. A JSON list of TOC items, each with fields:
```json
{
"level": <integer>, // e.g., 1, 2, 3
"title": <string> // section title
}
```
2. A user query (natural language question).
You must assign a **relevance score** (integer) to every TOC entry, based on how related its `title` is to the `query`.
---
## RULES
### Scoring System
- 5 → highly relevant (directly answers or matches the query intent)
- 3 → somewhat related (same topic or partially overlaps)
- 1 → weakly related (vague or tangential)
- 0 → no clear relation
- -1 → explicitly irrelevant or contradictory
### Hierarchy Traversal
- The TOC is hierarchical: smaller `level` = higher layer (e.g., level 1 is top-level, level 2 is a subsection).
- You must traverse in **hierarchical order** — interpret the structure based on levels (1 > 2 > 3).
- If a high-level item (level 1) is strongly related (score 5), its child items (level 2, 3) are likely relevant too.
- If a high-level item is unrelated (-1 or 0), its deeper children are usually less relevant unless the titles clearly match the query.
- Lower (deeper) levels provide more specific content; prefer assigning higher scores if they directly match the query.
### Output Format
Return a **JSON array**, preserving the input order but adding a new key `"score"`:
```json
[
{"level": 1, "title": "Introduction", "score": 0},
{"level": 2, "title": "Definition of Sustainability", "score": 5}
]
```
### Constraints
- Output **only the JSON array** — no explanations or reasoning text.
### EXAMPLES
#### Example 1
Input TOC:
[
{"level": 1, "title": "Machine Learning Overview"},
{"level": 2, "title": "Supervised Learning"},
{"level": 2, "title": "Unsupervised Learning"},
{"level": 3, "title": "Applications of Deep Learning"}
]
Query:
"How is deep learning used in image classification?"
Output:
[
{"level": 1, "title": "Machine Learning Overview", "score": 3},
{"level": 2, "title": "Supervised Learning", "score": 3},
{"level": 2, "title": "Unsupervised Learning", "score": 0},
{"level": 3, "title": "Applications of Deep Learning", "score": 5}
]
---
#### Example 2
Input TOC:
[
{"level": 1, "title": "Marketing Basics"},
{"level": 2, "title": "Consumer Behavior"},
{"level": 2, "title": "Digital Marketing"},
{"level": 3, "title": "Social Media Campaigns"},
{"level": 3, "title": "SEO Optimization"}
]
Query:
"What are the best online marketing methods?"
Output:
[
{"level": 1, "title": "Marketing Basics", "score": 3},
{"level": 2, "title": "Consumer Behavior", "score": 1},
{"level": 2, "title": "Digital Marketing", "score": 5},
{"level": 3, "title": "Social Media Campaigns", "score": 5},
{"level": 3, "title": "SEO Optimization", "score": 5}
]
---
#### Example 3
Input TOC:
[
{"level": 1, "title": "Physics Overview"},
{"level": 2, "title": "Classical Mechanics"},
{"level": 3, "title": "Newtons Laws"},
{"level": 2, "title": "Thermodynamics"},
{"level": 3, "title": "Entropy and Heat Transfer"}
]
Query:
"What is entropy?"
Output:
[
{"level": 1, "title": "Physics Overview", "score": 3},
{"level": 2, "title": "Classical Mechanics", "score": 0},
{"level": 3, "title": "Newtons Laws", "score": -1},
{"level": 2, "title": "Thermodynamics", "score": 5},
{"level": 3, "title": "Entropy and Heat Transfer", "score": 5}
]

View File

@@ -0,0 +1,17 @@
# User Prompt: TOC Relevance Evaluation
You will now receive:
1. A JSON list of TOC items (each with `level` and `title`)
2. A user query string.
Traverse the TOC hierarchically based on level numbers and assign scores (5,3,1,0,-1) according to the rules in the system prompt.
Output **only** the JSON array with the added `"score"` field.
---
**Input TOC:**
{{ toc_json }}
**Query:**
{{ query }}

View File

@@ -0,0 +1,19 @@
**Task Instruction:**
You are tasked with reading and analyzing tool call result based on the following inputs: **Inputs for current call**, and **Results**. Your objective is to extract relevant and helpful information for **Inputs for current call** from the **Results** and seamlessly integrate this information into the previous steps to continue reasoning for the original question.
**Guidelines:**
1. **Analyze the Results:**
- Carefully review the content of each results of tool call.
- Identify factual information that is relevant to the **Inputs for current call** and can aid in the reasoning process for the original question.
2. **Extract Relevant Information:**
- Select the information from the Searched Web Pages that directly contributes to advancing the previous reasoning steps.
- Ensure that the extracted information is accurate and relevant.
- **Inputs for current call:**
{{ inputs }}
- **Results:**
{{ results }}

View File

@@ -0,0 +1,23 @@
## INSTRUCTION
Transcribe the content from the provided PDF page image into clean Markdown format.
- Only output the content transcribed from the image.
- Do NOT output this instruction or any other explanation.
- If the content is missing or you do not understand the input, return an empty string.
## RULES
1. Do NOT generate examples, demonstrations, or templates.
2. Do NOT output any extra text such as 'Example', 'Example Output', or similar.
3. Do NOT generate any tables, headings, or content that is not explicitly present in the image.
4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content.
5. Do NOT explain Markdown or mention that you are using Markdown.
6. Do NOT wrap the output in ```markdown or ``` blocks.
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
8. Preserve the original language, information, and order exactly as shown in the image.
{% if page %}
At the end of the transcription, add the page divider: `--- Page {{ page }} ---`.
{% endif %}
> If you do not detect valid content in the image, return an empty string.

View File

@@ -0,0 +1,24 @@
## ROLE
You are an expert visual data analyst.
## GOAL
Analyze the image and provide a comprehensive description of its content. Focus on identifying the type of visual data representation (e.g., bar chart, pie chart, line graph, table, flowchart), its structure, and any text captions or labels included in the image.
## TASKS
1. Describe the overall structure of the visual representation. Specify if it is a chart, graph, table, or diagram.
2. Identify and extract any axes, legends, titles, or labels present in the image. Provide the exact text where available.
3. Extract the data points from the visual elements (e.g., bar heights, line graph coordinates, pie chart segments, table rows and columns).
4. Analyze and explain any trends, comparisons, or patterns shown in the data.
5. Capture any annotations, captions, or footnotes, and explain their relevance to the image.
6. Only include details that are explicitly present in the image. If an element (e.g., axis, legend, or caption) does not exist or is not visible, do not mention it.
## OUTPUT FORMAT (Include only sections relevant to the image content)
- Visual Type: [Type]
- Title: [Title text, if available]
- Axes / Legends / Labels: [Details, if available]
- Data Points: [Extracted data]
- Trends / Insights: [Analysis and interpretation]
- Captions / Annotations: [Text and relevance, if available]
> Ensure high accuracy, clarity, and completeness in your analysis, and include only the information present in the image. Avoid unnecessary statements about missing elements.

View File

View File

@@ -0,0 +1,255 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
import numpy as np
DEFAULT_MATCH_VECTOR_TOPN = 10
DEFAULT_MATCH_SPARSE_TOPN = 10
VEC = list | np.ndarray
@dataclass
class SparseVector:
indices: list[int]
values: list[float] | list[int] | None = None
def __post_init__(self):
assert (self.values is None) or (len(self.indices) == len(self.values))
def to_dict_old(self):
d = {"indices": self.indices}
if self.values is not None:
d["values"] = self.values
return d
def to_dict(self):
if self.values is None:
raise ValueError("SparseVector.values is None")
result = {}
for i, v in zip(self.indices, self.values):
result[str(i)] = v
return result
@staticmethod
def from_dict(d):
return SparseVector(d["indices"], d.get("values"))
def __str__(self):
return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})"
def __repr__(self):
return str(self)
class MatchTextExpr(ABC):
def __init__(
self,
fields: list[str],
matching_text: str,
topn: int,
extra_options: dict = dict(),
):
self.fields = fields
self.matching_text = matching_text
self.topn = topn
self.extra_options = extra_options
class MatchDenseExpr(ABC):
def __init__(
self,
vector_column_name: str,
embedding_data: VEC,
embedding_data_type: str,
distance_type: str,
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
extra_options: dict = dict(),
):
self.vector_column_name = vector_column_name
self.embedding_data = embedding_data
self.embedding_data_type = embedding_data_type
self.distance_type = distance_type
self.topn = topn
self.extra_options = extra_options
class MatchSparseExpr(ABC):
def __init__(
self,
vector_column_name: str,
sparse_data: SparseVector | dict,
distance_type: str,
topn: int,
opt_params: dict | None = None,
):
self.vector_column_name = vector_column_name
self.sparse_data = sparse_data
self.distance_type = distance_type
self.topn = topn
self.opt_params = opt_params
class MatchTensorExpr(ABC):
def __init__(
self,
column_name: str,
query_data: VEC,
query_data_type: str,
topn: int,
extra_option: dict | None = None,
):
self.column_name = column_name
self.query_data = query_data
self.query_data_type = query_data_type
self.topn = topn
self.extra_option = extra_option
class FusionExpr(ABC):
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
self.method = method
self.topn = topn
self.fusion_params = fusion_params
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
class OrderByExpr(ABC):
def __init__(self):
self.fields = list()
def asc(self, field: str):
self.fields.append((field, 0))
return self
def desc(self, field: str):
self.fields.append((field, 1))
return self
def fields(self):
return self.fields
class DocStoreConnection(ABC):
"""
Database operations
"""
@abstractmethod
def dbType(self) -> str:
"""
Return the type of the database.
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def health(self) -> dict:
"""
Return the health status of the database.
"""
raise NotImplementedError("Not implemented")
"""
Table operations
"""
@abstractmethod
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
"""
Create an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def deleteIdx(self, indexName: str, knowledgebaseId: str):
"""
Delete an index with given name
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
"""
Check if an index with given name exists
"""
raise NotImplementedError("Not implemented")
"""
CRUD operations
"""
@abstractmethod
def search(
self, selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str|list[str],
knowledgebaseIds: list[str],
aggFields: list[str] = [],
rank_feature: dict | None = None
):
"""
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
"""
Get single chunk with given id
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
"""
Update or insert a bulk of rows
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
"""
Update rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
@abstractmethod
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
"""
Delete rows with given conjunctive equivalent filtering condition
"""
raise NotImplementedError("Not implemented")
"""
Helper functions for search result
"""
@abstractmethod
def getTotal(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getChunkIds(self, res):
raise NotImplementedError("Not implemented")
@abstractmethod
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
raise NotImplementedError("Not implemented")
@abstractmethod
def getHighlight(self, res, keywords: list[str], fieldnm: str):
raise NotImplementedError("Not implemented")
@abstractmethod
def getAggregation(self, res, fieldnm: str):
raise NotImplementedError("Not implemented")
"""
SQL
"""
@abstractmethod
def sql(sql: str, fetch_size: int, format: str):
"""
Run the sql generated by text-to-sql
"""
raise NotImplementedError("Not implemented")

View File

@@ -0,0 +1,247 @@
import io
import hashlib
import zipfile
import requests
from requests.exceptions import Timeout, RequestException
from io import BytesIO
from typing import List, Union, Tuple, Optional, Dict
import PyPDF2
from docx import Document
import olefile
def _is_zip(h: bytes) -> bool:
return h.startswith(b"PK\x03\x04") or h.startswith(b"PK\x05\x06") or h.startswith(b"PK\x07\x08")
def _is_pdf(h: bytes) -> bool:
return h.startswith(b"%PDF-")
def _is_ole(h: bytes) -> bool:
return h.startswith(b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1")
def _sha10(b: bytes) -> str:
return hashlib.sha256(b).hexdigest()[:10]
def _guess_ext(b: bytes) -> str:
h = b[:8]
if _is_zip(h):
try:
with zipfile.ZipFile(io.BytesIO(b), "r") as z:
names = [n.lower() for n in z.namelist()]
if any(n.startswith("word/") for n in names):
return ".docx"
if any(n.startswith("ppt/") for n in names):
return ".pptx"
if any(n.startswith("xl/") for n in names):
return ".xlsx"
except Exception:
pass
return ".zip"
if _is_pdf(h):
return ".pdf"
if _is_ole(h):
return ".doc"
return ".bin"
# Try to extract the real embedded payload from OLE's Ole10Native
def _extract_ole10native_payload(data: bytes) -> bytes:
try:
pos = 0
if len(data) < 4:
return data
_ = int.from_bytes(data[pos:pos+4], "little")
pos += 4
# filename/src/tmp (NUL-terminated ANSI)
for _ in range(3):
z = data.index(b"\x00", pos)
pos = z + 1
# skip unknown 4 bytes
pos += 4
if pos + 4 > len(data):
return data
size = int.from_bytes(data[pos:pos+4], "little")
pos += 4
if pos + size <= len(data):
return data[pos:pos+size]
except Exception:
pass
return data
def extract_embed_file(target: Union[bytes, bytearray]) -> List[Tuple[str, bytes]]:
"""
Only extract the 'first layer' of embedding, returning raw (filename, bytes).
"""
top = bytes(target)
head = top[:8]
out: List[Tuple[str, bytes]] = []
seen = set()
def push(b: bytes, name_hint: str = ""):
h10 = _sha10(b)
if h10 in seen:
return
seen.add(h10)
ext = _guess_ext(b)
# If name_hint has an extension use its basename; else fallback to guessed ext
if "." in name_hint:
fname = name_hint.split("/")[-1]
else:
fname = f"{h10}{ext}"
out.append((fname, b))
# OOXML/ZIP container (docx/xlsx/pptx)
if _is_zip(head):
try:
with zipfile.ZipFile(io.BytesIO(top), "r") as z:
embed_dirs = (
"word/embeddings/", "word/objects/", "word/activex/",
"xl/embeddings/", "ppt/embeddings/"
)
for name in z.namelist():
low = name.lower()
if any(low.startswith(d) for d in embed_dirs):
try:
b = z.read(name)
push(b, name)
except Exception:
pass
except Exception:
pass
return out
# OLE container (doc/ppt/xls)
if _is_ole(head):
try:
with olefile.OleFileIO(io.BytesIO(top)) as ole:
for entry in ole.listdir():
p = "/".join(entry)
try:
data = ole.openstream(entry).read()
except Exception:
continue
if not data:
continue
if "Ole10Native" in p or "ole10native" in p.lower():
data = _extract_ole10native_payload(data)
push(data, p)
except Exception:
pass
return out
return out
def extract_links_from_docx(docx_bytes: bytes):
"""
Extract all hyperlinks from a Word (.docx) document binary stream.
Args:
docx_bytes (bytes): Raw bytes of a .docx file.
Returns:
set[str]: A set of unique hyperlink URLs.
"""
links = set()
with BytesIO(docx_bytes) as bio:
document = Document(bio)
# Each relationship may represent a hyperlink, image, footer, etc.
for rel in document.part.rels.values():
if rel.reltype == (
"http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink"
):
links.add(rel.target_ref)
return links
def extract_links_from_pdf(pdf_bytes: bytes):
"""
Extract all clickable hyperlinks from a PDF binary stream.
Args:
pdf_bytes (bytes): Raw bytes of a PDF file.
Returns:
set[str]: A set of unique hyperlink URLs (unordered).
"""
links = set()
with BytesIO(pdf_bytes) as bio:
pdf = PyPDF2.PdfReader(bio)
for page in pdf.pages:
annots = page.get("/Annots")
if not annots or isinstance(annots, PyPDF2.generic.IndirectObject):
continue
for annot in annots:
obj = annot.get_object()
a = obj.get("/A")
if a and a.get("/URI"):
links.add(a["/URI"])
return links
_GLOBAL_SESSION: Optional[requests.Session] = None
def _get_session(headers: Optional[Dict[str, str]] = None) -> requests.Session:
"""Get or create a global reusable session."""
global _GLOBAL_SESSION
if _GLOBAL_SESSION is None:
_GLOBAL_SESSION = requests.Session()
_GLOBAL_SESSION.headers.update({
"User-Agent": (
"Mozilla/5.0 (X11; Linux x86_64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/121.0 Safari/537.36"
)
})
if headers:
_GLOBAL_SESSION.headers.update(headers)
return _GLOBAL_SESSION
def extract_html(
url: str,
timeout: float = 60.0,
headers: Optional[Dict[str, str]] = None,
max_retries: int = 2,
) -> Tuple[Optional[bytes], Dict[str, str]]:
"""
Extract the full HTML page as raw bytes from a given URL.
Automatically reuses a persistent HTTP session and applies robust timeout & retry logic.
Args:
url (str): Target webpage URL.
timeout (float): Request timeout in seconds (applies to connect + read).
headers (dict, optional): Extra HTTP headers.
max_retries (int): Number of retries on timeout or transient errors.
Returns:
tuple(bytes|None, dict):
- html_bytes: Raw HTML content (or None if failed)
- metadata: HTTP info (status_code, content_type, final_url, error if any)
"""
session = _get_session(headers=headers)
metadata = {"final_url": url, "status_code": "", "content_type": "", "error": ""}
for attempt in range(1, max_retries + 1):
try:
resp = session.get(url, timeout=timeout)
resp.raise_for_status()
html_bytes = resp.content
metadata.update({
"final_url": resp.url,
"status_code": str(resp.status_code),
"content_type": resp.headers.get("Content-Type", ""),
})
return html_bytes, metadata
except Timeout:
metadata["error"] = f"Timeout after {timeout}s (attempt {attempt}/{max_retries})"
if attempt >= max_retries:
continue
except RequestException as e:
metadata["error"] = f"Request failed: {e}"
continue
return None, metadata

View File

Some files were not shown because too many files have changed in this diff Show More